diff --git a/src/addrdb.h b/src/addrdb.h index 3ffcfe3e1..d8c66d872 100644 --- a/src/addrdb.h +++ b/src/addrdb.h @@ -14,6 +14,7 @@ class CSubNet; class CAddrMan; +class CDataStream; typedef enum BanReason { diff --git a/src/init.cpp b/src/init.cpp index 27843fa88..6aaa7bfc5 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -42,6 +42,7 @@ #endif #include #include +#include #ifndef WIN32 #include @@ -70,6 +71,7 @@ static const bool DEFAULT_REST_ENABLE = false; static const bool DEFAULT_DISABLE_SAFEMODE = false; static const bool DEFAULT_STOPAFTERBLOCKIMPORT = false; +std::unique_ptr g_connman; #if ENABLE_ZMQ static CZMQNotificationInterface* pzmqNotificationInterface = NULL; @@ -197,7 +199,9 @@ void Shutdown() if (pwalletMain) pwalletMain->Flush(false); #endif - StopNode(); + StopNode(*g_connman); + g_connman.reset(); + StopTorControl(); UnregisterNodeSignals(GetNodeSignals()); @@ -1101,6 +1105,10 @@ bool AppInit2(boost::thread_group& threadGroup, CScheduler& scheduler) #endif // ENABLE_WALLET // ********************************************************* Step 6: network initialization + assert(!g_connman); + g_connman = std::unique_ptr(new CConnman()); + CConnman& connman = *g_connman; + RegisterNodeSignals(GetNodeSignals()); // sanitize comments per BIP-0014, format user agent and check total size @@ -1497,7 +1505,9 @@ bool AppInit2(boost::thread_group& threadGroup, CScheduler& scheduler) if (GetBoolArg("-listenonion", DEFAULT_LISTEN_ONION)) StartTorControl(threadGroup, scheduler); - StartNode(threadGroup, scheduler); + std::string strNodeError; + if(!StartNode(connman, threadGroup, scheduler, strNodeError)) + return InitError(strNodeError); // ********************************************************* Step 12: finished diff --git a/src/net.cpp b/src/net.cpp index cee149ee7..6177dc04f 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -65,13 +65,6 @@ namespace { const int MAX_OUTBOUND_CONNECTIONS = 8; const int MAX_FEELER_CONNECTIONS = 1; - - struct ListenSocket { - SOCKET socket; - bool whitelisted; - - ListenSocket(SOCKET _socket, bool _whitelisted) : socket(_socket), whitelisted(_whitelisted) {} - }; } const static std::string NET_MESSAGE_COMMAND_OTHER = "*other*"; @@ -1015,7 +1008,7 @@ static bool AttemptToEvictConnection() { return false; } -static void AcceptConnection(const ListenSocket& hListenSocket) { +void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); SOCKET hSocket = accept(hListenSocket.socket, (struct sockaddr*)&sockaddr, &len); @@ -1089,7 +1082,7 @@ static void AcceptConnection(const ListenSocket& hListenSocket) { } } -void ThreadSocketHandler() +void CConnman::ThreadSocketHandler() { unsigned int nPrevNodeCount = 0; while (true) @@ -1497,7 +1490,7 @@ static std::string GetDNSHost(const CDNSSeedData& data, ServiceFlags* requiredSe } -void ThreadDNSAddressSeed() +void CConnman::ThreadDNSAddressSeed() { // goal: only query DNS seeds if address need is acute if ((addrman.size() > 0) && @@ -1577,7 +1570,7 @@ void DumpData() DumpBanlist(); } -void static ProcessOneShot() +void CConnman::ProcessOneShot() { std::string strDest; { @@ -1595,7 +1588,7 @@ void static ProcessOneShot() } } -void ThreadOpenConnections() +void CConnman::ThreadOpenConnections() { // Connect to specific addresses if (mapArgs.count("-connect") && mapMultiArgs["-connect"].size() > 0) @@ -1791,7 +1784,7 @@ std::vector GetAddedNodeInfo() return ret; } -void ThreadOpenAddedConnections() +void CConnman::ThreadOpenAddedConnections() { { LOCK(cs_vAddedNodes); @@ -1848,7 +1841,7 @@ bool OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSem } -void ThreadMessageHandler() +void CConnman::ThreadMessageHandler() { boost::mutex condition_mutex; boost::unique_lock lock(condition_mutex); @@ -2064,7 +2057,11 @@ void static Discover(boost::thread_group& threadGroup) #endif } -void StartNode(boost::thread_group& threadGroup, CScheduler& scheduler) +CConnman::CConnman() +{ +} + +bool StartNode(CConnman& connman, boost::thread_group& threadGroup, CScheduler& scheduler, std::string& strNodeError) { uiInterface.InitMessage(_("Loading addresses...")); // Load addresses from peers.dat @@ -2102,6 +2099,17 @@ void StartNode(boost::thread_group& threadGroup, CScheduler& scheduler) fAddressesInitialized = true; + Discover(threadGroup); + + bool ret = connman.Start(threadGroup, strNodeError); + + // Dump network addresses + scheduler.scheduleEvery(DumpData, DUMP_ADDRESSES_INTERVAL); + return ret; +} + +bool CConnman::Start(boost::thread_group& threadGroup, std::string& strNodeError) +{ if (semOutbound == NULL) { // initialize semaphore int nMaxOutbound = std::min((MAX_OUTBOUND_CONNECTIONS + MAX_FEELER_CONNECTIONS), nMaxConnections); @@ -2114,8 +2122,6 @@ void StartNode(boost::thread_group& threadGroup, CScheduler& scheduler) pnodeLocalHost = new CNode(INVALID_SOCKET, CAddress(CService(local, 0), nLocalServices)); } - Discover(threadGroup); - // // Start threads // @@ -2123,34 +2129,30 @@ void StartNode(boost::thread_group& threadGroup, CScheduler& scheduler) if (!GetBoolArg("-dnsseed", true)) LogPrintf("DNS seeding disabled\n"); else - threadGroup.create_thread(boost::bind(&TraceThread, "dnsseed", &ThreadDNSAddressSeed)); + threadGroup.create_thread(boost::bind(&TraceThread >, "dnsseed", boost::function(boost::bind(&CConnman::ThreadDNSAddressSeed, this)))); // Map ports with UPnP MapPort(GetBoolArg("-upnp", DEFAULT_UPNP)); // Send and receive from sockets, accept connections - threadGroup.create_thread(boost::bind(&TraceThread, "net", &ThreadSocketHandler)); + threadGroup.create_thread(boost::bind(&TraceThread >, "net", boost::function(boost::bind(&CConnman::ThreadSocketHandler, this)))); // Initiate outbound connections from -addnode - threadGroup.create_thread(boost::bind(&TraceThread, "addcon", &ThreadOpenAddedConnections)); + threadGroup.create_thread(boost::bind(&TraceThread >, "addcon", boost::function(boost::bind(&CConnman::ThreadOpenAddedConnections, this)))); // Initiate outbound connections - threadGroup.create_thread(boost::bind(&TraceThread, "opencon", &ThreadOpenConnections)); + threadGroup.create_thread(boost::bind(&TraceThread >, "opencon", boost::function(boost::bind(&CConnman::ThreadOpenConnections, this)))); // Process messages - threadGroup.create_thread(boost::bind(&TraceThread, "msghand", &ThreadMessageHandler)); + threadGroup.create_thread(boost::bind(&TraceThread >, "msghand", boost::function(boost::bind(&CConnman::ThreadMessageHandler, this)))); - // Dump network addresses - scheduler.scheduleEvery(&DumpData, DUMP_ADDRESSES_INTERVAL); + return true; } -bool StopNode() +bool StopNode(CConnman& connman) { LogPrintf("StopNode()\n"); MapPort(false); - if (semOutbound) - for (int i=0; i<(MAX_OUTBOUND_CONNECTIONS + MAX_FEELER_CONNECTIONS); i++) - semOutbound->post(); if (fAddressesInitialized) { @@ -2158,6 +2160,7 @@ bool StopNode() fAddressesInitialized = false; } + connman.Stop(); return true; } @@ -2168,28 +2171,6 @@ public: ~CNetCleanup() { - // Close sockets - BOOST_FOREACH(CNode* pnode, vNodes) - if (pnode->hSocket != INVALID_SOCKET) - CloseSocket(pnode->hSocket); - BOOST_FOREACH(ListenSocket& hListenSocket, vhListenSocket) - if (hListenSocket.socket != INVALID_SOCKET) - if (!CloseSocket(hListenSocket.socket)) - LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError())); - - // clean up some globals (to help leak detection) - BOOST_FOREACH(CNode *pnode, vNodes) - delete pnode; - BOOST_FOREACH(CNode *pnode, vNodesDisconnected) - delete pnode; - vNodes.clear(); - vNodesDisconnected.clear(); - vhListenSocket.clear(); - delete semOutbound; - semOutbound = NULL; - delete pnodeLocalHost; - pnodeLocalHost = NULL; - #ifdef WIN32 // Shutdown Windows Sockets WSACleanup(); @@ -2198,6 +2179,38 @@ public: } instance_of_cnetcleanup; +void CConnman::Stop() +{ + if (semOutbound) + for (int i=0; i<(MAX_OUTBOUND_CONNECTIONS + MAX_FEELER_CONNECTIONS); i++) + semOutbound->post(); + + // Close sockets + BOOST_FOREACH(CNode* pnode, vNodes) + if (pnode->hSocket != INVALID_SOCKET) + CloseSocket(pnode->hSocket); + BOOST_FOREACH(ListenSocket& hListenSocket, vhListenSocket) + if (hListenSocket.socket != INVALID_SOCKET) + if (!CloseSocket(hListenSocket.socket)) + LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError())); + + // clean up some globals (to help leak detection) + BOOST_FOREACH(CNode *pnode, vNodes) + delete pnode; + BOOST_FOREACH(CNode *pnode, vNodesDisconnected) + delete pnode; + vNodes.clear(); + vNodesDisconnected.clear(); + vhListenSocket.clear(); + delete semOutbound; + semOutbound = NULL; + delete pnodeLocalHost; + pnodeLocalHost = NULL; +} + +CConnman::~CConnman() +{ +} void RelayTransaction(const CTransaction& tx) { diff --git a/src/net.h b/src/net.h index 0d1c62e42..7f212f233 100644 --- a/src/net.h +++ b/src/net.h @@ -21,6 +21,7 @@ #include #include #include +#include #ifndef WIN32 #include @@ -93,11 +94,36 @@ CNode* FindNode(const std::string& addrName); CNode* FindNode(const CService& ip); CNode* FindNode(const NodeId id); //TODO: Remove this bool OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant *grantOutbound = NULL, const char *strDest = NULL, bool fOneShot = false, bool fFeeler = false); + +struct ListenSocket { + SOCKET socket; + bool whitelisted; + + ListenSocket(SOCKET socket_, bool whitelisted_) : socket(socket_), whitelisted(whitelisted_) {} +}; + +class CConnman +{ +public: + CConnman(); + ~CConnman(); + bool Start(boost::thread_group& threadGroup, std::string& strNodeError); + void Stop(); +private: + void ThreadOpenAddedConnections(); + void ProcessOneShot(); + void ThreadOpenConnections(); + void ThreadMessageHandler(); + void AcceptConnection(const ListenSocket& hListenSocket); + void ThreadSocketHandler(); + void ThreadDNSAddressSeed(); +}; +extern std::unique_ptr g_connman; void MapPort(bool fUseUPnP); unsigned short GetListenPort(); bool BindListenPort(const CService &bindAddr, std::string& strError, bool fWhitelisted = false); -void StartNode(boost::thread_group& threadGroup, CScheduler& scheduler); -bool StopNode(); +bool StartNode(CConnman& connman, boost::thread_group& threadGroup, CScheduler& scheduler, std::string& strNodeError); +bool StopNode(CConnman& connman); void SocketSendData(CNode *pnode); struct CombinerAll diff --git a/src/test/test_bitcoin.cpp b/src/test/test_bitcoin.cpp index 056f2982c..ed74418e3 100644 --- a/src/test/test_bitcoin.cpp +++ b/src/test/test_bitcoin.cpp @@ -26,6 +26,8 @@ #include #include +std::unique_ptr g_connman; + extern bool fPrintToConsole; extern void noui_connect(); @@ -43,6 +45,7 @@ BasicTestingSetup::BasicTestingSetup(const std::string& chainName) BasicTestingSetup::~BasicTestingSetup() { ECC_Stop(); + g_connman.reset(); } TestingSetup::TestingSetup(const std::string& chainName) : BasicTestingSetup(chainName) @@ -50,6 +53,7 @@ TestingSetup::TestingSetup(const std::string& chainName) : BasicTestingSetup(cha const CChainParams& chainparams = Params(); // Ideally we'd move all the RPC tests to the functional testing framework // instead of unit tests, but for now we need these here. + RegisterAllCoreRPCCommands(tableRPC); ClearDatadirCache(); pathTemp = GetTempPath() / strprintf("test_bitcoin_%lu_%i", (unsigned long)GetTime(), (int)(GetRand(100000))); @@ -68,6 +72,8 @@ TestingSetup::TestingSetup(const std::string& chainName) : BasicTestingSetup(cha nScriptCheckThreads = 3; for (int i=0; i < nScriptCheckThreads-1; i++) threadGroup.create_thread(&ThreadScriptCheck); + g_connman = std::unique_ptr(new CConnman()); + connman = g_connman.get(); RegisterNodeSignals(GetNodeSignals()); } diff --git a/src/test/test_bitcoin.h b/src/test/test_bitcoin.h index bc0d2fe31..9819a7097 100644 --- a/src/test/test_bitcoin.h +++ b/src/test/test_bitcoin.h @@ -27,10 +27,12 @@ struct BasicTestingSetup { /** Testing setup that configures a complete environment. * Included are data directory, coins database, script check threads setup. */ +class CConnman; struct TestingSetup: public BasicTestingSetup { CCoinsViewDB *pcoinsdbview; boost::filesystem::path pathTemp; boost::thread_group threadGroup; + CConnman* connman; TestingSetup(const std::string& chainName = CBaseChainParams::MAIN); ~TestingSetup();