]> xenbits.xensource.com Git - pvdrivers/win/xennet.git/commitdiff
Update co-installer
authorPaul Durrant <paul.durrant@citrix.com>
Thu, 19 Sep 2013 10:25:55 +0000 (11:25 +0100)
committerPaul Durrant <paul.durrant@citrix.com>
Thu, 19 Sep 2013 10:25:55 +0000 (11:25 +0100)
Signed-off-by: Paul Durrant <paul.durrant@citrix.com>
src/coinst/coinst.c

index 90bc0b97f1bcfdcf579c0b8d5d69720ee456cdcf..0380fd83a220bd25807e7b08eda4b89b55d258e8 100644 (file)
@@ -60,15 +60,24 @@ __user_code;
 #define ADDRESSES_KEY(_Driver)  \
         SERVICE_KEY(_Driver) ## "\\Addresses"
 
-#define ALIASES_KEY(_Driver)  \
+#define ALIASES_KEY(_Driver)    \
         SERVICE_KEY(_Driver) ## "\\Aliases"
 
 #define UNPLUG_KEY(_Driver)     \
         SERVICE_KEY(_Driver) ## "\\Unplug"
 
-#define CLASS_KEY "SYSTEM\\CurrentControlSet\\Control\\Class"
+#define CONTROL_KEY "SYSTEM\\CurrentControlSet\\Control"
 
-#define NSI_KEY "SYSTEM\\CurrentControlSet\\Control\\Nsi"
+#define CLASS_KEY   \
+        CONTROL_KEY ## "\\Class"
+
+#define NSI_KEY \
+        CONTROL_KEY ## "\\Nsi"
+
+#define SOFTWARE_KEY "SOFTWARE\\Citrix"
+
+#define INSTALLER_KEY   \
+        SOFTWARE_KEY ## "\\XenToolsNetSettings\\XEN\\VIF"
 
 static VOID
 #pragma prefast(suppress:6262) // Function uses '1036' bytes of stack: exceeds /analyze:stacksize'1024'
@@ -115,8 +124,8 @@ __Log(
 #define Log(_Format, ...) \
         __Log(__MODULE__ "|" __FUNCTION__ ": " _Format, __VA_ARGS__)
 
-static PTCHAR
-GetErrorMessage(
+static FORCEINLINE PTCHAR
+__GetErrorMessage(
     IN  DWORD   Error
     )
 {
@@ -143,8 +152,8 @@ GetErrorMessage(
     return Message;
 }
 
-static const CHAR *
-FunctionName(
+static FORCEINLINE const CHAR *
+__FunctionName(
     IN  DI_FUNCTION Function
     )
 {
@@ -226,7 +235,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -264,12 +273,10 @@ GetProperty(
 
     PropertyLength += sizeof (TCHAR);
 
-    Property = malloc(PropertyLength);
+    Property = calloc(1, PropertyLength);
     if (Property == NULL)
         goto fail3;
 
-    memset(Property, 0, PropertyLength);
-
     if (!SetupDiGetDeviceRegistryProperty(DeviceInfoSet,
                                           DeviceInfoData,
                                           Index,
@@ -296,7 +303,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -373,7 +380,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -443,7 +450,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -503,12 +510,10 @@ GetPermanentAddress(
 
     BufferLength = MaxValueLength + sizeof (TCHAR);
 
-    Buffer = malloc(BufferLength);
+    Buffer = calloc(1, BufferLength);
     if (Buffer == NULL)
         goto fail4;
 
-    memset(Buffer, 0, BufferLength);
-
     Error = RegQueryValueEx(AddressesKey,
                             Location,
                             NULL,
@@ -525,7 +530,7 @@ GetPermanentAddress(
         goto fail6;
     }
 
-    Address = malloc(sizeof (ETHERNET_ADDRESS));
+    Address = calloc(1, sizeof (ETHERNET_ADDRESS));
     if (Address == NULL)
         goto fail7;
 
@@ -584,7 +589,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -628,7 +633,7 @@ GetNetLuid(
     goto done;
 
 found:
-    *NetLuid = malloc(sizeof (NET_LUID));
+    *NetLuid = calloc(1, sizeof (NET_LUID));
     if (*NetLuid == NULL)
         goto fail2;
 
@@ -650,7 +655,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -710,7 +715,7 @@ fail1:
 
     {
         PTCHAR  Message;
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -851,7 +856,7 @@ found:
     NameLength = (DWORD)(sizeof ("{XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}\\") +
                          ((strlen(SubKeyName) + 1) * sizeof (TCHAR)));
 
-    *Name = malloc(NameLength);
+    *Name = calloc(1, NameLength);
     if (*Name == NULL)
         goto fail7;
 
@@ -914,7 +919,7 @@ fail1:
 
     {
         PTCHAR  Message;
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -986,7 +991,7 @@ fail1:
 
     {
         PTCHAR  Message;
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1043,12 +1048,10 @@ GetAliasSoftwareKeyName(
 
     NameLength = MaxValueLength + sizeof (TCHAR);
 
-    *Name = malloc(NameLength);
+    *Name = calloc(1, NameLength);
     if (Name == NULL)
         goto fail4;
 
-    memset(*Name, 0, NameLength);
-
     Error = RegQueryValueEx(AliasesKey,
                             Location,
                             NULL,
@@ -1065,8 +1068,6 @@ GetAliasSoftwareKeyName(
         goto fail6;
     }
 
-    Log("%s", (strlen(*Name) == 0) ? "[NONE]" : *Name);
-
     if (strlen(*Name) == 0) {
         free(*Name);
         *Name = NULL;
@@ -1076,6 +1077,8 @@ GetAliasSoftwareKeyName(
 
     free(Location);
 
+    Log("%s", (*Name == NULL) ? "[NONE]" : *Name);
+
     return TRUE;
 
 fail6:
@@ -1104,7 +1107,69 @@ fail1:
 
     {
         PTCHAR  Message;
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
+        Log("fail1 (%s)", Message);
+        LocalFree(Message);
+    }
+
+    return FALSE;
+}
+
+static BOOLEAN
+ClearAliasSoftwareKeyName(
+    IN  HDEVINFO            DeviceInfoSet,
+    IN  PSP_DEVINFO_DATA    DeviceInfoData
+    )
+{
+    PTCHAR                  Location;
+    HKEY                    AliasesKey;
+    HRESULT                 Error;
+
+    Location = GetProperty(DeviceInfoSet,
+                           DeviceInfoData,
+                           SPDRP_LOCATION_INFORMATION);
+    if (Location == NULL)
+        goto fail1;
+
+    Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
+                         ALIASES_KEY(XENVIF),
+                         0,
+                         KEY_ALL_ACCESS,
+                         &AliasesKey);
+    if (Error != ERROR_SUCCESS) {
+        SetLastError(Error);
+        goto fail2;
+    }
+
+    Error = RegDeleteValue(AliasesKey,
+                           Location);
+    if (Error != ERROR_SUCCESS) {
+        SetLastError(Error);
+        goto fail3;
+    }
+
+    RegCloseKey(AliasesKey);
+
+    free(Location);
+
+    return TRUE;
+
+fail3:
+    Log("fail3");
+
+    RegCloseKey(AliasesKey);
+
+fail2:
+    Log("fail2");
+
+    free(Location);
+
+fail1:
+    Error = GetLastError();
+
+    {
+        PTCHAR  Message;
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1152,7 +1217,181 @@ fail1:
 
     {
         PTCHAR  Message;
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
+        Log("fail1 (%s)", Message);
+        LocalFree(Message);
+    }
+
+    return NULL;
+}
+
+static BOOLEAN
+GetInstallerSettingsKeyName(
+    IN  HDEVINFO            DeviceInfoSet,
+    IN  PSP_DEVINFO_DATA    DeviceInfoData,
+    OUT PTCHAR              *Name
+    )
+{
+    PTCHAR                  Location;
+    HKEY                    InstallerKey;
+    DWORD                   MaxValueLength;
+    DWORD                   NameLength;
+    DWORD                   Type;
+    HRESULT                 Error;
+
+    Log("====>");
+
+    Location = GetProperty(DeviceInfoSet,
+                           DeviceInfoData,
+                           SPDRP_LOCATION_INFORMATION);
+    if (Location == NULL)
+        goto fail1;
+
+    *Name = NULL;
+
+    Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
+                         INSTALLER_KEY,
+                         0,
+                         KEY_READ,
+                         &InstallerKey);
+    if (Error != ERROR_SUCCESS) {
+        if (Error == ERROR_FILE_NOT_FOUND)
+            goto done;
+
+        SetLastError(Error);
+        goto fail2;
+    }
+
+    Error = RegQueryInfoKey(InstallerKey,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            &MaxValueLength,
+                            NULL,
+                            NULL);
+    if (Error != ERROR_SUCCESS) {
+        SetLastError(Error);
+        goto fail3;
+    }
+
+    NameLength = MaxValueLength + sizeof (TCHAR);
+
+    *Name = calloc(1, NameLength);
+    if (Name == NULL)
+        goto fail4;
+
+    Error = RegQueryValueEx(InstallerKey,
+                            Location,
+                            NULL,
+                            &Type,
+                            (LPBYTE)*Name,
+                            &NameLength);
+    if (Error != ERROR_SUCCESS) {
+        SetLastError(Error);
+        goto fail5;
+    }
+
+    if (Type != REG_SZ) {
+        SetLastError(ERROR_BAD_FORMAT);
+        goto fail6;
+    }
+
+    if (strlen(*Name) == 0) {
+        free(*Name);
+        *Name = NULL;
+    }
+
+    RegCloseKey(InstallerKey);
+
+    free(Location);
+
+done:
+    Log("%s", (*Name == NULL) ? "[NONE]" : *Name);
+
+    Log("<====");
+
+    return TRUE;
+
+fail6:
+    Log("fail6");
+
+fail5:
+    Log("fail5");
+
+    free(*Name);
+
+fail4:
+    Log("fail4");
+
+fail3:
+    Log("fail3");
+
+    RegCloseKey(InstallerKey);
+
+fail2:
+    Log("fail2");
+
+    free(Location);
+
+fail1:
+    Error = GetLastError();
+
+    {
+        PTCHAR  Message;
+        Message = __GetErrorMessage(Error);
+        Log("fail1 (%s)", Message);
+        LocalFree(Message);
+    }
+
+    return FALSE;
+}
+
+static HKEY
+OpenInstallerSettingsKey(
+    IN  PTCHAR  Name
+    )
+{
+    HRESULT     Result;
+    TCHAR       KeyName[MAX_PATH];
+    HKEY        Key;
+    HRESULT     Error;
+
+    Result = StringCbPrintf(KeyName,
+                            MAX_PATH,
+                            "%s\\%s",
+                            INSTALLER_KEY,
+                            Name);
+    if (!SUCCEEDED(Result)) {
+        SetLastError(ERROR_BUFFER_OVERFLOW);
+        goto fail1;
+    }
+
+    Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
+                         KeyName,
+                         0,
+                         KEY_READ,
+                         &Key);
+    if (Error != ERROR_SUCCESS) {
+        SetLastError(Error);
+        goto fail2;
+    }
+
+    return Key;
+
+fail2:
+    Log("fail2");
+
+fail1:
+    Error = GetLastError();
+
+    {
+        PTCHAR  Message;
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1201,12 +1440,10 @@ GetInterfaceName(
 
     RootDeviceLength = MaxValueLength + sizeof (TCHAR);
 
-    RootDevice = malloc(RootDeviceLength);
+    RootDevice = calloc(1, RootDeviceLength);
     if (RootDevice == NULL)
         goto fail2;
 
-    memset(RootDevice, 0, RootDeviceLength);
-
     Error = RegQueryValueEx(LinkageKey,
                             "RootDevice",
                             NULL,
@@ -1256,7 +1493,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1416,7 +1653,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1442,7 +1679,7 @@ CopyParameters(
                       strlen(DestinationName) +
                       1) * sizeof (TCHAR));
 
-    DestinationKeyName = malloc(Length);
+    DestinationKeyName = calloc(1, Length);
     if (DestinationKeyName == NULL)
         goto fail1;
 
@@ -1460,7 +1697,7 @@ CopyParameters(
                       strlen(SourceName) +
                       1) * sizeof (TCHAR));
 
-    SourceKeyName = malloc(Length);
+    SourceKeyName = calloc(1, Length);
     if (SourceKeyName == NULL)
         goto fail3;
 
@@ -1500,7 +1737,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1669,7 +1906,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1713,12 +1950,10 @@ GetIpVersion6AddressValueName(
 
     ValueLength = MaxValueLength;
 
-    Value = malloc(ValueLength);
+    Value = calloc(1, ValueLength);
     if (Value == NULL)
         goto fail2;
 
-    memset(Value, 0, ValueLength);
-
     Error = RegQueryValueEx(Key,
                             "NetLuidIndex",
                             NULL,
@@ -1757,7 +1992,7 @@ GetIpVersion6AddressValueName(
 
     BufferLength = ((sizeof (ULONG64) * 2) + 1) * sizeof (TCHAR);
 
-    Buffer = malloc(BufferLength);
+    Buffer = calloc(1, BufferLength);
     if (Buffer == NULL)
         goto fail7;
 
@@ -1805,7 +2040,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1892,7 +2127,57 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
+        Log("fail1 (%s)", Message);
+        LocalFree(Message);
+    }
+
+    return FALSE;
+}
+
+static BOOLEAN
+CopySettingsFromInstaller(
+    IN  HDEVINFO                    DeviceInfoSet,
+    IN  PSP_DEVINFO_DATA            DeviceInfoData,
+    IN  PTCHAR                      Name
+    )
+{
+    HKEY                            Source;
+    HKEY                            Destination;
+    BOOLEAN                         Success;
+    HRESULT                         Error;
+
+    Source = OpenInstallerSettingsKey(Name);
+    if (Source == NULL)
+        goto fail1;
+
+    Destination = OpenSoftwareKey(DeviceInfoSet, DeviceInfoData);
+    if (Destination == NULL)
+        goto fail2;
+
+    Success = CopySettings(Destination, Source);
+    if (!Success)
+        goto fail3;
+
+    return TRUE;
+
+fail3:
+    Log("fail3");
+
+    RegCloseKey(Destination);
+
+fail2:
+    Log("fail2");
+
+    RegCloseKey(Source);
+
+fail1:
+    Error = GetLastError();
+
+    {
+        PTCHAR  Message;
+
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1942,7 +2227,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -1992,7 +2277,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2006,7 +2291,11 @@ RequestUnplug(
     )
 {
     HKEY    UnplugKey;
+    DWORD   NicsLength;
+    PTCHAR  Nics;
+    DWORD   Offset;
     HRESULT Error;
+    HRESULT Result;
 
     Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
                          UNPLUG_KEY(XENFILT),
@@ -2018,21 +2307,66 @@ RequestUnplug(
         goto fail1;
     }
 
+    NicsLength = (DWORD)((strlen("XENVIF") + 1 +
+                          strlen("XENNET") + 1 +
+                          1) * sizeof (TCHAR));
+
+    Nics = calloc(1, NicsLength);
+    if (Nics == NULL)
+        goto fail2;
+
+    Offset = 0;
+
+    Result = StringCbPrintf(Nics + Offset,
+                            NicsLength - (Offset * sizeof (TCHAR)),
+                            "XENVIF");
+    if (!SUCCEEDED(Result)) {
+        SetLastError(ERROR_BUFFER_OVERFLOW);
+        goto fail3;
+    }
+
+    Offset += (DWORD)(strlen("XENVIF") + 1);
+
+    Result = StringCbPrintf(Nics + Offset,
+                            NicsLength - (Offset * sizeof (TCHAR)),
+                            "XENNET");
+    if (!SUCCEEDED(Result)) {
+        SetLastError(ERROR_BUFFER_OVERFLOW);
+        goto fail4;
+    }
+
+    Offset += (DWORD)(strlen("XENNET") + 1);
+
+    *(Nics + Offset) = '\0';
+
     Error = RegSetValueEx(UnplugKey,
                           "NICS",
                           0,
-                          REG_SZ,
-                          (LPBYTE)"XENNET",
-                          (DWORD)sizeof ("XENNET"));
+                          REG_MULTI_SZ,
+                          (LPBYTE)Nics,
+                          NicsLength);
     if (Error != ERROR_SUCCESS) {
         SetLastError(Error);
-        goto fail2;
+        goto fail5;
     }
 
+    free(Nics);
+
     RegCloseKey(UnplugKey);
 
     return TRUE;
 
+fail5:
+    Log("fail5");
+
+fail4:
+    Log("fail4");
+
+fail3:
+    Log("fail3");
+
+    free(Nics);
+
 fail2:
     Log("fail2");
 
@@ -2044,7 +2378,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2133,7 +2467,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2224,7 +2558,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2263,30 +2597,6 @@ fail1:
     return FALSE;
 }
 
-static BOOLEAN
-IsUpgrade(
-    IN  HDEVINFO            DeviceInfoSet,
-    IN  PSP_DEVINFO_DATA    DeviceInfoData
-    )
-{
-    HKEY                    SoftwareKey;
-    PTCHAR                  InterfaceName;
-
-    SoftwareKey = OpenSoftwareKey(DeviceInfoSet, DeviceInfoData);
-
-    if (SoftwareKey == NULL)
-        return FALSE;
-
-    InterfaceName = GetInterfaceName(SoftwareKey);
-    RegCloseKey(SoftwareKey);
-
-    if (InterfaceName == NULL)
-        return FALSE;
-
-    free(InterfaceName);
-    return TRUE;
-}
-
 static FORCEINLINE HRESULT
 __DifInstallPreProcess(
     IN  HDEVINFO                    DeviceInfoSet,
@@ -2299,40 +2609,39 @@ __DifInstallPreProcess(
     BOOLEAN                         Success;
     HRESULT                         Error;
 
-    Log("====>");
-
-    if (!IsUpgrade(DeviceInfoSet, DeviceInfoData)) {
-        Log("INITIAL INSTALLATION");
+    UNREFERENCED_PARAMETER(Context);
 
-        Address = GetPermanentAddress(DeviceInfoSet, DeviceInfoData);
-        if (Address == NULL)
-            goto fail1;
+    Log("====>");
 
-        Success = FindAliasSoftwareKeyName(Address, &Name);
-        if (!Success)
-            goto fail2;
+    Success = GetAliasSoftwareKeyName(DeviceInfoSet,
+                                      DeviceInfoData,
+                                      &Name);
+    if (Success)
+        goto done;
 
-        Success = SetAliasSoftwareKeyName(DeviceInfoSet,
-                                          DeviceInfoData,
-                                          Name);
-        if (!Success)
-            goto fail3;
+    Address = GetPermanentAddress(DeviceInfoSet, DeviceInfoData);
+    if (Address == NULL)
+        goto fail1;
 
-        if (Name != NULL)
-            free(Name);
+    Success = FindAliasSoftwareKeyName(Address, &Name);
+    if (!Success)
+        goto fail2;
 
-        free(Address);
+    Success = SetAliasSoftwareKeyName(DeviceInfoSet,
+                                      DeviceInfoData,
+                                      Name);
+    if (!Success)
+        goto fail3;
 
-        Context->PrivateData = (PVOID)FALSE;
-    } else {
-        Log("UPGRADE INSTALLATION");
+    free(Address);
 
-        Context->PrivateData = (PVOID)TRUE;
-    }
+done:
+    if (Name != NULL)
+        free(Name);
 
     Log("<====");
 
-    return ERROR_DI_POSTPROCESSING_REQUIRED
+    return NO_ERROR
 
 fail3:
     Log("fail3");
@@ -2350,7 +2659,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2365,7 +2674,6 @@ __DifInstallPostProcess(
     IN  PCOINSTALLER_CONTEXT_DATA   Context
     )
 {
-    BOOLEAN                         Upgrade;
     BOOLEAN                         Success;
     PTCHAR                          Name;
     DWORD                           Count;
@@ -2382,21 +2690,35 @@ __DifInstallPostProcess(
         goto fail1;
     }
 
-    Upgrade = (BOOLEAN)(ULONG_PTR)Context->PrivateData;
-
-    if (Upgrade)
-        goto done;
-
     Success = SetFriendlyName(DeviceInfoSet,
                               DeviceInfoData);
     if (!Success)
         goto fail2;
 
+    Success = GetInstallerSettingsKeyName(DeviceInfoSet,
+                                          DeviceInfoData,
+                                          &Name);
+    if (!Success)
+        goto fail3;
+
+    if (Name != NULL) {
+        Success = CopySettingsFromInstaller(DeviceInfoSet,
+                                            DeviceInfoData,
+                                            Name);
+
+        free(Name);
+
+        if (!Success)
+            goto fail4;
+
+        goto done;
+    }
+
     Success = GetAliasSoftwareKeyName(DeviceInfoSet,
                                       DeviceInfoData,
                                       &Name);
     if (!Success)
-        goto fail3;
+        goto fail5;
 
     if (Name != NULL) {
         Success = CopySettingsFromAlias(DeviceInfoSet,
@@ -2406,23 +2728,31 @@ __DifInstallPostProcess(
         free(Name);
 
         if (!Success)
-            goto fail4;
+            goto fail6;
     }
 
+done:
     Success = RequestUnplug();
     if (!Success)
-        goto fail5;
+        goto fail7;
 
-done:
     Success = IncrementServiceCount(&Count);
     if (!Success)
-        goto fail6;
+        goto fail8;
 
     if (Count == 1)
         (VOID) RequestReboot(DeviceInfoSet, DeviceInfoData);
 
+    Log("<====");
+
     return NO_ERROR;
 
+fail8:
+    Log("fail8");
+
+fail7:
+    Log("fail7");
+
 fail6:
     Log("fail6");
 
@@ -2444,7 +2774,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2471,9 +2801,20 @@ DifInstall(
 
     Log("Flags = %08x", DeviceInstallParams.Flags);
 
-    Error = (!Context->PostProcessing) ?
-            __DifInstallPreProcess(DeviceInfoSet, DeviceInfoData, Context) :
-            __DifInstallPostProcess(DeviceInfoSet, DeviceInfoData, Context);
+    if (!Context->PostProcessing) {
+        Error = __DifInstallPreProcess(DeviceInfoSet, DeviceInfoData, Context);
+
+        Context->PrivateData = (PVOID)Error;
+
+        Error = ERROR_DI_POSTPROCESSING_REQUIRED; 
+    } else {
+        Error = (HRESULT)(DWORD_PTR)Context->PrivateData;
+        
+        if (Error == NO_ERROR)
+            (VOID) __DifInstallPostProcess(DeviceInfoSet, DeviceInfoData, Context);
+
+        Error = NO_ERROR; 
+    }
 
     return Error;
 
@@ -2483,7 +2824,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2500,7 +2841,6 @@ __DifRemovePreProcess(
 {
     BOOLEAN                         Success;
     PTCHAR                          Name;
-    DWORD                           Count;
     HRESULT                         Error;
 
     UNREFERENCED_PARAMETER(Context);
@@ -2524,18 +2864,9 @@ __DifRemovePreProcess(
             goto fail2;
     }
 
-    Success = DecrementServiceCount(&Count);
-    if (!Success)
-        goto fail3;    
-
-    Context->PrivateData = (Count == 0) ? (PVOID)TRUE : (PVOID)FALSE;
-
     Log("<====");
 
-    return ERROR_DI_POSTPROCESSING_REQUIRED; 
-
-fail3:
-    Log("fail3");
+    return NO_ERROR;
 
 fail2:
     Log("fail2");
@@ -2546,7 +2877,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2561,8 +2892,8 @@ __DifRemovePostProcess(
     IN  PCOINSTALLER_CONTEXT_DATA   Context
     )
 {
-    BOOLEAN                         NeedReboot;
     BOOLEAN                         Success;
+    DWORD                           Count;
     HRESULT                         Error;
 
     Log("====>");
@@ -2573,20 +2904,25 @@ __DifRemovePostProcess(
         goto fail1;
     }
 
-    NeedReboot = (BOOLEAN)(ULONG_PTR)Context->PrivateData;
-
-    if (!NeedReboot)
-        goto done;
-
-    Success = RequestReboot(DeviceInfoSet, DeviceInfoData);
+    Success = ClearAliasSoftwareKeyName(DeviceInfoSet,
+                                        DeviceInfoData);
     if (!Success)
         goto fail2;
 
-done:
+    Success = DecrementServiceCount(&Count);
+    if (!Success)
+        goto fail3;
+
+    if (Count == 0)
+        (VOID) RequestReboot(DeviceInfoSet, DeviceInfoData);
+
     Log("<====");
 
     return NO_ERROR;
 
+fail3:
+    Log("fail3");
+
 fail2:
     Log("fail2");
 
@@ -2596,7 +2932,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2623,9 +2959,20 @@ DifRemove(
 
     Log("Flags = %08x", DeviceInstallParams.Flags);
 
-    Error = (!Context->PostProcessing) ?
-            __DifRemovePreProcess(DeviceInfoSet, DeviceInfoData, Context) :
-            __DifRemovePostProcess(DeviceInfoSet, DeviceInfoData, Context);
+    if (!Context->PostProcessing) {
+        Error = __DifRemovePreProcess(DeviceInfoSet, DeviceInfoData, Context);
+
+        Context->PrivateData = (PVOID)Error;
+
+        Error = ERROR_DI_POSTPROCESSING_REQUIRED; 
+    } else {
+        Error = (HRESULT)(DWORD_PTR)Context->PrivateData;
+        
+        if (Error == NO_ERROR)
+            (VOID) __DifRemovePostProcess(DeviceInfoSet, DeviceInfoData, Context);
+
+        Error = NO_ERROR; 
+    }
 
     return Error;
 
@@ -2635,7 +2982,7 @@ fail1:
     {
         PTCHAR  Message;
 
-        Message = GetErrorMessage(Error);
+        Message = __GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
@@ -2659,10 +3006,10 @@ Entry(
 
     if (!Context->PostProcessing) {
         Log("%s PreProcessing",
-            FunctionName(Function));
+            __FunctionName(Function));
     } else {
         Log("%s PostProcessing (%08x)",
-            FunctionName(Function),
+            __FunctionName(Function),
             Context->InstallResult);
     }