#include <tchar.h>
#include <stdlib.h>
#include <strsafe.h>
+#include <wtsapi32.h>
+#include <malloc.h>
+#include <assert.h>
#include <version.h>
HANDLE StopEvent;
HANDLE RequestEvent;
HKEY RequestKey;
+ BOOL RebootPending;
} MONITOR_CONTEXT, *PMONITOR_CONTEXT;
MONITOR_CONTEXT MonitorContext;
return ERROR_CALL_NOT_IMPLEMENTED;
}
+static const CHAR *
+WTSStateName(
+ IN DWORD State
+ )
+{
+#define _STATE_NAME(_State) \
+ case WTS ## _State: \
+ return #_State
+
+ switch (State) {
+ _STATE_NAME(Active);
+ _STATE_NAME(Connected);
+ _STATE_NAME(ConnectQuery);
+ _STATE_NAME(Shadow);
+ _STATE_NAME(Disconnected);
+ _STATE_NAME(Idle);
+ _STATE_NAME(Listen);
+ _STATE_NAME(Reset);
+ _STATE_NAME(Down);
+ _STATE_NAME(Init);
+ default:
+ break;
+ }
+
+ return "UNKNOWN";
+
+#undef _STATE_NAME
+}
+
+static VOID
+DoReboot(
+ VOID
+ )
+{
+ (VOID) InitiateSystemShutdownEx(NULL,
+ NULL,
+ 0,
+ TRUE,
+ TRUE,
+ SHTDN_REASON_MAJOR_OPERATINGSYSTEM |
+ SHTDN_REASON_MINOR_INSTALLATION |
+ SHTDN_REASON_FLAG_PLANNED);
+}
+
+static VOID
+PromptForReboot(
+ IN PTCHAR DriverName
+ )
+{
+ PMONITOR_CONTEXT Context = &MonitorContext;
+ HRESULT Result;
+ TCHAR ServiceKeyName[MAX_PATH];
+ HKEY ServiceKey;
+ DWORD MaxValueLength;
+ DWORD DisplayNameLength;
+ PTCHAR DisplayName;
+ DWORD Type;
+ TCHAR Title[] = TEXT(VENDOR_NAME_STR);
+ TCHAR Message[MAXIMUM_BUFFER_SIZE];
+ PWTS_SESSION_INFO SessionInfo;
+ DWORD Count;
+ DWORD Index;
+ BOOL Success;
+ HRESULT Error;
+
+ Log("====> (%s)", DriverName);
+
+ Result = StringCbPrintf(ServiceKeyName,
+ MAX_PATH,
+ SERVICES_KEY "\\%s",
+ DriverName);
+ assert(SUCCEEDED(Result));
+
+ Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
+ ServiceKeyName,
+ 0,
+ KEY_READ,
+ &ServiceKey);
+ if (Error != ERROR_SUCCESS) {
+ SetLastError(Error);
+ goto fail1;
+ }
+
+ Error = RegQueryInfoKey(ServiceKey,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ &MaxValueLength,
+ NULL,
+ NULL);
+ if (Error != ERROR_SUCCESS) {
+ SetLastError(Error);
+ goto fail2;
+ }
+
+ DisplayNameLength = MaxValueLength + sizeof (TCHAR);
+
+ DisplayName = calloc(1, DisplayNameLength);
+ if (DisplayName == NULL)
+ goto fail3;
+
+ Error = RegQueryValueEx(ServiceKey,
+ "DisplayName",
+ NULL,
+ &Type,
+ (LPBYTE)DisplayName,
+ &DisplayNameLength);
+ if (Error != ERROR_SUCCESS) {
+ SetLastError(Error);
+ goto fail4;
+ }
+
+ if (Type != REG_SZ) {
+ SetLastError(ERROR_BAD_FORMAT);
+ goto fail5;
+ }
+
+ Result = StringCbPrintf(Message,
+ MAXIMUM_BUFFER_SIZE,
+ TEXT("%s needs to restart the system to "
+ "complete installation.\n"
+ "Press 'Yes' to restart the system "
+ "now or 'No' if you plan to restart "
+ "the system later.\n"),
+ DisplayName);
+ assert(SUCCEEDED(Result));
+
+ Success = WTSEnumerateSessions(WTS_CURRENT_SERVER_HANDLE,
+ 0,
+ 1,
+ &SessionInfo,
+ &Count);
+
+ if (!Success)
+ goto fail6;
+
+ for (Index = 0; Index < Count; Index++) {
+ DWORD SessionId = SessionInfo[Index].SessionId;
+ PTCHAR Name = SessionInfo[Index].pWinStationName;
+ WTS_CONNECTSTATE_CLASS State = SessionInfo[Index].State;
+ DWORD Response;
+
+ Log("[%u]: %s [%s]",
+ SessionId,
+ Name,
+ WTSStateName(State));
+
+ if (State != WTSActive)
+ continue;
+
+ Success = WTSSendMessage(WTS_CURRENT_SERVER_HANDLE,
+ SessionId,
+ Title,
+ sizeof (Title),
+ Message,
+ sizeof (Message),
+ MB_YESNO | MB_ICONEXCLAMATION,
+ 0,
+ &Response,
+ TRUE);
+
+ if (!Success)
+ goto fail7;
+
+ Context->RebootPending = TRUE;
+
+ if (Response == IDYES)
+ DoReboot();
+
+ break;
+ }
+
+ WTSFreeMemory(SessionInfo);
+
+ free(DisplayName);
+
+ RegCloseKey(ServiceKey);
+
+ Log("<====");
+
+ return;
+
+fail7:
+ Log("fail7");
+
+ WTSFreeMemory(SessionInfo);
+
+fail6:
+ Log("fail6");
+
+fail5:
+ Log("fail5");
+
+fail4:
+ Log("fail4");
+
+ free(DisplayName);
+
+fail3:
+ Log("fail3");
+
+fail2:
+ Log("fail2");
+
+ RegCloseKey(ServiceKey);
+
+fail1:
+ Error = GetLastError();
+
+ {
+ PTCHAR Message;
+ Message = GetErrorMessage(Error);
+ Log("fail1 (%s)", Message);
+ LocalFree(Message);
+ }
+}
+
+static VOID
+CheckRebootValue(
+ VOID
+ )
+{
+ PMONITOR_CONTEXT Context = &MonitorContext;
+ HRESULT Error;
+ DWORD MaxValueLength;
+ DWORD RebootLength;
+ PTCHAR Reboot;
+ DWORD Type;
+
+ Log("====>");
+
+ Error = RegQueryInfoKey(Context->RequestKey,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ &MaxValueLength,
+ NULL,
+ NULL);
+ if (Error != ERROR_SUCCESS) {
+ SetLastError(Error);
+ goto fail1;
+ }
+
+ RebootLength = MaxValueLength + sizeof (TCHAR);
+
+ Reboot = calloc(1, RebootLength);
+ if (Reboot == NULL)
+ goto fail2;
+
+ Error = RegQueryValueEx(Context->RequestKey,
+ "Reboot",
+ NULL,
+ &Type,
+ (LPBYTE)Reboot,
+ &RebootLength);
+ if (Error != ERROR_SUCCESS) {
+ if (Error == ERROR_FILE_NOT_FOUND)
+ goto done;
+
+ SetLastError(Error);
+ goto fail3;
+ }
+
+ if (Type != REG_SZ) {
+ SetLastError(ERROR_BAD_FORMAT);
+ goto fail4;
+ }
+
+ if (!Context->RebootPending)
+ PromptForReboot(Reboot);
+
+ (VOID) RegDeleteValue(Context->RequestKey, "Reboot");
+
+done:
+ free(Reboot);
+
+ Log("<====");
+
+ return;
+
+fail4:
+ Log("fail4");
+
+fail3:
+ Log("fail3");
+
+ free(Reboot);
+
+fail2:
+ Log("fail2");
+
+fail1:
+ Error = GetLastError();
+
+ {
+ PTCHAR Message;
+ Message = GetErrorMessage(Error);
+ Log("fail1 (%s)", Message);
+ LocalFree(Message);
+ }
+}
+
static VOID
CheckRequestKey(
VOID
Log("====>");
+ CheckRebootValue();
+
Error = RegNotifyChangeKeyValue(Context->RequestKey,
TRUE,
REG_NOTIFY_CHANGE_LAST_SET,
}
}
+static BOOL
+AcquireShutdownPrivilege(
+ VOID
+ )
+{
+ HANDLE Token;
+ TOKEN_PRIVILEGES New;
+ BOOL Success;
+ HRESULT Error;
+
+ Log("====>");
+
+ New.PrivilegeCount = 1;
+
+ Success = LookupPrivilegeValue(NULL,
+ SE_SHUTDOWN_NAME,
+ &New.Privileges[0].Luid);
+
+ if (!Success)
+ goto fail1;
+
+ New.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED;
+
+ Success = OpenProcessToken(GetCurrentProcess(),
+ TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY,
+ &Token);
+
+ if (!Success)
+ goto fail2;
+
+ Success = AdjustTokenPrivileges(Token,
+ FALSE,
+ &New,
+ 0,
+ NULL,
+ NULL);
+
+ if (!Success)
+ goto fail3;
+
+ CloseHandle(Token);
+
+ Log("<====");
+
+ return TRUE;
+
+fail3:
+ Log("fail3");
+
+ CloseHandle(Token);
+
+fail2:
+ Log("fail2");
+
+fail1:
+ Error = GetLastError();
+
+ {
+ PTCHAR Message;
+ Message = GetErrorMessage(Error);
+ Log("fail1 (%s)", Message);
+ LocalFree(Message);
+ }
+
+ return FALSE;
+}
+
VOID WINAPI
MonitorMain(
_In_ DWORD argc,
Log("====>");
+ Success = AcquireShutdownPrivilege();
+
+ if (!Success)
+ goto fail1;
+
Context->Service = RegisterServiceCtrlHandlerEx(MONITOR_NAME,
MonitorCtrlHandlerEx,
NULL);
if (Context->Service == NULL)
- goto fail1;
+ goto fail2;
Context->EventLog = RegisterEventSource(NULL,
MONITOR_NAME);
if (Context->EventLog == NULL)
- goto fail2;
+ goto fail3;
Context->Status.dwServiceType = SERVICE_WIN32_OWN_PROCESS;
Context->Status.dwServiceSpecificExitCode = 0;
NULL);
if (Context->StopEvent == NULL)
- goto fail3;
+ goto fail4;
Context->RequestEvent = CreateEvent(NULL,
TRUE,
NULL);
if (Context->RequestEvent == NULL)
- goto fail4;
+ goto fail5;
Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
REQUEST_KEY,
&Context->RequestKey);
if (Error != ERROR_SUCCESS)
- goto fail5;
+ goto fail6;
SetEvent(Context->RequestEvent);
return;
-fail5:
- Log("fail5");
+fail6:
+ Log("fail6");
ReportStatus(SERVICE_STOPPED, GetLastError(), 0);
CloseHandle(Context->RequestEvent);
+fail5:
+ Log("fail5");
+
+ CloseHandle(Context->StopEvent);
+
fail4:
Log("fail4");
- CloseHandle(Context->StopEvent);
+ (VOID) DeregisterEventSource(Context->EventLog);
fail3:
Log("fail3");
- (VOID) DeregisterEventSource(Context->EventLog);
-
fail2:
Log("fail2");