diff --git a/examples/button-clicker/Assets/DiscordRpc.cs b/examples/button-clicker/Assets/DiscordRpc.cs index 83e742e..b3f6ffd 100644 --- a/examples/button-clicker/Assets/DiscordRpc.cs +++ b/examples/button-clicker/Assets/DiscordRpc.cs @@ -87,6 +87,9 @@ public class DiscordRpc [DllImport("discord-rpc", EntryPoint = "Discord_Respond", CallingConvention = CallingConvention.Cdecl)] public static extern void Respond(string userId, Reply reply); + [DllImport("discord-rpc", EntryPoint = "Discord_UpdateHandlers", CallingConvention = CallingConvention.Cdecl)] + public static extern void UpdateHandlers(ref EventHandlers handlers); + public static void UpdatePresence(RichPresence presence) { var presencestruct = presence.GetStruct(); diff --git a/include/discord_rpc.h b/include/discord_rpc.h index feb874b..72e5a4f 100644 --- a/include/discord_rpc.h +++ b/include/discord_rpc.h @@ -80,6 +80,8 @@ DISCORD_EXPORT void Discord_ClearPresence(void); DISCORD_EXPORT void Discord_Respond(const char* userid, /* DISCORD_REPLY_ */ int reply); +DISCORD_EXPORT void Discord_UpdateHandlers(DiscordEventHandlers* handlers); + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/src/discord_rpc.cpp b/src/discord_rpc.cpp index 5e68c0f..635bba7 100644 --- a/src/discord_rpc.cpp +++ b/src/discord_rpc.cpp @@ -60,6 +60,7 @@ static char LastErrorMessage[256]; static int LastDisconnectErrorCode{0}; static char LastDisconnectErrorMessage[256]; static std::mutex PresenceMutex; +static std::mutex HandlerMutex; static QueuedMessage QueuedPresence{}; static MsgQueue SendQueue; static MsgQueue JoinAskQueue; @@ -212,15 +213,15 @@ static void Discord_UpdateConnection(void) // writes if (QueuedPresence.length) { QueuedMessage local; - PresenceMutex.lock(); - local.Copy(QueuedPresence); - QueuedPresence.length = 0; - PresenceMutex.unlock(); + { + std::lock_guard guard(PresenceMutex); + local.Copy(QueuedPresence); + QueuedPresence.length = 0; + } if (!Connection->Write(local.buffer, local.length)) { // if we fail to send, requeue - PresenceMutex.lock(); + std::lock_guard guard(PresenceMutex); QueuedPresence.Copy(local); - PresenceMutex.unlock(); } } @@ -250,6 +251,19 @@ static bool RegisterForEvent(const char* evtName) return false; } +static bool DeregisterForEvent(const char* evtName) +{ + auto qmessage = SendQueue.GetNextAddMessage(); + if (qmessage) { + qmessage->length = + JsonWriteUnsubscribeCommand(qmessage->buffer, sizeof(qmessage->buffer), Nonce++, evtName); + SendQueue.CommitAdd(); + SignalIOActivity(); + return true; + } + return false; +} + extern "C" DISCORD_EXPORT void Discord_Initialize(const char* applicationId, DiscordEventHandlers* handlers, int autoRegister, @@ -266,11 +280,14 @@ extern "C" DISCORD_EXPORT void Discord_Initialize(const char* applicationId, Pid = GetProcessId(); - if (handlers) { - Handlers = *handlers; - } - else { - Handlers = {}; + { + std::lock_guard guard(HandlerMutex); + if (handlers) { + Handlers = *handlers; + } + else { + Handlers = {}; + } } if (Connection) { @@ -279,20 +296,9 @@ extern "C" DISCORD_EXPORT void Discord_Initialize(const char* applicationId, Connection = RpcConnection::Create(applicationId); Connection->onConnect = []() { + Discord_UpdateHandlers(&Handlers); WasJustConnected.exchange(true); ReconnectTimeMs.reset(); - - if (Handlers.joinGame) { - RegisterForEvent("ACTIVITY_JOIN"); - } - - if (Handlers.spectateGame) { - RegisterForEvent("ACTIVITY_SPECTATE"); - } - - if (Handlers.joinRequest) { - RegisterForEvent("ACTIVITY_JOIN_REQUEST"); - } }; Connection->onDisconnect = [](int err, const char* message) { LastDisconnectErrorCode = err; @@ -318,10 +324,11 @@ extern "C" DISCORD_EXPORT void Discord_Shutdown(void) extern "C" DISCORD_EXPORT void Discord_UpdatePresence(const DiscordRichPresence* presence) { - PresenceMutex.lock(); - QueuedPresence.length = JsonWriteRichPresenceObj( - QueuedPresence.buffer, sizeof(QueuedPresence.buffer), Nonce++, Pid, presence); - PresenceMutex.unlock(); + { + std::lock_guard guard(PresenceMutex); + QueuedPresence.length = JsonWriteRichPresenceObj( + QueuedPresence.buffer, sizeof(QueuedPresence.buffer), Nonce++, Pid, presence); + } SignalIOActivity(); } @@ -360,25 +367,38 @@ extern "C" DISCORD_EXPORT void Discord_RunCallbacks(void) if (isConnected) { // if we are connected, disconnect cb first + std::lock_guard guard(HandlerMutex); if (wasDisconnected && Handlers.disconnected) { Handlers.disconnected(LastDisconnectErrorCode, LastDisconnectErrorMessage); } } - if (WasJustConnected.exchange(false) && Handlers.ready) { - Handlers.ready(); + if (WasJustConnected.exchange(false)) { + std::lock_guard guard(HandlerMutex); + if (Handlers.ready) { + Handlers.ready(); + } } - if (GotErrorMessage.exchange(false) && Handlers.errored) { - Handlers.errored(LastErrorCode, LastErrorMessage); + if (GotErrorMessage.exchange(false)) { + std::lock_guard guard(HandlerMutex); + if (Handlers.errored) { + Handlers.errored(LastErrorCode, LastErrorMessage); + } } - if (WasJoinGame.exchange(false) && Handlers.joinGame) { - Handlers.joinGame(JoinGameSecret); + if (WasJoinGame.exchange(false)) { + std::lock_guard guard(HandlerMutex); + if (Handlers.joinGame) { + Handlers.joinGame(JoinGameSecret); + } } - if (WasSpectateGame.exchange(false) && Handlers.spectateGame) { - Handlers.spectateGame(SpectateGameSecret); + if (WasSpectateGame.exchange(false)) { + std::lock_guard guard(HandlerMutex); + if (Handlers.spectateGame) { + Handlers.spectateGame(SpectateGameSecret); + } } // Right now this batches up any requests and sends them all in a burst; I could imagine a world @@ -388,17 +408,50 @@ extern "C" DISCORD_EXPORT void Discord_RunCallbacks(void) // not it should be trivial for the implementer to make a queue themselves. while (JoinAskQueue.HavePendingSends()) { auto req = JoinAskQueue.GetNextSendMessage(); - if (Handlers.joinRequest) { - DiscordJoinRequest djr{req->userId, req->username, req->discriminator, req->avatar}; - Handlers.joinRequest(&djr); + { + std::lock_guard guard(HandlerMutex); + if (Handlers.joinRequest) { + DiscordJoinRequest djr{req->userId, req->username, req->discriminator, req->avatar}; + Handlers.joinRequest(&djr); + } } JoinAskQueue.CommitSend(); } if (!isConnected) { // if we are not connected, disconnect message last + std::lock_guard guard(HandlerMutex); if (wasDisconnected && Handlers.disconnected) { Handlers.disconnected(LastDisconnectErrorCode, LastDisconnectErrorMessage); } } } + +extern "C" DISCORD_EXPORT void Discord_UpdateHandlers(DiscordEventHandlers* newHandlers) +{ + if (newHandlers) { + +#define HANDLE_EVENT_REGISTRATION(handler_name, event) \ + if (!Handlers.handler_name && newHandlers->handler_name) { \ + RegisterForEvent(event); \ + } \ + else if (Handlers.handler_name && !newHandlers->handler_name) { \ + DeregisterForEvent(event); \ + } + + std::lock_guard guard(HandlerMutex); + HANDLE_EVENT_REGISTRATION(joinGame, "ACTIVITY_JOIN") + HANDLE_EVENT_REGISTRATION(spectateGame, "ACTIVITY_SPECTATE") + HANDLE_EVENT_REGISTRATION(joinRequest, "ACTIVITY_JOIN_REQUEST") + +#undef HANDLE_EVENT_REGISTRATION + + Handlers = *newHandlers; + } + else + { + std::lock_guard guard(HandlerMutex); + Handlers = {}; + } + return; +} diff --git a/src/serialization.cpp b/src/serialization.cpp index 8a3215f..4190ee0 100644 --- a/src/serialization.cpp +++ b/src/serialization.cpp @@ -197,6 +197,25 @@ size_t JsonWriteSubscribeCommand(char* dest, size_t maxLen, int nonce, const cha return writer.Size(); } +size_t JsonWriteUnsubscribeCommand(char* dest, size_t maxLen, int nonce, const char* evtName) +{ + JsonWriter writer(dest, maxLen); + + { + WriteObject obj(writer); + + JsonWriteNonce(writer, nonce); + + WriteKey(writer, "cmd"); + writer.String("UNSUBSCRIBE"); + + WriteKey(writer, "evt"); + writer.String(evtName); + } + + return writer.Size(); +} + size_t JsonWriteJoinReply(char* dest, size_t maxLen, const char* userId, int reply, int nonce) { JsonWriter writer(dest, maxLen); diff --git a/src/serialization.h b/src/serialization.h index ad9382b..106dce7 100644 --- a/src/serialization.h +++ b/src/serialization.h @@ -47,6 +47,8 @@ size_t JsonWriteRichPresenceObj(char* dest, const DiscordRichPresence* presence); size_t JsonWriteSubscribeCommand(char* dest, size_t maxLen, int nonce, const char* evtName); +size_t JsonWriteUnsubscribeCommand(char* dest, size_t maxLen, int nonce, const char* evtName); + size_t JsonWriteJoinReply(char* dest, size_t maxLen, const char* userId, int reply, int nonce); // I want to use as few allocations as I can get away with, and to do that with RapidJson, you need