]> xenbits.xensource.com Git - xcp/xen-api.git/commitdiff
Add a new type: Db_ref.t which may be either (In_memory x) or Remote. At the xapi...
authorDavid Scott <dave.scott@eu.citrix.com>
Wed, 26 Jan 2011 17:39:06 +0000 (17:39 +0000)
committerDavid Scott <dave.scott@eu.citrix.com>
Wed, 26 Jan 2011 17:39:06 +0000 (17:39 +0000)
Signed-off-by: David Scott <dave.scott@eu.citrix.com>
59 files changed:
ocaml/database/backend_xml.ml
ocaml/database/database_server_main.ml
ocaml/database/database_test.ml
ocaml/database/db_backend.ml
ocaml/database/db_cache.ml
ocaml/database/db_cache_impl.ml
ocaml/database/db_cache_impl.mli
ocaml/database/db_interface.ml
ocaml/database/db_ref.ml [new file with mode: 0644]
ocaml/database/db_remote_cache_access_v1.ml
ocaml/database/db_remote_cache_access_v2.ml
ocaml/database/db_rpc_client_v1.ml
ocaml/database/db_rpc_client_v2.ml
ocaml/database/eventgen.ml
ocaml/database/ref_index.ml
ocaml/db_process/xapi-db-process.ml
ocaml/idl/ocaml_backend/OMakefile
ocaml/idl/ocaml_backend/context.ml
ocaml/idl/ocaml_backend/context.mli
ocaml/idl/ocaml_backend/gen_db_actions.ml
ocaml/xapi/cli_operations.ml
ocaml/xapi/console.ml
ocaml/xapi/create_misc.ml
ocaml/xapi/db.ml
ocaml/xapi/db_gc.ml
ocaml/xapi/dbsync.ml
ocaml/xapi/dbsync_master.ml
ocaml/xapi/dbsync_slave.ml
ocaml/xapi/export.ml
ocaml/xapi/helpers.ml
ocaml/xapi/import_raw_vdi.ml
ocaml/xapi/message_forwarding.ml
ocaml/xapi/monitor_master.ml
ocaml/xapi/monitor_rrds.ml
ocaml/xapi/monitor_self.ml
ocaml/xapi/nm.ml
ocaml/xapi/pool_db_backup.ml
ocaml/xapi/redo_log_usage.ml
ocaml/xapi/xapi.ml
ocaml/xapi/xapi_guest_agent.ml
ocaml/xapi/xapi_ha.ml
ocaml/xapi/xapi_host.ml
ocaml/xapi/xapi_host.mli
ocaml/xapi/xapi_host_helpers.ml
ocaml/xapi/xapi_http.ml
ocaml/xapi/xapi_mgmt_iface.ml
ocaml/xapi/xapi_network.ml
ocaml/xapi/xapi_pif.ml
ocaml/xapi/xapi_pif.mli
ocaml/xapi/xapi_pool.ml
ocaml/xapi/xapi_pool.mli
ocaml/xapi/xapi_vif_helpers.ml
ocaml/xapi/xapi_vm.ml
ocaml/xapi/xapi_vm.mli
ocaml/xapi/xapi_vm_clone.ml
ocaml/xapi/xapi_vm_lifecycle.ml
ocaml/xapi/xapi_vm_placement.ml
ocaml/xapi/xapi_vm_snapshot.ml
ocaml/xapi/xha_metadata_vdi.ml

index 478e9c7ee44daab15b7a4148ad646c79141d82df..820929ade3bdfb292701a75602d94d7ef4726a85 100644 (file)
@@ -79,7 +79,7 @@ let flush dbconn db =
 (* NB We don't do incremental flushing *)
 
 let flush_dirty dbconn =
-       let db = get_database () in
+       let db = Db_ref.get_database (Db_backend.make ()) in
        let g = Manifest.generation (Database.manifest db) in
        if g > dbconn.Parse_db_conf.last_generation_count then begin
                flush dbconn db;
index fa9529fd8d02fc2d14325449b6a611d45171f769..10d4007faba90c4816db87b05c270acdee23e76f 100644 (file)
@@ -54,10 +54,9 @@ let _ =
                                        Printf.printf "Database path: %s\n%!" db_filename;
                                        let db = Parse_db_conf.make db_filename in
                                        Db_conn_store.initialise_db_connections [ db ];
-                                       Db_cache.set_master true;
-
-                                       Db_cache_impl.make [ db ] (Schema.of_datamodel ());
-                                       Db_cache_impl.sync [ db ] (Db_backend.get_database ());
+                                       let t = Db_backend.make () in                                   
+                                       Db_cache_impl.make t [ db ] (Schema.of_datamodel ());
+                                       Db_cache_impl.sync [ db ] (Db_ref.get_database t);
 
                                        Unixext.unlink_safe !listen_path;
                                        let sockaddr = Unix.ADDR_UNIX !listen_path in
index aa7fcc4db28cc89ea2676b1fc138197a22371d1a..74758ccd92e5b0b28525588abc43550b4626bcd0 100644 (file)
@@ -155,20 +155,20 @@ module Tests = functor(Client: Db_interface.DB_ACCESS) -> struct
                        )
 
        (* Verify the ref_index contents are correct for a given [tblname] and [key] (uuid/ref) *)
-       let check_ref_index tblname key = match Ref_index.lookup key with
+       let check_ref_index t tblname key = match Ref_index.lookup key with
                | None ->
                        (* We should fail to find the row *)
                        expect_missing_row tblname key
-                               (fun () -> let (_: string) = Client.read_field tblname "uuid" key in ());
+                               (fun () -> let (_: string) = Client.read_field t tblname "uuid" key in ());
                        expect_missing_uuid tblname key
-                               (fun () -> let (_: string) = Client.db_get_by_uuid tblname key in ())
+                               (fun () -> let (_: string) = Client.db_get_by_uuid t tblname key in ())
                | Some { Ref_index.name_label = name_label; uuid = uuid; _ref = _ref } ->
                        (* key should be either uuid or _ref *)
                        if key <> uuid && (key <> _ref)
                        then failwith (Printf.sprintf "check_ref_index %s key %s: got ref %s uuid %s" tblname key _ref uuid);
-                       let real_ref = if Client.is_valid_ref key then key else Client.db_get_by_uuid tblname key in
+                       let real_ref = if Client.is_valid_ref t key then key else Client.db_get_by_uuid t tblname key in
                        let real_name_label = 
-                               try Some (Client.read_field tblname "name__label" real_ref)
+                               try Some (Client.read_field t tblname "name__label" real_ref)
                                with _ -> None in
                        if name_label <> real_name_label
                        then failwith (Printf.sprintf "check_ref_index %s key %s: ref_index name_label = %s; db has %s" tblname key (Opt.default "None" name_label) (Opt.default "None" real_name_label))
@@ -248,6 +248,8 @@ module Tests = functor(Client: Db_interface.DB_ACCESS) -> struct
                let invalid_ref = "foo" in
                let invalid_uuid = "bar" in
                
+               let t = if in_process then Db_backend.make () else Db_ref.Remote in
+
        let vbd_ref = "waz" in
                let vbd_uuid = "whatever" in
 
@@ -256,255 +258,255 @@ module Tests = functor(Client: Db_interface.DB_ACCESS) -> struct
                (* Before we begin, clear out any old state: *)
                expect_missing_row "VM" valid_ref
                        (fun () ->
-                               Client.delete_row "VM" valid_ref;
+                               Client.delete_row "VM" valid_ref;
                );
-               if in_process then check_ref_index "VM" valid_ref;
+               if in_process then check_ref_index "VM" valid_ref;
 
                expect_missing_row "VBD" vbd_ref
                (fun () ->
-                       Client.delete_row "VBD" vbd_ref;
+                       Client.delete_row "VBD" vbd_ref;
                );
-               if in_process then check_ref_index "VBD" vbd_ref;
+               if in_process then check_ref_index "VBD" vbd_ref;
 
                Printf.printf "Deleted stale state from previous test\n";
                
                Printf.printf "get_table_from_ref <invalid ref>\n";
                begin
-                       match Client.get_table_from_ref invalid_ref with
+                       match Client.get_table_from_ref invalid_ref with
                                | None -> Printf.printf "Reference '%s' has no associated table\n" invalid_ref
                                | Some t -> failwith (Printf.sprintf "Reference '%s' exists in table '%s'" invalid_ref t)
                end;
                Printf.printf "is_valid_ref <invalid_ref>\n";
-               if Client.is_valid_ref invalid_ref then failwith "is_valid_ref <invalid_ref> = true";
+               if Client.is_valid_ref invalid_ref then failwith "is_valid_ref <invalid_ref> = true";
                
                Printf.printf "read_refs <valid tbl>\n";
-               let existing_refs = Client.read_refs "VM" in
+               let existing_refs = Client.read_refs "VM" in
                Printf.printf "VM refs: [ %s ]\n" (String.concat "; " existing_refs);
                Printf.printf "read_refs <invalid tbl>\n";
                expect_missing_tbl "Vm"
                        (fun () ->
-                               let (_: string list) = Client.read_refs "Vm" in
+                               let (_: string list) = Client.read_refs "Vm" in
                                ()
                        );
                Printf.printf "delete_row <invalid ref>\n";
                expect_missing_row "VM" invalid_ref
                        (fun () ->
-                               Client.delete_row "VM" invalid_ref;
+                               Client.delete_row "VM" invalid_ref;
                                failwith "delete_row of a non-existent row silently succeeded"
                        );
                Printf.printf "create_row <unique ref> <unique uuid> <missing required field>\n";
                expect_missing_field "name__label"
                        (fun () ->
                                let broken_vm = List.filter (fun (k, _) -> k <> "name__label") (make_vm valid_ref valid_uuid) in
-                               Client.create_row "VM" broken_vm valid_ref;
+                               Client.create_row "VM" broken_vm valid_ref;
                                failwith "create_row <unique ref> <unique uuid> <missing required field>"
                        );
                Printf.printf "create_row <unique ref> <unique uuid>\n";
-               Client.create_row "VM" (make_vm valid_ref valid_uuid) valid_ref;
-               if in_process then check_ref_index "VM" valid_ref;
+               Client.create_row "VM" (make_vm valid_ref valid_uuid) valid_ref;
+               if in_process then check_ref_index "VM" valid_ref;
                Printf.printf "is_valid_ref <valid ref>\n";
-               if not (Client.is_valid_ref valid_ref)
+               if not (Client.is_valid_ref valid_ref)
                then failwith "is_valid_ref <valid_ref> = false, after create_row";
                Printf.printf "get_table_from_ref <valid ref>\n";
-               begin match Client.get_table_from_ref valid_ref with
+               begin match Client.get_table_from_ref valid_ref with
                        | Some "VM" -> ()
                        | Some t -> failwith "get_table_from_ref <valid ref> : invalid table"
                        | None -> failwith "get_table_from_ref <valid ref> : None"
                end;
                Printf.printf "read_refs includes <valid ref>\n";
-               if not (List.mem valid_ref (Client.read_refs "VM"))
+               if not (List.mem valid_ref (Client.read_refs "VM"))
                then failwith "read_refs did not include <valid ref>";
                
                Printf.printf "create_row <duplicate ref> <unique uuid>\n";
                expect_uniqueness_violation "VM" "_ref" valid_ref
                        (fun () ->
-                               Client.create_row "VM" (make_vm valid_ref (valid_uuid ^ "unique")) valid_ref;
+                               Client.create_row "VM" (make_vm valid_ref (valid_uuid ^ "unique")) valid_ref;
                                failwith "create_row <duplicate ref> <unique uuid>"
                        );
                Printf.printf "create_row <unique ref> <duplicate uuid>\n";
                expect_uniqueness_violation "VM" "uuid" valid_uuid
                        (fun () ->
-                               Client.create_row "VM" (make_vm (valid_ref ^ "unique") valid_uuid) (valid_ref ^ "unique");
+                               Client.create_row "VM" (make_vm (valid_ref ^ "unique") valid_uuid) (valid_ref ^ "unique");
                                failwith "create_row <unique ref> <duplicate uuid>"
                        );
                Printf.printf "db_get_by_uuid <valid uuid>\n";
-               let r = Client.db_get_by_uuid "VM" valid_uuid in
+               let r = Client.db_get_by_uuid "VM" valid_uuid in
                if r <> valid_ref
                then failwith (Printf.sprintf "db_get_by_uuid <valid uuid>: got %s; expected %s" r valid_ref);
                Printf.printf "db_get_by_uuid <invalid uuid>\n";
                expect_missing_uuid "VM" invalid_uuid
                        (fun () ->
-                               let (_: string) = Client.db_get_by_uuid "VM" invalid_uuid in
+                               let (_: string) = Client.db_get_by_uuid "VM" invalid_uuid in
                                failwith "db_get_by_uuid <invalid uuid>"
                        );
                Printf.printf "get_by_name_label <invalid name label>\n";
-               if Client.db_get_by_name_label "VM" invalid_name <> []
+               if Client.db_get_by_name_label "VM" invalid_name <> []
                then failwith "db_get_by_name_label <invalid name label>";
                
                Printf.printf "get_by_name_label <valid name label>\n";
-               if Client.db_get_by_name_label "VM" name <> [ valid_ref ]
+               if Client.db_get_by_name_label "VM" name <> [ valid_ref ]
                then failwith "db_get_by_name_label <valid name label>";
                
                Printf.printf "read_field <valid field> <valid objref>\n";
-               if Client.read_field "VM" "name__label" valid_ref <> name
+               if Client.read_field "VM" "name__label" valid_ref <> name
                then failwith "read_field <valid field> <valid objref> : invalid name";
 
                Printf.printf "read_field <valid defaulted field> <valid objref>\n";
-               if Client.read_field "VM" "protection_policy" valid_ref <> "OpaqueRef:NULL"
+               if Client.read_field "VM" "protection_policy" valid_ref <> "OpaqueRef:NULL"
                then failwith "read_field <valid defaulted field> <valid objref> : invalid protection_policy";
 
                Printf.printf "read_field <valid field> <invalid objref>\n";
                expect_missing_row "VM" invalid_ref
                        (fun () ->
-                               let (_: string) = Client.read_field "VM" "name__label" invalid_ref in
+                               let (_: string) = Client.read_field "VM" "name__label" invalid_ref in
                                failwith "read_field <valid field> <invalid objref>"
                        );
                Printf.printf "read_field <invalid field> <valid objref>\n";
                expect_missing_field "name_label"
                        (fun () ->
-                               let (_: string) = Client.read_field "VM" "name_label" valid_ref in
+                               let (_: string) = Client.read_field "VM" "name_label" valid_ref in
                                failwith "read_field <invalid field> <valid objref>"
                        );
                Printf.printf "read_field <invalid field> <invalid objref>\n";
                expect_missing_row "VM" invalid_ref
                        (fun () ->
-                               let (_: string) = Client.read_field "VM" "name_label" invalid_ref in
+                               let (_: string) = Client.read_field "VM" "name_label" invalid_ref in
                                failwith "read_field <invalid field> <invalid objref>"
                        );
                Printf.printf "read_field_where <valid table> <valid return> <valid field> <valid value>\n";
                let where_name_label = 
                        { Db_cache_types.table = "VM"; return = Escaping.escape_id(["name"; "label"]); where_field="uuid"; where_value = valid_uuid } in
-               let xs = Client.read_field_where where_name_label in
+               let xs = Client.read_field_where where_name_label in
                if not (List.mem name xs)
                then failwith "read_field_where <valid table> <valid return> <valid field> <valid value>";
-               test_invalid_where_record "read_field_where" Client.read_field_where;
+               test_invalid_where_record "read_field_where" (Client.read_field_where t);
                
-               let xs = Client.read_set_ref where_name_label in
+               let xs = Client.read_set_ref where_name_label in
                if not (List.mem name xs)
                then failwith "read_set_ref <valid table> <valid return> <valid field> <valid value>";
-               test_invalid_where_record "read_set_ref" Client.read_set_ref;
+               test_invalid_where_record "read_set_ref" (Client.read_set_ref t);
                
                Printf.printf "write_field <invalid table>\n";
                expect_missing_tbl "Vm"
                        (fun () ->
-                               let (_: unit) = Client.write_field "Vm" "" "" "" in
+                               let (_: unit) = Client.write_field "Vm" "" "" "" in
                                failwith "write_field <invalid table>"
                        );
                Printf.printf "write_field <valid table> <invalid ref>\n";
                expect_missing_row "VM" invalid_ref
                        (fun () ->
-                               let (_: unit) = Client.write_field "VM" invalid_ref "" "" in
+                               let (_: unit) = Client.write_field "VM" invalid_ref "" "" in
                                failwith "write_field <valid table> <invalid ref>"
                        );
                Printf.printf "write_field <valid table> <valid ref> <invalid field>\n";
                expect_missing_field "wibble"
                        (fun () ->
-                               let (_: unit) = Client.write_field "VM" valid_ref "wibble" "" in
+                               let (_: unit) = Client.write_field "VM" valid_ref "wibble" "" in
                                failwith "write_field <valid table> <valid ref> <invalid field>"
                        );
                Printf.printf "write_field <valid table> <valid ref> <valid field>\n";
-               let (_: unit) = Client.write_field "VM" valid_ref (Escaping.escape_id ["name"; "description"]) "description" in
-               if in_process then check_ref_index "VM" valid_ref;              
+               let (_: unit) = Client.write_field "VM" valid_ref (Escaping.escape_id ["name"; "description"]) "description" in
+               if in_process then check_ref_index t "VM" valid_ref;            
                Printf.printf "write_field <valid table> <valid ref> <valid field> - invalidating ref_index\n";
-               let (_: unit) = Client.write_field "VM" valid_ref (Escaping.escape_id ["name"; "label"]) "newlabel" in
-               if in_process then check_ref_index "VM" valid_ref;              
+               let (_: unit) = Client.write_field "VM" valid_ref (Escaping.escape_id ["name"; "label"]) "newlabel" in
+               if in_process then check_ref_index t "VM" valid_ref;            
 
                Printf.printf "read_record <invalid table> <invalid ref>\n";
                expect_missing_tbl "Vm"
                        (fun () ->
-                               let _ = Client.read_record "Vm" invalid_ref in
+                               let _ = Client.read_record "Vm" invalid_ref in
                                failwith "read_record <invalid table> <invalid ref>"
                        );
                Printf.printf "read_record <valid table> <valid ref>\n";
                expect_missing_row "VM" invalid_ref
                        (fun () ->
-                               let _ = Client.read_record "VM" invalid_ref in
+                               let _ = Client.read_record "VM" invalid_ref in
                                failwith "read_record <valid table> <invalid ref>"
                        );
                Printf.printf "read_record <valid table> <valid ref>\n";
-               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
+               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
                if not(List.mem_assoc (Escaping.escape_id [ "name"; "label" ]) fv_list)
                then failwith "read_record <valid table> <valid ref> 1";
                if List.assoc "VBDs" fvs_list <> []
                then failwith "read_record <valid table> <valid ref> 2";
                Printf.printf "read_record <valid table> <valid ref> foreign key\n";
-               Client.create_row "VBD" (make_vbd valid_ref vbd_ref vbd_uuid) vbd_ref;
-               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
+               Client.create_row "VBD" (make_vbd valid_ref vbd_ref vbd_uuid) vbd_ref;
+               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
                if List.assoc "VBDs" fvs_list <> [ vbd_ref ] then begin
                        Printf.printf "fv_list = [ %s ] fvs_list = [ %s ]\n%!" (String.concat "; " (List.map (fun (k, v) -> k ^":" ^ v) fv_list))  (String.concat "; " (List.map (fun (k, v) -> k ^ ":" ^ (String.concat ", " v)) fvs_list));
                        failwith "read_record <valid table> <valid ref> 3"
                end;
                Printf.printf "read_record <valid table> <valid ref> deleted foreign key\n";
-               Client.delete_row "VBD" vbd_ref;
-               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
+               Client.delete_row "VBD" vbd_ref;
+               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
                if List.assoc "VBDs" fvs_list <> []
                then failwith "read_record <valid table> <valid ref> 4";
                Printf.printf "read_record <valid table> <valid ref> overwritten foreign key\n";
-               Client.create_row "VBD" (make_vbd valid_ref vbd_ref vbd_uuid) vbd_ref;
-               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
+               Client.create_row "VBD" (make_vbd valid_ref vbd_ref vbd_uuid) vbd_ref;
+               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
                if List.assoc "VBDs" fvs_list = []
                then failwith "read_record <valid table> <valid ref> 5";
-               Client.write_field "VBD" vbd_ref (Escaping.escape_id [ "VM" ]) "overwritten";
-               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
+               Client.write_field "VBD" vbd_ref (Escaping.escape_id [ "VM" ]) "overwritten";
+               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
                if List.assoc "VBDs" fvs_list <> []
                then failwith "read_record <valid table> <valid ref> 6";        
                
                expect_missing_tbl "Vm"
                        (fun () ->
-                               let _ = Client.read_records_where "Vm" Db_filter_types.True in
+                               let _ = Client.read_records_where "Vm" Db_filter_types.True in
                                ()
                        );
-               let xs = Client.read_records_where "VM" Db_filter_types.True in
+               let xs = Client.read_records_where "VM" Db_filter_types.True in
                if List.length xs <> 1
                then failwith "read_records_where <valid table> 2";
-               let xs = Client.read_records_where "VM" Db_filter_types.False in
+               let xs = Client.read_records_where "VM" Db_filter_types.False in
                if xs <> []
                then failwith "read_records_where <valid table> 3";
                
                expect_missing_tbl "Vm"
                        (fun () ->
-                               let xs = Client.find_refs_with_filter "Vm" Db_filter_types.True in
+                               let xs = Client.find_refs_with_filter "Vm" Db_filter_types.True in
                                failwith "find_refs_with_filter <invalid table>";
                        );
-               let xs = Client.find_refs_with_filter "VM" Db_filter_types.True in
+               let xs = Client.find_refs_with_filter "VM" Db_filter_types.True in
                if List.length xs <> 1
                then failwith "find_refs_with_filter <valid table> 1";
-               let xs = Client.find_refs_with_filter "VM" Db_filter_types.False in
+               let xs = Client.find_refs_with_filter "VM" Db_filter_types.False in
                if xs <> []
                then failwith "find_refs_with_filter <valid table> 2";
                
                expect_missing_tbl "Vm"
                        (fun () ->
-                               Client.process_structured_field ("","") "Vm" "wibble" invalid_ref Db_cache_types.AddSet;
+                               Client.process_structured_field ("","") "Vm" "wibble" invalid_ref Db_cache_types.AddSet;
                                failwith "process_structure_field <invalid table> <invalid fld> <invalid ref>"
                        );
                expect_missing_field "wibble"
                        (fun () ->
-                               Client.process_structured_field ("","") "VM" "wibble" valid_ref Db_cache_types.AddSet;
+                               Client.process_structured_field ("","") "VM" "wibble" valid_ref Db_cache_types.AddSet;
                                failwith "process_structure_field <valid table> <invalid fld> <valid ref>"
                        );
                expect_missing_row "VM" invalid_ref
                        (fun () ->
-                               Client.process_structured_field ("","") "VM" (Escaping.escape_id ["name"; "label"]) invalid_ref Db_cache_types.AddSet;
+                               Client.process_structured_field ("","") "VM" (Escaping.escape_id ["name"; "label"]) invalid_ref Db_cache_types.AddSet;
                                failwith "process_structure_field <valid table> <valid fld> <invalid ref>"
                        );
-               Client.process_structured_field ("foo", "") "VM" "tags" valid_ref Db_cache_types.AddSet;
-               if Client.read_field "VM" "tags" valid_ref <> "('foo')"
+               Client.process_structured_field ("foo", "") "VM" "tags" valid_ref Db_cache_types.AddSet;
+               if Client.read_field "VM" "tags" valid_ref <> "('foo')"
                then failwith "process_structure_field expected ('foo')";
-               Client.process_structured_field ("foo", "") "VM" "tags" valid_ref Db_cache_types.AddSet;
-               if Client.read_field "VM" "tags" valid_ref <> "('foo')"
+               Client.process_structured_field ("foo", "") "VM" "tags" valid_ref Db_cache_types.AddSet;
+               if Client.read_field "VM" "tags" valid_ref <> "('foo')"
                then failwith "process_structure_field expected ('foo') 2";
-               Client.process_structured_field ("foo", "bar") "VM" "other_config" valid_ref Db_cache_types.AddMap;
+               Client.process_structured_field ("foo", "bar") "VM" "other_config" valid_ref Db_cache_types.AddMap;
                
-               if Client.read_field "VM" "other_config" valid_ref <> "(('foo' 'bar'))"
+               if Client.read_field "VM" "other_config" valid_ref <> "(('foo' 'bar'))"
                then failwith "process_structure_field expected (('foo' 'bar')) 3";
                
                begin
                        try
-                               Client.process_structured_field ("foo", "bar") "VM" "other_config" valid_ref Db_cache_types.AddMap;
+                               Client.process_structured_field ("foo", "bar") "VM" "other_config" valid_ref Db_cache_types.AddMap;
                        with Db_exn.Duplicate_key("VM", "other_config", r', "foo") when r' = valid_ref -> ()
                end;
-               if Client.read_field "VM" "other_config" valid_ref <> "(('foo' 'bar'))"
+               if Client.read_field "VM" "other_config" valid_ref <> "(('foo' 'bar'))"
                then failwith "process_structure_field expected (('foo' 'bar')) 4";
                
                (* Check that non-persistent fields are filled with an empty value *)
@@ -523,7 +525,7 @@ module Tests = functor(Client: Db_interface.DB_ACCESS) -> struct
                        let n = 5000 in
                        
                        let rpc_time = time n (fun _ ->
-                               let (_: bool) = Client.is_valid_ref valid_ref in ()) in
+                               let (_: bool) = Client.is_valid_ref valid_ref in ()) in
                        
                        Printf.printf "%.2f primitive RPC calls/sec\n" rpc_time;
                        
@@ -532,14 +534,14 @@ module Tests = functor(Client: Db_interface.DB_ACCESS) -> struct
                                (fun i ->
                                        let rf = Printf.sprintf "%s:%d" vbd_ref i in
                                        try
-                                               Client.delete_row "VBD" rf
+                                               Client.delete_row "VBD" rf
                                        with _ -> ()
                                ) in
                        Printf.printf "Deleted %d VBD records, %.2f calls/sec\n%!" n delete_time;
                        
                        expect_missing_row "VBD" vbd_ref
                                (fun () ->
-                                       Client.delete_row "VBD" vbd_ref;
+                                       Client.delete_row "VBD" vbd_ref;
                                );
                        
                        (* Create lots of VBDs referening no VM *)
@@ -547,7 +549,7 @@ module Tests = functor(Client: Db_interface.DB_ACCESS) -> struct
                                (fun i ->
                                        let rf = Printf.sprintf "%s:%d" vbd_ref i in
                                        let uuid = Printf.sprintf "%s:%d" vbd_uuid i in
-                                       Client.create_row "VBD" (make_vbd invalid_ref rf uuid) rf;
+                                       Client.create_row "VBD" (make_vbd invalid_ref rf uuid) rf;
                                ) in
                        Printf.printf "Created %d VBD records, %.2f calls/sec\n%!" n create_time;
                        
@@ -558,10 +560,10 @@ module Tests = functor(Client: Db_interface.DB_ACCESS) -> struct
                                (fun i ->
                                        if i < (m / 3 * 2) then begin
                                                if i mod 2 = 0
-                                               then Client.create_row "VBD" (make_vbd valid_ref vbd_ref vbd_uuid) vbd_ref
-                                               else Client.delete_row "VBD" vbd_ref
+                                               then Client.create_row "VBD" (make_vbd valid_ref vbd_ref vbd_uuid) vbd_ref
+                                               else Client.delete_row "VBD" vbd_ref
                                        end else
-                                               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
+                                               let fv_list, fvs_list = Client.read_record "VM" valid_ref in
                                                ()
                                ) in
                        Printf.printf "good sequence: %.2f calls/sec\n%!" benign_time;
@@ -569,9 +571,9 @@ module Tests = functor(Client: Db_interface.DB_ACCESS) -> struct
                        let malign_time = time m
                                (fun i ->
                                        match i mod 3 with
-                                               | 0 -> Client.create_row "VBD" (make_vbd valid_ref vbd_ref vbd_uuid) vbd_ref
-                                               | 1 -> Client.delete_row "VBD" vbd_ref
-                                               | 2 -> let fv_list, fvs_list = Client.read_record "VM" valid_ref in
+                                               | 0 -> Client.create_row "VBD" (make_vbd valid_ref vbd_ref vbd_uuid) vbd_ref
+                                               | 1 -> Client.delete_row "VBD" vbd_ref
+                                               | 2 -> let fv_list, fvs_list = Client.read_record "VM" valid_ref in
                                                ()
                                ) in
                        Printf.printf "bad sequence: %.2f calls/sec\n%!" malign_time;
index 46262df33fbc876d2eb83c546c872a95244f7ff9..60caad0449ad721390805ce8b208f881932fe9a7 100644 (file)
@@ -24,15 +24,11 @@ open D
 let db_FLUSH_TIMER=2.0 (* flush db write buffer every db_FLUSH_TIMER seconds *)
 let display_sql_writelog_val = ref true (* compute/write sql-writelog debug string *)
 
-(* The cache itself: *)
-let database : Db_cache_types.Database.t ref = ref (Db_cache_types.Database.make (Schema.of_datamodel ()))
-
 (* --------------------- Util functions on db datastructures *)
 
-let update_database f = 
-       database := f (!database)
+let master_database = ref (Db_cache_types.Database.make Schema.empty)
 
-let get_database () = !database
+let make () = Db_ref.in_memory (ref master_database)
 
 
 (* !!! Right now this is called at cache population time. It would probably be preferable to call it on flush time instead, so we
index 2d7e8ebc51ad9c7e602701496dfbfad6e7713ca3..685b07f0a9df186a1b042c40e7e178318c52f04e 100644 (file)
@@ -30,27 +30,21 @@ module Remote_db : DB_ACCESS = Db_rpc_client_v1.Make(struct
        let rpc request = Master_connection.execute_remote_fn request Constants.remote_db_access_uri
 end)
 
-exception Must_initialise_database_mode
-let implementation = ref None
-let set_master = function
-       | false ->
-               implementation := Some (module Remote_db : DB_ACCESS)
-       | true ->
-               implementation := Some (module Local_db : DB_ACCESS)
-let get () = match !implementation with
-       | None -> raise Must_initialise_database_mode
-       | Some m -> m
+let get = function
+       | Db_ref.In_memory _ -> (module Local_db  : DB_ACCESS)
+       | Db_ref.Remote      -> (module Remote_db : DB_ACCESS)
 
 let apply_delta_to_cache entry =
-       let module DB = (val (get ()) : DB_ACCESS) in   
+       let module DB = (Local_db : DB_ACCESS) in
+       let t = Db_backend.make () in
     let context = Context.make "redo_log" in
     match entry with 
                | Redo_log.CreateRow(tblname, objref, kvs) ->
                        debug "Redoing create_row %s (%s)" tblname objref;
-                       DB.create_row tblname kvs objref
+                       DB.create_row t tblname kvs objref
                | Redo_log.DeleteRow(tblname, objref) ->
                        debug "Redoing delete_row %s (%s)" tblname objref;
-                       DB.delete_row tblname objref
+                       DB.delete_row t tblname objref
                | Redo_log.WriteField(tblname, objref, fldname, newval) ->
                        debug "Redoing write_field %s (%s) [%s -> %s]" tblname objref fldname newval;
-                       DB.write_field tblname objref fldname newval
+                       DB.write_field t tblname objref fldname newval
index 0f83a52d86b650493024641a3d1709efb72e2a26..72c1091d8e8dd5650038d5249e2c323deece7d46 100644 (file)
 
 (** An in-memory cache, used by pool master *)
 
+(* Locking strategy:
+   1. functions which read/modify/write must acquire the db lock. Such
+      functions have the suffix "_locked" to clearly identify them.
+   2. functions which only read must only call "get_database" once,
+      to ensure they see a consistent snapshot.
+*)
 open Db_exn
 open Db_lock
 open Pervasiveext
@@ -23,30 +29,31 @@ open D
 module W = Debug.Debugger(struct let name = "db_write" end)
     
 open Db_cache_types
-open Db_backend
+open Db_ref
+
+(* Only needed by the DB_ACCESS signature *)
+let initialise () = ()
 
 (* This fn is part of external interface, so need to take lock *)
-let get_table_from_ref objref =
-    with_lock
-               (fun () ->
-                       let db = get_database () in
-                       try
-                               Some (Database.table_of_ref objref db)
-                       with Not_found -> 
-                               None)
+let get_table_from_ref t objref =
+       try
+               Some (Database.table_of_ref objref (get_database t))
+       with Not_found -> 
+               None
                
-let is_valid_ref objref =
-    match (get_table_from_ref objref) with
+let is_valid_ref objref =
+    match (get_table_from_ref objref) with
                | Some _ -> true
                | None -> false
                        
+let read_field_internal t tblname fldname objref db = 
+       Row.find fldname (Table.find_exn tblname objref (TableSet.find tblname (Database.tableset db)))
+
 (* Read field from cache *)
-let read_field tblname fldname objref =
-    with_lock
-               (fun () ->
-                       let db = get_database () in
-                       Row.find fldname (Table.find_exn tblname objref (TableSet.find tblname (Database.tableset db)))
-               )
+let read_field t tblname fldname objref =
+       read_field_internal t tblname fldname objref (get_database t)
+
+
 
 
 (** Finds the longest XML-compatible UTF-8 prefix of the given *)
@@ -62,35 +69,33 @@ let ensure_utf8_xml string =
 
                
 (* Write field in cache *)
-let write_field tblname objref fldname newval =
-    with_lock
-               (fun () ->
-                       let db = get_database () in
-
-                       let row = Table.find_exn tblname objref (TableSet.find tblname (Database.tableset db)) in
-                       let current_val = Row.find fldname row in
-                       
-                       let newval = ensure_utf8_xml newval in
-                       
-                       if current_val<>newval then
-                               begin
-                                       W.debug "write_field %s,%s: %s |-> %s" tblname objref fldname newval;
-                                       
-                                       (* Update the field in the cache whether it's persistent or not *)
-                                       update_database (set_field_in_row tblname objref fldname newval);
-
-                                       Database.notify (WriteField(tblname, objref, fldname, current_val, newval)) db;
-
-                                       (* then only persist the change if the schema says so *)
-                                       if Schema.is_field_persistent (Database.schema db) tblname fldname 
-                                       then update_database Database.increment;
-                               end)
+let write_field_locked t tblname objref fldname newval =
+       let row = Table.find_exn tblname objref (TableSet.find tblname (Database.tableset (get_database t))) in
+       let current_val = Row.find fldname row in
+       
+       let newval = ensure_utf8_xml newval in
+       
+       if current_val<>newval then begin
+               W.debug "write_field %s,%s: %s |-> %s" tblname objref fldname newval;
+               
+               (* Update the field in the cache whether it's persistent or not *)
+               update_database t (set_field_in_row tblname objref fldname newval);
                
+               Database.notify (WriteField(tblname, objref, fldname, current_val, newval)) (get_database t);
+               
+               (* then only persist the change if the schema says so *)
+               if Schema.is_field_persistent (Database.schema (get_database t)) tblname fldname 
+               then update_database t Database.increment
+       end
+               
+let write_field t tblname objref fldname newval =
+       with_lock (fun () -> 
+               write_field_locked t tblname objref fldname newval)
+
 (* This function *should* only be used by db_actions code looking up Set(Ref _) fields:
    if we detect another (illegal) use we log the problem and fall back to a slow scan *)
-let read_set_ref rcd =
-       let db = get_database () in
-
+let read_set_ref t rcd =
+       let db = get_database t in
        (* The where_record should correspond to the 'one' end of a 'one to many' *)
        let one_tbl = rcd.table in
        let one_fld = rcd.where_field in
@@ -106,21 +111,17 @@ let read_set_ref rcd =
                let _, many_tbl, many_fld = List.find (fun (a, _, _) -> a = one_fld) rels in
                let objref = rcd.where_value in
                
-               let str = read_field many_tbl many_fld objref in
+               let str = read_field_internal t many_tbl many_fld objref db in
                String_unmarshall_helper.set (fun x -> x) str           
        end else begin
                error "Illegal read_set_ref query { table = %s; where_field = %s; where_value = %s; return = %s }; falling back to linear scan" rcd.table rcd.where_field rcd.where_value rcd.return;
                Printf.printf "Illegal read_set_ref query { table = %s; where_field = %s; where_value = %s; return = %s }; falling back to linear scan\n%!" rcd.table rcd.where_field rcd.where_value rcd.return;
-               with_lock
-                       (fun () ->
-                               let db = get_database () in
-                               let tbl = TableSet.find rcd.table (Database.tableset db) in
-                               Table.fold
-                                       (fun rf row acc ->
-                                               if Row.find rcd.where_field row = rcd.where_value 
-                                               then Row.find rcd.return row :: acc else acc)
-                                       tbl []
-                       )
+               let tbl = TableSet.find rcd.table (Database.tableset db) in
+               Table.fold
+                       (fun rf row acc ->
+                               if Row.find rcd.where_field row = rcd.where_value 
+                               then Row.find rcd.return row :: acc else acc)
+                       tbl []
        end
                        
 
@@ -130,48 +131,46 @@ let read_set_ref rcd =
    and iterates through set-refs [returning (fieldname, ref list) list; where fieldname is the
    name of the Set Ref field in tbl; and ref list is the list of foreign keys from related
    table with remote-fieldname=objref] *)
-let read_record tblname objref  =
-    with_lock
-               (fun ()->
-                       let db = get_database () in
-                       let tbl = TableSet.find tblname (Database.tableset db) in
-                       let row = Table.find_exn tblname objref tbl in
-                       let fvlist = Row.fold (fun k d env -> (k,d)::env) row [] in
-                       (* Unfortunately the interface distinguishes between Set(Ref _) types and 
-                          ordinary fields *)
-                       let schema = Schema.table tblname (Database.schema db) in
-                       let set_ref = List.filter (fun (k, _) ->
-                               try
-                                       let column = Schema.Table.find k schema in
-                                       column.Schema.Column.issetref
-                               with Not_found as e ->
-                                       Printf.printf "Failed to find table %s in schema\n%!" k;
-                                       raise e
-                       ) fvlist in
-                       (* the set_ref fields must be converted back into lists *)
-                       let set_ref = List.map (fun (k, v) -> 
-                               k, String_unmarshall_helper.set (fun x -> x) v) set_ref in
-                       (fvlist, set_ref))
+let read_record t tblname objref  =
+       let db = get_database t in
+       let tbl = TableSet.find tblname (Database.tableset db) in
+       let row = Table.find_exn tblname objref tbl in
+       let fvlist = Row.fold (fun k d env -> (k,d)::env) row [] in
+       (* Unfortunately the interface distinguishes between Set(Ref _) types and 
+          ordinary fields *)
+       let schema = Schema.table tblname (Database.schema db) in
+       let set_ref = List.filter (fun (k, _) ->
+               try
+                       let column = Schema.Table.find k schema in
+                       column.Schema.Column.issetref
+               with Not_found as e ->
+                       Printf.printf "Failed to find table %s in schema\n%!" k;
+                       raise e
+       ) fvlist in
+       (* the set_ref fields must be converted back into lists *)
+       let set_ref = List.map (fun (k, v) -> 
+               k, String_unmarshall_helper.set (fun x -> x) v) set_ref in
+       (fvlist, set_ref)
 
 (* Delete row from tbl *)
-let delete_row tblname objref =
-    with_lock
-               (fun () ->
-                       W.debug "delete_row %s (%s)" tblname objref;
-
-                       let db = get_database () in
-                       let tbl = TableSet.find tblname (Database.tableset db) in
-                       let row = Table.find_exn tblname objref tbl in
-                       
-                       Database.notify (PreDelete(tblname, objref)) db;
-                       update_database (remove_row_from_table tblname objref);
-                       Database.notify (Delete(tblname, objref, Row.fold (fun k v acc -> (k, v) :: acc) row [])) db;
-                       if Schema.is_table_persistent (Database.schema db) tblname 
-                       then update_database Database.increment;
-               )
+let delete_row_locked t tblname objref =
+       W.debug "delete_row %s (%s)" tblname objref;
+       
+       let tbl = TableSet.find tblname (Database.tableset (get_database t)) in
+       let row = Table.find_exn tblname objref tbl in
+       
+       let db = get_database t in
+       Database.notify (PreDelete(tblname, objref)) db;
+       update_database t (remove_row_from_table tblname objref);
+       Database.notify (Delete(tblname, objref, Row.fold (fun k v acc -> (k, v) :: acc) row [])) db;
+       if Schema.is_table_persistent (Database.schema db) tblname 
+       then update_database t Database.increment
                
+let delete_row t tblname objref = 
+       with_lock (fun () -> delete_row_locked t tblname objref)
+
 (* Create new row in tbl containing specified k-v pairs *)
-let create_row tblname kvs' new_objref =
+let create_row_locked t tblname kvs' new_objref =
        
     (* Ensure values are valid for UTF-8-encoded XML. *)
     let kvs' = List.map (fun (key, value) -> (key, ensure_utf8_xml value)) kvs' in
@@ -181,38 +180,33 @@ let create_row tblname kvs' new_objref =
     let kvs' = (Db_names.ref, new_objref) :: kvs' in
 
        let row = List.fold_left (fun row (k, v) -> Row.add k v row) Row.empty kvs' in
-       let schema = Schema.table tblname (Database.schema (get_database ())) in
+       let schema = Schema.table tblname (Database.schema (get_database t)) in
     (* fill in default values if kv pairs for these are not supplied already *)
        let row = Row.add_defaults schema row in
        
-    with_lock
-               (fun () ->
-                       W.debug "create_row %s (%s) [%s]" tblname new_objref (String.concat "," (List.map (fun (k,v)->"("^k^","^"v"^")") kvs'));
-                       let db = get_database () in
-                       let tbl = TableSet.find tblname (Database.tableset db) in
-                       update_database (set_row_in_table tblname new_objref row);
-
-                       Database.notify (Create(tblname, new_objref, Row.fold (fun k v acc -> (k, v) :: acc) row [])) db;
-
-                       if Schema.is_table_persistent (Database.schema db) tblname 
-                       then update_database Database.increment;
-               )
+       W.debug "create_row %s (%s) [%s]" tblname new_objref (String.concat "," (List.map (fun (k,v)->"("^k^","^"v"^")") kvs'));
+       let tbl = TableSet.find tblname (Database.tableset (get_database t)) in
+       update_database t (set_row_in_table tblname new_objref row);
+       
+       Database.notify (Create(tblname, new_objref, Row.fold (fun k v acc -> (k, v) :: acc) row [])) (get_database t);
+       if Schema.is_table_persistent (Database.schema (get_database t)) tblname 
+       then update_database t Database.increment
                
+let create_row t tblname kvs' new_objref =
+       with_lock (fun () -> create_row_locked t tblname kvs' new_objref)
+
 (* Do linear scan to find field values which match where clause *)
-let read_field_where rcd =
-    with_lock
-               (fun () ->
-                       let db = get_database () in
-                       let tbl = TableSet.find rcd.table (Database.tableset db) in
-                       Table.fold
-                               (fun r row acc ->
-                                       let field = Row.find rcd.where_field row in
-                                       if field = rcd.where_value then Row.find rcd.return row :: acc else acc
-                               ) tbl []
-               )
+let read_field_where t rcd =
+       let db = get_database t in
+       let tbl = TableSet.find rcd.table (Database.tableset db) in
+       Table.fold
+               (fun r row acc ->
+                       let field = Row.find rcd.where_field row in
+                       if field = rcd.where_value then Row.find rcd.return row :: acc else acc
+               ) tbl []
                
-let db_get_by_uuid tbl uuid_val =
-    match (read_field_where
+let db_get_by_uuid t tbl uuid_val =
+    match (read_field_where t
         {table=tbl; return=Db_names.ref;
         where_field=Db_names.uuid; where_value=uuid_val}) with
                | [] -> raise (Read_missing_uuid (tbl, "", uuid_val))
@@ -220,69 +214,61 @@ let db_get_by_uuid tbl uuid_val =
                | _ -> raise (Too_many_values (tbl, "", uuid_val))
                        
 (** Return reference fields from tbl that matches specified name_label field *)
-let db_get_by_name_label tbl label =
-    read_field_where
+let db_get_by_name_label t tbl label =
+    read_field_where t
         {table=tbl; return=Db_names.ref;
         where_field=(Escaping.escape_id ["name"; "label"]);
         where_value=label}
                
 (* Read references from tbl *)
-let read_refs tblname =
-    with_lock
-               (fun () ->
-                       let db = get_database () in
-                       let tbl = TableSet.find tblname (Database.tableset db) in
-                       Table.fold (fun r _ acc -> r :: acc) tbl [])
+let read_refs t tblname =
+       let tbl = TableSet.find tblname (Database.tableset (get_database t)) in
+       Table.fold (fun r _ acc -> r :: acc) tbl []
                
 (* Return a list of all the refs for which the expression returns true. *)
-let find_refs_with_filter (tblname: string) (expr: Db_filter_types.expr) = 
-    with_lock
-               (fun ()->
-                       let db = get_database () in
-                       let tbl = TableSet.find tblname (Database.tableset db) in
-                       let eval_val row = function
-                               | Db_filter_types.Literal x -> x
-                               | Db_filter_types.Field x -> Row.find x row in
-                       Table.fold
-                               (fun r row acc ->
-                                       if Db_filter.eval_expr (eval_val row) expr
-                                       then Row.find Db_names.ref row :: acc else acc
-                               ) tbl []
-               )
+let find_refs_with_filter t (tblname: string) (expr: Db_filter_types.expr) = 
+       let db = get_database t in
+       let tbl = TableSet.find tblname (Database.tableset db) in
+       let eval_val row = function
+               | Db_filter_types.Literal x -> x
+               | Db_filter_types.Field x -> Row.find x row in
+       Table.fold
+               (fun r row acc ->
+                       if Db_filter.eval_expr (eval_val row) expr
+                       then Row.find Db_names.ref row :: acc else acc
+               ) tbl []
                
-let read_records_where tbl expr =
-    with_lock
-               (fun ()->
-                       let reqd_refs = find_refs_with_filter tbl expr in
-                       List.map (fun ref->ref, read_record tbl ref) reqd_refs
-               )
+let read_records_where t tbl expr =
+       let reqd_refs = find_refs_with_filter t tbl expr in
+       List.map (fun ref->ref, read_record t tbl ref) reqd_refs
        
-let process_structured_field (key,value) tblname fld objref proc_fn_selector =
+let process_structured_field_locked t (key,value) tblname fld objref proc_fn_selector =
        
     (* Ensure that both keys and values are valid for UTF-8-encoded XML. *)
     let key = ensure_utf8_xml key in
     let value = ensure_utf8_xml value in
        
-    with_lock
-               (fun () ->
-                       let db = get_database () in
-                       let tbl = TableSet.find tblname (Database.tableset db) in
-                       let row = Table.find_exn tblname objref tbl in
-                       let existing_str = Row.find fld row in
-                       let new_str = match proc_fn_selector with
-                               | AddSet -> add_to_set key existing_str
-                               | RemoveSet -> remove_from_set key existing_str
-                               | AddMap -> 
-                                       begin
-                                               try
-                                                       add_to_map key value existing_str
-                                               with Duplicate ->
-                                                       error "Duplicate key in set or map: table %s; field %s; ref %s; key %s" tblname fld objref key;
-                                                       raise (Duplicate_key (tblname,fld,objref,key));
-                                       end
-                               | RemoveMap -> remove_from_map key existing_str in
-                       write_field tblname objref fld new_str)
-               
+       let tbl = TableSet.find tblname (Database.tableset (get_database t)) in
+       let row = Table.find_exn tblname objref tbl in
+       let existing_str = Row.find fld row in
+       let new_str = match proc_fn_selector with
+               | AddSet -> add_to_set key existing_str
+               | RemoveSet -> remove_from_set key existing_str
+               | AddMap -> 
+                       begin
+                               try
+                                       add_to_map key value existing_str
+                               with Duplicate ->
+                                       error "Duplicate key in set or map: table %s; field %s; ref %s; key %s" tblname fld objref key;
+                                       raise (Duplicate_key (tblname,fld,objref,key));
+                       end
+               | RemoveMap -> remove_from_map key existing_str in
+       write_field t tblname objref fld new_str
+       
+let process_structured_field t (key,value) tblname fld objref proc_fn_selector =
+       with_lock (fun () -> 
+               process_structured_field_locked t (key,value) tblname fld objref proc_fn_selector)
+       
 (* -------------------------------------------------------------------- *)
                
 let load connections default_schema =
@@ -400,26 +386,22 @@ let spawn_db_flush_threads() =
                
                
 (* Called by server at start-of-day to initialiase cache. Populates cache and starts flushing threads *)
-let make connections default_schema =
+let make connections default_schema =
     let db = load connections default_schema in
        let db = Database.reindex db in
-       update_database (fun _ -> db);
+       update_database (fun _ -> db);
 
     spawn_db_flush_threads()
                
 
 (** Return an association list of table name * record count *)
-let stats () = 
-    with_lock 
-               (fun () ->
-                       TableSet.fold (fun name tbl acc ->
-                               let size = Table.fold (fun _ _ acc -> acc + 1) tbl 0 in
-                               (name, size) :: acc)
-                               (Database.tableset (Db_backend.get_database ()))
-                               []
-               )
+let stats t = 
+       TableSet.fold (fun name tbl acc ->
+               let size = Table.fold (fun _ _ acc -> acc + 1) tbl 0 in
+               (name, size) :: acc)
+               (Database.tableset (get_database t))
+               []
+
 
 
-(* Only needed by the DB_ACCESS signature *)
-let initialise () = ()
 
index 070a4225735357a9131abde5b7dbbdd143b884b9..41bb909ab5d85e3c9825f11ec98b7ce29e3e08d2 100644 (file)
@@ -1,7 +1,7 @@
 include Db_interface.DB_ACCESS
 
-(** [make connections default_schema] initialises the in-memory cache *)
-val make : Parse_db_conf.db_connection list -> Schema.t -> unit
+(** [make connections default_schema] initialises the in-memory cache *)
+val make : Db_ref.t -> Parse_db_conf.db_connection list -> Schema.t -> unit
 
 (** [flush_and_exit db code] flushes the specific backend [db] and exits
        xapi with [code] *)
@@ -10,5 +10,5 @@ val flush_and_exit : Parse_db_conf.db_connection -> int -> unit
 (** [sync db] forcibly flushes the database to disk *)
 val sync : Parse_db_conf.db_connection list -> Db_cache_types.Database.t -> unit
 
-(** [stats ()] returns some stats data for logging *)
-val stats : unit -> (string * int) list
+(** [stats t] returns some stats data for logging *)
+val stats : Db_ref.t -> (string * int) list
index b7f0f74984171a206a8834d9601cd7bea0ac108e..90f03104f270cd06fae1999f95cb64c7db1efc4f 100644 (file)
@@ -33,59 +33,59 @@ module type DB_ACCESS = sig
                
        (** [get_table_from_ref ref] returns [Some tbl] if [ref] is a 
                valid reference; None otherwise *)
-    val get_table_from_ref : string -> string option
+    val get_table_from_ref : Db_ref.t -> string -> string option
                
        (** [is_valid_ref ref] returns true if [ref] is valid; false otherwise *)
-    val is_valid_ref : string -> bool
+    val is_valid_ref : Db_ref.t -> string -> bool
                
        (** [read_refs tbl] returns a list of all references in table [tbl] *)
-    val read_refs : string -> string list
+    val read_refs : Db_ref.t -> string -> string list
                
        (** [find_refs_with_filter tbl expr] returns a list of all references
                to rows which match [expr] *)
     val find_refs_with_filter :
-        string -> Db_filter_types.expr -> string list
+        Db_ref.t -> string -> Db_filter_types.expr -> string list
                
        (** [read_field_where {tbl,return,where_field,where_value}] returns a
                list of the [return] fields in table [tbl] where the [where_field]
                equals [where_value] *)
-    val read_field_where : Db_cache_types.where_record -> string list
+    val read_field_where : Db_ref.t -> Db_cache_types.where_record -> string list
                
        (** [db_get_by_uuid tbl uuid] returns the single object reference
                associated with [uuid] *)
-    val db_get_by_uuid : string -> string -> string
+    val db_get_by_uuid : Db_ref.t -> string -> string -> string
                
        (** [db_get_by_name_label tbl label] returns the list of object references
                associated with [label] *)
-    val db_get_by_name_label : string -> string -> string list
+    val db_get_by_name_label : Db_ref.t -> string -> string -> string list
                
        (** [read_set_ref {tbl,return,where_field,where_value}] is identical
                to [read_field_where ...]. *)
-    val read_set_ref : Db_cache_types.where_record -> string list
+    val read_set_ref : Db_ref.t -> Db_cache_types.where_record -> string list
                
        (** [create_row tbl kvpairs ref] create a new row in [tbl] with
                key [ref] and contents [kvpairs] *)
     val create_row :
-        string -> (string * string) list -> string -> unit
+        Db_ref.t -> string -> (string * string) list -> string -> unit
                
        (** [delete_row context tbl ref] deletes row [ref] from table [tbl] *)
-    val delete_row : string -> string -> unit
+    val delete_row : Db_ref.t -> string -> string -> unit
                
        (** [write_field context tbl ref fld val] changes field [fld] to [val] in
                row [ref] in table [tbl] *)
-    val write_field : string -> string -> string -> string -> unit
-               
+    val write_field : Db_ref.t -> string -> string -> string -> string -> unit
+                
        (** [read_field context tbl ref fld] returns the value of field [fld]
                in row [ref] in table [tbl] *)
-    val read_field : string -> string -> string -> string
+    val read_field : Db_ref.t -> string -> string -> string -> string
                
        (** [read_record tbl ref] returns 
                [ (field, value) ] * [ (set_ref fieldname * [ ref ]) ] *)
-       val read_record : string -> string -> db_record
+       val read_record : Db_ref.t -> string -> string -> db_record
                
        (** [read_records_where tbl expr] returns a list of the values returned
                by read_record that match the expression *)
-       val read_records_where : string -> Db_filter_types.expr -> 
+       val read_records_where : Db_ref.t -> string -> Db_filter_types.expr -> 
                (string * db_record) list
                        
        (** [process_structured_field context kv tbl fld ref op] modifies the 
@@ -93,7 +93,7 @@ module type DB_ACCESS = sig
                which may be one of AddSet RemoveSet AddMap RemoveMap with 
                arguments [kv] *)
     val process_structured_field :
-        string * string ->
+        Db_ref.t -> string * string ->
         string -> string -> string -> Db_cache_types.structured_op_t -> unit
 end
 
diff --git a/ocaml/database/db_ref.ml b/ocaml/database/db_ref.ml
new file mode 100644 (file)
index 0000000..b5831a9
--- /dev/null
@@ -0,0 +1,18 @@
+type t = 
+       | In_memory of Db_cache_types.Database.t ref ref
+       | Remote
+
+exception Database_not_in_memory
+
+let in_memory (rf: Db_cache_types.Database.t ref ref) = In_memory rf
+
+let get_database = function
+       | In_memory x -> !(!(x))
+       | Remote -> raise Database_not_in_memory
+
+let update_database t f = match t with
+       | In_memory x ->
+               let d : Db_cache_types.Database.t = f (get_database t) in
+               (!(x)) := d
+       | Remote -> raise Database_not_in_memory
+
index 3e893771153162599d2019def5d7d580695ba056..63f0ccf194feec8600cfc724f0f0c95d982079bc 100644 (file)
@@ -1,3 +1,5 @@
+
+open Threadext
        
 module DBCacheRemoteListener = struct
        open Db_rpc_common_v1
@@ -51,62 +53,63 @@ module DBCacheRemoteListener = struct
                Note that, although the messages still contain the pool_secret for historical reasons,
                access has already been applied by the RBAC code in Xapi_http.add_handler. *)
        let process_xmlrpc xml =
-               Mutex.lock ctr_mutex;
-               calls_processed := !calls_processed + 1;
-               Mutex.unlock ctr_mutex;
+               Mutex.execute ctr_mutex
+                       (fun () -> calls_processed := !calls_processed + 1);
+
                let fn_name, args =
                        match (XMLRPC.From.array (fun x->x) xml) with
                                        [fn_name; _; args] ->
                                                XMLRPC.From.string fn_name, args
                                | _ -> raise DBCacheListenerInvalidMessageReceived in
+               let t = Db_backend.make () in
                try
                        debug "Received [total=%d rx=%d tx=%d] %s" !calls_processed !total_recv_len !total_transmit_len fn_name;
                        match fn_name with
                                        "get_table_from_ref" ->
                                                let s = unmarshall_get_table_from_ref_args args in
-                                               success (marshall_get_table_from_ref_response (DBCache.get_table_from_ref s))
+                                               success (marshall_get_table_from_ref_response (DBCache.get_table_from_ref s))
                                | "is_valid_ref" ->
                                        let s = unmarshall_is_valid_ref_args args in
-                                       success (marshall_is_valid_ref_response (DBCache.is_valid_ref s))
+                                       success (marshall_is_valid_ref_response (DBCache.is_valid_ref s))
                                | "read_refs" ->
                                        let s = unmarshall_read_refs_args args in
-                                       success (marshall_read_refs_response (DBCache.read_refs s))
+                                       success (marshall_read_refs_response (DBCache.read_refs s))
                                | "read_field_where" ->
                                        let w = unmarshall_read_field_where_args args in
-                                       success (marshall_read_field_where_response (DBCache.read_field_where w))
+                                       success (marshall_read_field_where_response (DBCache.read_field_where w))
                                | "read_set_ref" ->
                                        let w = unmarshall_read_set_ref_args args in
-                                       success (marshall_read_set_ref_response (DBCache.read_field_where w))
+                                       success (marshall_read_set_ref_response (DBCache.read_field_where w))
                                | "create_row" ->
                                        let (s1,ssl,s2) = unmarshall_create_row_args args in
-                                       success (marshall_create_row_response (DBCache.create_row s1 ssl s2))
+                                       success (marshall_create_row_response (DBCache.create_row s1 ssl s2))
                                | "delete_row" ->
                                        let (s1,s2) = unmarshall_delete_row_args args in
-                                       success (marshall_delete_row_response (DBCache.delete_row s1 s2))
+                                       success (marshall_delete_row_response (DBCache.delete_row s1 s2))
                                | "write_field" ->
                                        let (s1,s2,s3,s4) = unmarshall_write_field_args args in
-                                       success (marshall_write_field_response (DBCache.write_field s1 s2 s3 s4))
+                                       success (marshall_write_field_response (DBCache.write_field s1 s2 s3 s4))
                                | "read_field" ->
                                        let (s1,s2,s3) = unmarshall_read_field_args args in
-                                       success (marshall_read_field_response (DBCache.read_field s1 s2 s3))
+                                       success (marshall_read_field_response (DBCache.read_field s1 s2 s3))
                                | "find_refs_with_filter" ->
                                        let (s,e) = unmarshall_find_refs_with_filter_args args in
-                                       success (marshall_find_refs_with_filter_response (DBCache.find_refs_with_filter s e))
+                                       success (marshall_find_refs_with_filter_response (DBCache.find_refs_with_filter s e))
                                | "process_structured_field" ->
                                        let (ss,s1,s2,s3,op) = unmarshall_process_structured_field_args args in
-                                       success (marshall_process_structured_field_response (DBCache.process_structured_field ss s1 s2 s3 op))
+                                       success (marshall_process_structured_field_response (DBCache.process_structured_field ss s1 s2 s3 op))
                                | "read_record" ->
                                        let (s1,s2) = unmarshall_read_record_args args in
-                                       success (marshall_read_record_response (DBCache.read_record s1 s2))
+                                       success (marshall_read_record_response (DBCache.read_record s1 s2))
                                | "read_records_where" ->
                                        let (s,e) = unmarshall_read_records_where_args args in
-                                       success (marshall_read_records_where_response (DBCache.read_records_where s e))
+                                       success (marshall_read_records_where_response (DBCache.read_records_where s e))
                                | "db_get_by_uuid" ->
                                        let (s,e) = unmarshall_db_get_by_uuid_args args in
-                                       success (marshall_db_get_by_uuid_response (DBCache.db_get_by_uuid s e))
+                                       success (marshall_db_get_by_uuid_response (DBCache.db_get_by_uuid s e))
                                | "db_get_by_name_label" ->
                                        let (s,e) = unmarshall_db_get_by_name_label_args args in
-                                       success (marshall_db_get_by_name_label_response (DBCache.db_get_by_name_label s e))
+                                       success (marshall_db_get_by_name_label_response (DBCache.db_get_by_name_label s e))
                                | _ -> raise (DBCacheListenerUnknownMessageName fn_name)
                with
                                Duplicate_key (c,f,u,k) ->
@@ -126,6 +129,7 @@ let handler req bio =
        let fd = Buf_io.fd_of bio in (* fd only used for writing *)
        let body = Http_svr.read_body ~limit:Xapi_globs.http_limit_max_rpc_size req bio in
        let body_xml = Xml.parse_string body in
-       let response = Xml.to_bigbuffer (DBCacheRemoteListener.process_xmlrpc body_xml) in
+       let reply_xml = DBCacheRemoteListener.process_xmlrpc body_xml in
+       let response = Xml.to_bigbuffer reply_xml in
        Http_svr.response_fct req fd (Bigbuffer.length response)
                (fun fd -> Bigbuffer.to_fct response (fun s -> ignore(Unix.write fd s 0 (String.length s)))) 
index 1b31372458afacda674dd84014ba5bc0d55731f1..bcff3574f157012ecc735f2eac8cb97a1918e8b3 100644 (file)
@@ -20,40 +20,41 @@ open Db_exn
 (** Convert a marshalled Request Rpc.t into a marshalled Response Rpc.t *)
 let process_rpc (req: Rpc.t) = 
        let module DB = (Db_cache_impl : Db_interface.DB_ACCESS) in
+       let t = Db_backend.make () in
        Response.rpc_of_t
                (try
                        match Request.t_of_rpc req with
                                | Request.Get_table_from_ref x -> 
-                                       Response.Get_table_from_ref (DB.get_table_from_ref x)
+                                       Response.Get_table_from_ref (DB.get_table_from_ref x)
                                | Request.Is_valid_ref x ->
-                                       Response.Is_valid_ref (DB.is_valid_ref x)
+                                       Response.Is_valid_ref (DB.is_valid_ref x)
                                | Request.Read_refs x ->
-                                       Response.Read_refs (DB.read_refs x)
+                                       Response.Read_refs (DB.read_refs x)
                                | Request.Find_refs_with_filter (x, e) ->
-                                       Response.Find_refs_with_filter (DB.find_refs_with_filter x e)
+                                       Response.Find_refs_with_filter (DB.find_refs_with_filter x e)
                                | Request.Read_field_where w ->
-                                       Response.Read_field_where (DB.read_field_where w)
+                                       Response.Read_field_where (DB.read_field_where w)
                                | Request.Db_get_by_uuid (a, b) ->
-                                       Response.Db_get_by_uuid (DB.db_get_by_uuid a b)
+                                       Response.Db_get_by_uuid (DB.db_get_by_uuid a b)
                                | Request.Db_get_by_name_label (a, b) ->
-                                       Response.Db_get_by_name_label (DB.db_get_by_name_label a b)
+                                       Response.Db_get_by_name_label (DB.db_get_by_name_label a b)
                                | Request.Read_set_ref w ->
-                                       Response.Read_set_ref (DB.read_set_ref w)
+                                       Response.Read_set_ref (DB.read_set_ref w)
                                | Request.Create_row (a, b, c) ->
-                                       Response.Create_row (DB.create_row a b c)
+                                       Response.Create_row (DB.create_row a b c)
                                | Request.Delete_row (a, b) ->
-                                       Response.Delete_row (DB.delete_row a b)
+                                       Response.Delete_row (DB.delete_row a b)
                                | Request.Write_field (a, b, c, d) ->
-                                       Response.Write_field (DB.write_field a b c d)
+                                       Response.Write_field (DB.write_field a b c d)
                                | Request.Read_field (a, b, c) ->
-                                       Response.Read_field (DB.read_field a b c)
+                                       Response.Read_field (DB.read_field a b c)
                                | Request.Read_record (a, b) ->
-                                       let a', b' = DB.read_record a b in
+                                       let a', b' = DB.read_record a b in
                                        Response.Read_record (a', b')
                                | Request.Read_records_where (a, b) ->
-                                       Response.Read_records_where (DB.read_records_where a b)
+                                       Response.Read_records_where (DB.read_records_where a b)
                                | Request.Process_structured_field (a, b, c, d, e) ->
-                                       Response.Process_structured_field (DB.process_structured_field a b c d e)
+                                       Response.Process_structured_field (DB.process_structured_field a b c d e)
                with 
                        | DBCache_NotFound (x,y,z) ->
                                Response.Dbcache_notfound (x, y, z)
@@ -72,7 +73,8 @@ let handler req bio =
        let fd = Buf_io.fd_of bio in (* fd only used for writing *)
        let body = Http_svr.read_body ~limit:Xapi_globs.http_limit_max_rpc_size req bio in
        let request_rpc = Jsonrpc.of_string body in
+       let reply_rpc = process_rpc request_rpc in
        (* XXX: need to cope with > 16MiB responses *)
-       let response = Jsonrpc.to_string (process_rpc request_rpc) in
+       let response = Jsonrpc.to_string reply_rpc in
        Http_svr.response_str req fd response
 
index 27891f0428e69f19bf66db782932225d8f200656..f7c936947a86015d31af4053ed2d9b143424652d 100644 (file)
@@ -57,28 +57,28 @@ module Make = functor(RPC: Db_interface.RPC) -> struct
                                        else process_exception_xml resp_xml
                        | _ -> raise Remote_db_server_returned_bad_message
                                
-       let get_table_from_ref x =
+       let get_table_from_ref x =
                do_remote_call
                        marshall_get_table_from_ref_args
                        unmarshall_get_table_from_ref_response
                        "get_table_from_ref"
                        x
                        
-       let is_valid_ref x =
+       let is_valid_ref x =
                do_remote_call
                        marshall_is_valid_ref_args
                        unmarshall_is_valid_ref_response
                        "is_valid_ref"
                        x
                        
-       let read_refs x =
+       let read_refs x =
                do_remote_call
                        marshall_read_refs_args
                        unmarshall_read_refs_response
                        "read_refs"
                        x
                        
-       let read_field_where x =
+       let read_field_where x =
                do_remote_call
                        marshall_read_field_where_args
                        unmarshall_read_field_where_response
@@ -86,21 +86,21 @@ module Make = functor(RPC: Db_interface.RPC) -> struct
                        x
 
 
-       let db_get_by_uuid t u =
+       let db_get_by_uuid t u =
                do_remote_call
                        marshall_db_get_by_uuid_args
                        unmarshall_db_get_by_uuid_response
                        "db_get_by_uuid"
                        (t,u)
                        
-       let db_get_by_name_label t l =
+       let db_get_by_name_label t l =
                do_remote_call
                        marshall_db_get_by_name_label_args
                        unmarshall_db_get_by_name_label_response
                        "db_get_by_name_label"
                        (t,l)
                        
-       let read_set_ref x =
+       let read_set_ref x =
                do_remote_call
                        marshall_read_set_ref_args
                        unmarshall_read_set_ref_response
@@ -108,56 +108,56 @@ module Make = functor(RPC: Db_interface.RPC) -> struct
                        x
                        
                        
-       let create_row x y z =
+       let create_row x y z =
                do_remote_call
                        marshall_create_row_args
                        unmarshall_create_row_response
                        "create_row"
                        (x,y,z)
                        
-       let delete_row x y =
+       let delete_row x y =
                do_remote_call
                        marshall_delete_row_args
                        unmarshall_delete_row_response
                        "delete_row"
                        (x,y)
                        
-       let write_field a b c d =
+       let write_field a b c d =
                do_remote_call
                        marshall_write_field_args
                        unmarshall_write_field_response
                        "write_field"
                        (a,b,c,d)
                        
-       let read_field x y z =
+       let read_field x y z =
                do_remote_call
                        marshall_read_field_args
                        unmarshall_read_field_response
                        "read_field"
                        (x,y,z)
                        
-       let find_refs_with_filter s e =
+       let find_refs_with_filter s e =
                do_remote_call
                        marshall_find_refs_with_filter_args
                        unmarshall_find_refs_with_filter_response
                        "find_refs_with_filter"
                        (s,e)
                        
-       let read_record x y =
+       let read_record x y =
                do_remote_call
                        marshall_read_record_args
                        unmarshall_read_record_response
                        "read_record"
                        (x,y)
                        
-       let read_records_where x e =
+       let read_records_where x e =
                do_remote_call
                        marshall_read_records_where_args
                        unmarshall_read_records_where_response
                        "read_records_where"
                        (x,e)
                
-       let process_structured_field a b c d e =
+       let process_structured_field a b c d e =
                do_remote_call
                        marshall_process_structured_field_args
                        unmarshall_process_structured_field_response
index 01a37e7b150e75a033fef3c0bd0a9bcee4236a64..f372a702557b0e72fa154e7ae2a19dcf9241cbc7 100644 (file)
@@ -18,7 +18,6 @@ open Db_rpc_common_v2
 open Db_exn
 
 module Make = functor(RPC: Db_interface.RPC) -> struct
-
        let initialise = RPC.initialise
        let rpc x = Jsonrpc.of_string (RPC.rpc (Jsonrpc.to_string x))
 
@@ -37,77 +36,77 @@ module Make = functor(RPC: Db_interface.RPC) -> struct
                                raise (Too_many_values (x,y,z))                 
                        | y -> y
 
-       let get_table_from_ref x =
+       let get_table_from_ref x =
                match process (Request.Get_table_from_ref x) with
                        | Response.Get_table_from_ref y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let is_valid_ref x =
+       let is_valid_ref x =
                match process (Request.Is_valid_ref x) with
                        | Response.Is_valid_ref y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
 
-       let read_refs x =
+       let read_refs x =
                match process (Request.Read_refs x) with
                        | Response.Read_refs y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let read_field_where x =
+       let read_field_where x =
                match process (Request.Read_field_where x) with
                        | Response.Read_field_where y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
 
-       let db_get_by_uuid t u =
+       let db_get_by_uuid t u =
                match process (Request.Db_get_by_uuid (t, u)) with
                        | Response.Db_get_by_uuid y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let db_get_by_name_label t l =
+       let db_get_by_name_label t l =
                match process (Request.Db_get_by_name_label (t, l)) with
                        | Response.Db_get_by_name_label y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let read_set_ref x =
+       let read_set_ref x =
                match process (Request.Read_set_ref x) with
                        | Response.Read_set_ref y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let create_row x y z =
+       let create_row x y z =
                match process (Request.Create_row (x, y, z)) with
                        | Response.Create_row y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let delete_row x y =
+       let delete_row x y =
                match process (Request.Delete_row (x, y)) with
                        | Response.Delete_row y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let write_field a b c d =
+       let write_field a b c d =
                match process (Request.Write_field (a, b, c, d)) with
                        | Response.Write_field y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let read_field x y z =
+       let read_field x y z =
                match process (Request.Read_field (x, y, z)) with
                        | Response.Read_field y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let find_refs_with_filter s e =
+       let find_refs_with_filter s e =
                match process (Request.Find_refs_with_filter (s, e)) with
                        | Response.Find_refs_with_filter y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
 
-       let read_record x y =
+       let read_record x y =
                match process (Request.Read_record (x, y)) with
                        | Response.Read_record (x, y) -> x, y
                        | _ -> raise Remote_db_server_returned_bad_message
                        
-       let read_records_where x e =
+       let read_records_where x e =
                match process (Request.Read_records_where (x, e)) with
                        | Response.Read_records_where y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
 
-       let process_structured_field a b c d e =
+       let process_structured_field a b c d e =
                match process (Request.Process_structured_field(a, b, c, d, e)) with
                        | Response.Process_structured_field y -> y
                        | _ -> raise Remote_db_server_returned_bad_message
index e8b64395b5bc5119997c6df5869b1274d49fe1d0..88eca237db1e6bd4c04654e570c347d406d9a6ff 100644 (file)
@@ -79,16 +79,22 @@ let database_callback event db =
        let other_tbl_refs_for_this_field tblname fldname =
                List.filter (fun (_,fld) -> fld=fldname) (other_tbl_refs tblname) in    
 
+       let is_valid_ref r = 
+               try
+                       ignore(Database.table_of_ref r db);
+                       true
+               with _ -> false in
+
        match event with
                | WriteField (tblname, objref, fldname, oldval, newval) ->
                        let events_old_val = 
-                               if Db_cache_impl.is_valid_ref oldval then 
+                               if is_valid_ref oldval then 
                                        events_of_other_tbl_refs
                                                (List.map (fun (tbl,fld) ->
                                                        (tbl, oldval, find_get_record tbl ~__context:context ~self:oldval)) (other_tbl_refs_for_this_field tblname fldname)) 
                                else [] in
                        let events_new_val =
-                               if Db_cache_impl.is_valid_ref newval then
+                               if is_valid_ref newval then
                                        events_of_other_tbl_refs
                                                (List.map (fun (tbl,fld) ->
                                                        (tbl, newval, find_get_record tbl ~__context:context ~self:newval)) (other_tbl_refs_for_this_field tblname fldname)) 
@@ -131,7 +137,7 @@ let database_callback event db =
                        let other_tbl_refs =
                                List.fold_left (fun accu (remote_tbl,fld) ->
                                        let fld_value = List.assoc fld kv in
-                                       if Db_cache_impl.is_valid_ref fld_value 
+                                       if is_valid_ref fld_value 
                                        then (remote_tbl, fld_value, find_get_record remote_tbl ~__context:context ~self:fld_value) :: accu 
                                        else accu) 
                                        [] other_tbl_refs in
@@ -150,7 +156,7 @@ let database_callback event db =
                        let other_tbl_refs =
                                List.fold_left (fun accu (tbl,fld) ->
                                        let fld_value = List.assoc fld kv in
-                                       if Db_cache_impl.is_valid_ref fld_value 
+                                       if is_valid_ref fld_value 
                                        then (tbl, fld_value, find_get_record tbl ~__context:context ~self:fld_value) :: accu
                                        else accu) 
                                        [] other_tbl_refs in
index f0a0f5aac9675ee41eab8d2292ebb8cf6930458c..c0625036a5d8b8fee22ffc2a8eb6dbfa2b44f6e0 100644 (file)
@@ -26,7 +26,8 @@ let string_of (x: indexrec) =
   Printf.sprintf "%s%s" x.uuid (Opt.default "" (Opt.map (fun name -> Printf.sprintf " (%s)" name) x.name_label))
 
 let lookup key =
-       let db = Db_backend.get_database () in
+       let t = Db_backend.make () in
+       let db = Db_ref.get_database t in
        let r (tblname, objref) = 
                let row = Table.find objref (TableSet.find tblname (Database.tableset db)) in {
                        name_label = (try Some (Row.find Db_names.name_label row) with _ -> None);
index e969777ea8e486f2ea0597abbf6a9bed6e5d2263..037e97098de62ae58ff4540e1e82d419ae295e09 100644 (file)
@@ -50,15 +50,12 @@ let initialise_db_connections() =
   dbs
 
 let read_in_database() =
-  (* Make sure we're running in master mode: we cannot be a slave
-     and then access the dbcache *)
-  Db_cache.set_master true;
   let connections = initialise_db_connections() in
   (* Initialiase in-memory database cache *)
-  Db_cache_impl.make connections Schema.empty
+  Db_cache_impl.make (Db_backend.make ()) connections Schema.empty
 
 let write_out_databases() =
-       Db_cache_impl.sync (Db_conn_store.read_db_connections ()) (Db_backend.get_database ())
+       Db_cache_impl.sync (Db_conn_store.read_db_connections ()) (Db_ref.get_database (Db_backend.make ()))
 
 (* should never be thrown due to checking argument at start *)
 exception UnknownFormat
@@ -71,7 +68,7 @@ let write_out_database filename =
                        Parse_db_conf.path=filename;
                        Parse_db_conf.mode=Parse_db_conf.No_limit;
                        Parse_db_conf.compress=(!compress)
-    } ] (Db_backend.get_database ())
+    } ] (Db_ref.get_database (Db_backend.make ()))
 
 let help_pad = "      "
 let operation_list =
@@ -93,7 +90,7 @@ let do_write_database() =
   begin
     read_in_database();
     if !xmltostdout then
-               Db_xml.To.fd (Unix.descr_of_out_channel stdout) (Db_backend.get_database ())
+               Db_xml.To.fd (Unix.descr_of_out_channel stdout) (Db_ref.get_database (Db_backend.make()))
     else
       write_out_database !filename
   end
@@ -101,7 +98,7 @@ let do_write_database() =
 let find_my_host_row() =
   Xapi_inventory.read_inventory ();
   let localhost_uuid = Xapi_inventory.lookup Xapi_inventory._installation_uuid in
-  let db = Db_backend.get_database () in
+  let db = Db_ref.get_database (Db_backend.make ()) in
   let tbl = TableSet.find Db_names.host (Database.tableset db) in
   Table.fold (fun r row acc -> if Row.find Db_names.uuid row = localhost_uuid then (Some (r, row)) else acc) tbl None
 
@@ -136,7 +133,7 @@ let do_write_hostiqn() =
                                  (* ... otherwise add new key/value pair *)
                                  (_iscsi_iqn,new_iqn)::other_config in
                  let other_config = String_marshall_helper.map (fun x->x) (fun x->x) other_config in
-                 Db_backend.update_database (set_field_in_row Db_names.host r Db_names.other_config other_config);
+                 Db_ref.update_database (Db_backend.make ()) (set_field_in_row Db_names.host r Db_names.other_config other_config);
                  write_out_databases()
 
 let do_am_i_in_the_database () = 
index 487722bd0e19824980835b23740a87b6a0a5c54e..4d2ead8f7133912e9f9b7f3a22bd9882039c0ee2 100644 (file)
@@ -82,6 +82,7 @@ SERVER_OBJS = ../../database/escaping locking_helpers \
        $(AUTOGEN_HELPER_DIR)/db_exn \
        $(AUTOGEN_HELPER_DIR)/ref_index \
        $(AUTOGEN_HELPER_DIR)/db_backend \
+       $(AUTOGEN_HELPER_DIR)/db_ref \
        $(AUTOGEN_HELPER_DIR)/backend_xml \
        $(AUTOGEN_HELPER_DIR)/generation \
        $(AUTOGEN_HELPER_DIR)/db_connections \
index 119463c44106b681a5d5276ba8083218551b729c..a923ea77c12f8fe0a1c98ce4127596da12dfbb28 100644 (file)
@@ -39,6 +39,7 @@ type t = { session_id: API.ref_session option;
           forwarded_task : bool;
           origin: origin;
           task_name: string; (* Name for dummy task FIXME: used only for dummy task, as real task as their name in the database *)
+          database: Db_ref.t;
         }
 
 let get_session_id x =
@@ -75,6 +76,8 @@ let string_of x =
     (string_of_origin x.origin)
     x.task_name
 
+let database_of x = x.database
+
 (** Calls coming in from the unix socket are pre-authenticated *)
 let is_unix_socket s =
   match Unix.getpeername s with
@@ -88,6 +91,10 @@ let is_unencrypted s =
     | Unix.ADDR_INET (addr, _) when addr = Unix.inet_addr_loopback -> false
     | Unix.ADDR_INET _ -> true
 
+let default_database () = 
+       if Pool_role.is_master ()
+       then Db_backend.make ()
+       else Db_ref.Remote
 
 let preauth ~__context =
   match __context.origin with
@@ -101,6 +108,7 @@ let initial =
     forwarded_task = false;
     origin = Internal;
     task_name = "initial_task";
+       database = default_database ();
   }
 
 (* ref fn used to break the cyclic dependency between context, db_actions and taskhelper *)
@@ -152,9 +160,11 @@ let from_forwarded_task ?(__context=initial) ?session_id ?(origin=Internal) task
         forwarded_task = true;
         task_in_database = not (Ref.is_dummy task_id);
         origin = origin;
-        task_name = task_name } 
+        task_name = task_name;
+               database = default_database ();
+         } 
 
-let make ?(__context=initial) ?(quiet=false) ?subtask_of ?session_id ?(task_in_database=false) ?task_description ?(origin=Internal) task_name =
+let make ?(__context=initial) ?(quiet=false) ?subtask_of ?session_id ?(database=default_database ()) ?(task_in_database=false) ?task_description ?(origin=Internal) task_name =
   let task_id, task_uuid =
     if task_in_database 
     then !__make_task ~__context ?description:task_description ?session_id ?subtask_of task_name
@@ -177,6 +187,7 @@ let make ?(__context=initial) ?(quiet=false) ?subtask_of ?session_id ?(task_in_d
         | Some subtask_of -> " by task " ^ !__string_of_task "" subtask_of)
     ;
     { session_id = session_id;
+         database = database;
       task_id = task_id;
       task_in_database = task_in_database;
       origin = origin;
index 9bc387e12571ebaa21d810a0e9f442dd35ed2a10..0f1c104b8ecd83fc3900c1d73ff7236909b96df0 100644 (file)
@@ -24,11 +24,12 @@ type origin =
 (** [initial] is the initial context. *)
 val initial : t
 
-(** [make ~__context ~subtask_of ~session_id ~task_in_database ~task_description ~origin name] creates a new context. 
+(** [make ~__context ~subtask_of ~database ~session_id ~task_in_database ~task_description ~origin name] creates a new context. 
     [__context] is the calling context,
        [quiet] silences "task created" log messages,
     [subtask_of] is a reference to the parent task, 
     [session_id] is the current session id,
+       [database] is the database to use in future Db.* operations
     [task_in_database] indicates if the task needs to be stored the task in the database, 
     [task_descrpition] is the description of the task,
     [task_name] is the task name of the created context. *)
@@ -37,6 +38,7 @@ val make :
   ?quiet:bool ->
   ?subtask_of:API.ref_task ->
   ?session_id:API.ref_session ->
+  ?database:Db_ref.t ->
   ?task_in_database:bool ->
   ?task_description:string -> ?origin:origin -> string -> t
 
@@ -75,6 +77,9 @@ val get_origin : t -> string
 (** [string_of __context] returns a string representing the context. *)
 val string_of : t -> string
 
+(** [database_of __context] returns a database handle, which can be used by Db.* *)
+val database_of : t -> Db_ref.t
+
 (** {6 Destructors} *)
 
 val destroy : t -> unit
index 506cffb1e6b4c36350214c05bda3597fa3b90efd..243e6fa881804e0e89f8f2bfcede69faa6bd5704 100644 (file)
@@ -168,9 +168,9 @@ let read_set_ref obj other full_name =
   (* Set(Ref t) is actually stored in the table t *)
   let obj', fld' = look_up_related_table_and_field obj other full_name in
   String.concat "\n" [
-         Printf.sprintf "if not(DB.is_valid_ref %s)" Client._self;
+         Printf.sprintf "if not(DB.is_valid_ref __t %s)" Client._self;
          Printf.sprintf "then raise (Api_errors.Server_error(Api_errors.handle_invalid, [ %s ]))" Client._self;
-         Printf.sprintf "else List.map %s.%s (DB.read_set_ref " _string_to_dm (OU.alias_of_ty (DT.Ref other));
+         Printf.sprintf "else List.map %s.%s (DB.read_set_ref __t " _string_to_dm (OU.alias_of_ty (DT.Ref other));
          Printf.sprintf "    { table = \"%s\"; return=Db_names.ref; " (Escaping.escape_obj obj');
          Printf.sprintf "      where_field = \"%s\"; where_value = %s })" fld' Client._self
   ]
@@ -178,7 +178,7 @@ let read_set_ref obj other full_name =
 let get_record (obj: obj) aux_fn_name =
   let body =
     [
-      Printf.sprintf "let (__regular_fields, __set_refs) = DB.read_record \"%s\" %s in" 
+      Printf.sprintf "let (__regular_fields, __set_refs) = DB.read_record __t \"%s\" %s in" 
        (Escaping.escape_obj obj.DT.name) Client._self;
       aux_fn_name^" ~__regular_fields ~__set_refs";
     ] in
@@ -218,7 +218,9 @@ let make_shallow_copy api (obj: obj) (src: string) (dst: string) (all_fields: fi
       (String.concat "; " (List.map (fun f -> "\"" ^ f ^ "\"") sql_fields))
 *)
 
-let open_db_module = "let module DB = (val (Db_cache.get ()) : Db_interface.DB_ACCESS) in\n"
+let open_db_module = 
+       "let __t = Context.database_of __context in\n" ^
+               "let module DB = (val (Db_cache.get __t) : Db_interface.DB_ACCESS) in\n"
 
 let db_action api : O.Module.t =
   let api = make_db_api api in
@@ -232,7 +234,7 @@ let db_action api : O.Module.t =
       ~name: "get_refs_where"
       ~params: [ Gen_common.context_arg; expr_arg ]
       ~ty: ( OU.alias_of_ty (Ref obj.DT.name) ^ " list")
-      ~body: [ open_db_module; "let refs = (DB.find_refs_with_filter \"" ^ tbl ^ "\" " ^ expr ^ ") in ";
+      ~body: [ open_db_module; "let refs = (DB.find_refs_with_filter __t \"" ^ tbl ^ "\" " ^ expr ^ ") in ";
               "List.map Ref.of_string refs " ] () in
 
     let get_record_aux_fn_body ?(m="API.") (obj: obj) (all_fields: field list) =
@@ -279,7 +281,7 @@ let db_action api : O.Module.t =
          ~params: [ Gen_common.context_arg; expr_arg ]
          ~ty: ("'a")
                ~body: [ open_db_module;
-               Printf.sprintf "let records = DB.read_records_where \"%s\" %s in"
+               Printf.sprintf "let records = DB.read_records_where __t \"%s\" %s in"
                     (Escaping.escape_obj obj.DT.name) expr;
                   Printf.sprintf "List.map (fun (ref,(__regular_fields,__set_refs)) -> Ref.of_string ref, %s __regular_fields __set_refs) records" conversion_fn] () in
       
@@ -307,7 +309,7 @@ let db_action api : O.Module.t =
 
     let body = match tag with
       | FromField(Setter, fld) ->
-         Printf.sprintf "DB.write_field (*__context*) \"%s\" %s \"%s\" value"
+         Printf.sprintf "DB.write_field __t \"%s\" %s \"%s\" value"
            (Escaping.escape_obj obj.DT.name)
            Client._self
            (Escaping.escape_id fld.DT.full_name)
@@ -324,31 +326,31 @@ let db_action api : O.Module.t =
            (Escaping.escape_obj obj') fld' Client._self
 *)
       | FromField(Getter, { DT.ty = ty; full_name = full_name }) ->
-         Printf.sprintf "%s.%s (DB.read_field (*__context*) \"%s\" \"%s\" %s)"
+         Printf.sprintf "%s.%s (DB.read_field __t \"%s\" \"%s\" %s)"
            _string_to_dm (OU.alias_of_ty ty)
            (Escaping.escape_obj obj.DT.name)
            (Escaping.escape_id full_name)
            Client._self
       | FromField(Add, { DT.ty = DT.Map(_, _); full_name = full_name }) ->
-         Printf.sprintf "DB.process_structured_field (*__context*) (%s,%s) \"%s\" \"%s\" %s AddMap"
+         Printf.sprintf "DB.process_structured_field __t (%s,%s) \"%s\" \"%s\" %s AddMap"
             Client._key Client._value
            (Escaping.escape_obj obj.DT.name)
            (Escaping.escape_id full_name)
            Client._self
       | FromField(Add, { DT.ty = DT.Set(_); full_name = full_name }) ->
-         Printf.sprintf "DB.process_structured_field (*__context*) (%s,\"\") \"%s\" \"%s\" %s AddSet"
+         Printf.sprintf "DB.process_structured_field __t (%s,\"\") \"%s\" \"%s\" %s AddSet"
             Client._value
            (Escaping.escape_obj obj.DT.name)
            (Escaping.escape_id full_name)
            Client._self
       | FromField(Remove, { DT.ty = DT.Map(_, _); full_name = full_name }) ->
-         Printf.sprintf "DB.process_structured_field (*__context*) (%s,\"\") \"%s\" \"%s\" %s RemoveMap"
+         Printf.sprintf "DB.process_structured_field __t (%s,\"\") \"%s\" \"%s\" %s RemoveMap"
             Client._key
            (Escaping.escape_obj obj.DT.name)
            (Escaping.escape_id full_name)
            Client._self
       | FromField(Remove, { DT.ty = DT.Set(_); full_name = full_name }) ->
-         Printf.sprintf "DB.process_structured_field (*__context*) (%s,\"\") \"%s\" \"%s\" %s RemoveSet"
+         Printf.sprintf "DB.process_structured_field __t (%s,\"\") \"%s\" \"%s\" %s RemoveSet"
             Client._value
            (Escaping.escape_obj obj.DT.name)
            (Escaping.escape_id full_name)
@@ -357,7 +359,7 @@ let db_action api : O.Module.t =
       | FromField((Add | Remove), _) -> failwith "Cannot generate db add/remove for non sets and maps"
 
       | FromObject(Delete) ->
-         (Printf.sprintf "DB.delete_row (*__context*) \"%s\" %s"
+         (Printf.sprintf "DB.delete_row __t \"%s\" %s"
            (Escaping.escape_obj obj.DT.name) Client._self)
       | FromObject(Make) ->
          let fields = List.filter field_in_this_table (DU.fields_of_obj obj) in
@@ -367,13 +369,13 @@ let db_action api : O.Module.t =
                                OU.escape (OU.ocaml_of_id fld.full_name)) fields  in
          let kvs' = List.map (fun (sql, o) ->
                                 Printf.sprintf "(\"%s\", %s)" sql o) kvs in
-         Printf.sprintf "DB.create_row (*__context*) \"%s\" [ %s ] ref"
+         Printf.sprintf "DB.create_row __t \"%s\" [ %s ] ref"
            (Escaping.escape_obj obj.DT.name)
            (String.concat "; " kvs') 
       | FromObject(GetByUuid) ->
          begin match x.msg_params, x.msg_result with
          | [ {param_type=ty; param_name=name} ], Some (result_ty, _) ->
-             let query = Printf.sprintf "DB.db_get_by_uuid \"%s\" %s"
+             let query = Printf.sprintf "DB.db_get_by_uuid __t \"%s\" %s"
                (Escaping.escape_obj obj.DT.name)
                (OU.escape name) in
              _string_to_dm ^ "." ^ (OU.alias_of_ty result_ty) ^ " (" ^ query ^ ")"
@@ -382,7 +384,7 @@ let db_action api : O.Module.t =
       | FromObject(GetByLabel) ->
          begin match x.msg_params, x.msg_result with
          | [ {param_type=ty; param_name=name} ], Some (Set result_ty, _) ->
-             let query = Printf.sprintf "DB.db_get_by_name_label \"%s\" %s"
+             let query = Printf.sprintf "DB.db_get_by_name_label __t \"%s\" %s"
                (Escaping.escape_obj obj.DT.name)
                (OU.escape name) in
              if DU.obj_has_get_by_name_label obj
@@ -398,7 +400,7 @@ let db_action api : O.Module.t =
             Eventually we'll need to provide user filtering for the public version *)
          begin match x.msg_result with
          | Some (Set result_ty, _) ->
-             let query = Printf.sprintf "DB.read_refs \"%s\""
+             let query = Printf.sprintf "DB.read_refs __t \"%s\""
                (Escaping.escape_obj obj.DT.name) in
              "List.map " ^ _string_to_dm ^ "." ^ (OU.alias_of_ty result_ty) ^ "(" ^ query ^ ")"
          | _ -> failwith "GetAll call needs a result type"
index 4acb7b28e8bc7329f4ea0f8639b958d0e7d137d1..7c04bd1293fc015172c891f3e588bad78d00a73d 100644 (file)
@@ -2837,7 +2837,7 @@ let vm_import fd printer rpc session_id params =
                                                        let host =
                                                                if sr<>Ref.null
                                                                then Importexport.find_host_for_sr ~__context sr
-                                                               else Helpers.get_localhost ()
+                                                               else Helpers.get_localhost __context
                                                        in
                                                        let address = Client.Host.get_address rpc session_id host in
                                                        (* Although it's inefficient use a loopback HTTP connection *)
index a7f006cf6b2f0d66b3326d67c9741cb420aaee8d..2d5508355f1c7743d591ffff06c1d96cb663dee8 100644 (file)
@@ -78,9 +78,10 @@ let console_of_request __context req =
   (* The _ref may be either a VM ref in which case we look for a
      default VNC console or it may be a console ref in which case we
      go for that. *)
+  let db = Context.database_of __context in
   let is_vm, is_console = 
-         let module DB = (val (Db_cache.get ()) : Db_interface.DB_ACCESS) in     
-         match DB.get_table_from_ref _ref with
+         let module DB = (val (Db_cache.get db) : Db_interface.DB_ACCESS) in     
+         match DB.get_table_from_ref db _ref with
        | Some c when c = Db_names.vm -> true, false
        | Some c when c = Db_names.console -> false, true
        | _ ->
index 22628da58eafa1663c119c1b5d7a9a0f58c72fca..fc09070d19a1afb38f05a5b91a5a021c23a36fc7 100644 (file)
@@ -117,7 +117,7 @@ and ensure_domain_zero_console_record ~__context ~domain_zero_ref =
                        create_domain_zero_console_record ~__context ~domain_zero_ref
                | [console_ref] ->
                        (* if there's a single reference but it's invalid, make a new one: *)
-                       if not (Db.is_valid_ref console_ref) then
+                       if not (Db.is_valid_ref __context console_ref) then
                                create_domain_zero_console_record ~__context ~domain_zero_ref
                | _ ->
                        (* if there's more than one console then something strange is *)
@@ -125,7 +125,7 @@ and ensure_domain_zero_console_record ~__context ~domain_zero_ref =
                        create_domain_zero_console_record ~__context ~domain_zero_ref
 
 and ensure_domain_zero_guest_metrics_record ~__context ~domain_zero_ref =
-       if not (Db.is_valid_ref (Db.VM.get_metrics ~__context ~self:domain_zero_ref)) then
+       if not (Db.is_valid_ref __context (Db.VM.get_metrics ~__context ~self:domain_zero_ref)) then
        begin
                debug "Domain 0 record does not have associated guest metrics record. Creating now";
                let metrics_ref = Ref.make() in
index 4956dfb9b5d67d2ddd2bb13a542f64cd4542a0aa..e2025bc6a541777d1bd1eda578b5190bdade60df 100644 (file)
@@ -16,6 +16,7 @@
  *)
  
 include Db_actions.DB_Action
-let is_valid_ref r =
-       let module DB = (val (Db_cache.get ()) : Db_interface.DB_ACCESS) in
-       DB.is_valid_ref (Ref.string_of r)
+let is_valid_ref __context r =
+       let t = Context.database_of __context in
+       let module DB = (val (Db_cache.get t) : Db_interface.DB_ACCESS) in
+       DB.is_valid_ref t (Ref.string_of r)
index a192313ce17083456ea821ffeda40f49cdb4f64b..9d5d19fa7d1ae5f8a50f50969b24e92e5cf90eac 100644 (file)
@@ -39,7 +39,8 @@ let _time = "time"
 let valid_ref x = Db.is_valid_ref x
 
 let gc_connector ~__context get_all get_record valid_ref1 valid_ref2 delete_record =
-       let module DB = (val (Db_cache.get ()) : Db_interface.DB_ACCESS) in       
+       let db = Context.database_of __context in
+       let module DB = (val (Db_cache.get db) : Db_interface.DB_ACCESS) in       
   let all_refs = get_all ~__context in
   let do_gc ref =
     let print_valid b = if b then "valid" else "INVALID" in
@@ -50,7 +51,7 @@ let gc_connector ~__context get_all get_record valid_ref1 valid_ref2 delete_reco
       if not (ref_1_valid && ref_2_valid) then
        begin
          let table,reference,valid1,valid2 = 
-           (match DB.get_table_from_ref (Ref.string_of ref) with
+           (match DB.get_table_from_ref db (Ref.string_of ref) with
                 None -> "UNKNOWN CLASS"
               | Some c -> c),
            (Ref.string_of ref),
@@ -62,7 +63,7 @@ let gc_connector ~__context get_all get_record valid_ref1 valid_ref2 delete_reco
   List.iter do_gc all_refs
 
 let gc_PIFs ~__context =
-  gc_connector ~__context Db.PIF.get_all Db.PIF.get_record (fun x->valid_ref x.pIF_host) (fun x->valid_ref x.pIF_network) 
+  gc_connector ~__context Db.PIF.get_all Db.PIF.get_record (fun x->valid_ref __context x.pIF_host) (fun x->valid_ref __context x.pIF_network) 
     (fun ~__context ~self ->
        (* We need to destroy the PIF, it's metrics and any VLAN/bond records that this PIF was a master of. *)
        (* bonds_to_gc is actually a list which is either empty (not part of a bond) or containing exactly one reference.. *)
@@ -74,10 +75,10 @@ let gc_PIFs ~__context =
        List.iter (fun bond -> (try Db.Bond.destroy ~__context ~self:bond with _ -> ())) bonds_to_gc;
        Db.PIF.destroy ~__context ~self)
 let gc_VBDs ~__context =
-  gc_connector ~__context Db.VBD.get_all Db.VBD.get_record (fun x->valid_ref x.vBD_VM) (fun x->valid_ref x.vBD_VDI || x.vBD_empty) 
+  gc_connector ~__context Db.VBD.get_all Db.VBD.get_record (fun x->valid_ref __context x.vBD_VM) (fun x->valid_ref __context x.vBD_VDI || x.vBD_empty) 
     (fun ~__context ~self ->
       (* When GCing VBDs that are CDs, set them to empty rather than destroy them entirely *)
-      if (valid_ref (Db.VBD.get_VM ~__context ~self)) && (Db.VBD.get_type ~__context ~self = `CD) then
+      if (valid_ref __context (Db.VBD.get_VM ~__context ~self)) && (Db.VBD.get_type ~__context ~self = `CD) then
        begin
          Db.VBD.set_VDI ~__context ~self ~value:Ref.null;
          Db.VBD.set_empty ~__context ~self ~value:true;
@@ -92,22 +93,22 @@ let gc_VBDs ~__context =
 
 let gc_crashdumps ~__context =
   gc_connector ~__context Db.Crashdump.get_all Db.Crashdump.get_record
-    (fun x->valid_ref x.crashdump_VM) (fun x->valid_ref x.crashdump_VDI) Db.Crashdump.destroy
+    (fun x->valid_ref __context x.crashdump_VM) (fun x->valid_ref __context x.crashdump_VDI) Db.Crashdump.destroy
 let gc_VIFs ~__context =
-  gc_connector ~__context Db.VIF.get_all Db.VIF.get_record (fun x->valid_ref x.vIF_VM) (fun x->valid_ref x.vIF_network)
+  gc_connector ~__context Db.VIF.get_all Db.VIF.get_record (fun x->valid_ref __context x.vIF_VM) (fun x->valid_ref __context x.vIF_network)
     (fun ~__context ~self ->
        let metrics = Db.VIF.get_metrics ~__context ~self in
        (try Db.VIF_metrics.destroy ~__context ~self:metrics with _ -> ());
        Db.VIF.destroy ~__context ~self)
 let gc_PBDs ~__context =
-  gc_connector ~__context Db.PBD.get_all Db.PBD.get_record (fun x->valid_ref x.pBD_host) (fun x->valid_ref x.pBD_SR) Db.PBD.destroy
+  gc_connector ~__context Db.PBD.get_all Db.PBD.get_record (fun x->valid_ref __context x.pBD_host) (fun x->valid_ref __context x.pBD_SR) Db.PBD.destroy
 let gc_Host_patches ~__context =
-  gc_connector ~__context Db.Host_patch.get_all Db.Host_patch.get_record (fun x->valid_ref x.host_patch_host) (fun x->valid_ref x.host_patch_pool_patch) Db.Host_patch.destroy
+  gc_connector ~__context Db.Host_patch.get_all Db.Host_patch.get_record (fun x->valid_ref __context x.host_patch_host) (fun x->valid_ref __context x.host_patch_pool_patch) Db.Host_patch.destroy
 let gc_host_cpus ~__context =
   let host_cpus = Db.Host_cpu.get_all ~__context in
     List.iter
       (fun hcpu ->
-        if not (valid_ref (Db.Host_cpu.get_host ~__context ~self:hcpu)) then
+        if not (valid_ref __context (Db.Host_cpu.get_host ~__context ~self:hcpu)) then
           Db.Host_cpu.destroy ~__context ~self:hcpu) host_cpus
 
 (* If the SR record is missing, delete the VDI record *)
index a7f8d9a684fea2fb6449d1baa17079428d8d3315..b0ef35967d26c9f5a75430ee7f1fa14f84e92676 100644 (file)
@@ -36,7 +36,7 @@ let create_host_metrics ~__context =
   List.iter 
     (fun self ->
        let m = Db.Host.get_metrics ~__context ~self in
-       if not(Db.is_valid_ref m) then begin
+       if not(Db.is_valid_ref __context m) then begin
         debug "Creating missing Host_metrics object for Host: %s" (Db.Host.get_uuid ~__context ~self);
         let r = Ref.make () in
         Db.Host_metrics.create ~__context ~ref:r
index c9a081198e83616ec02cbec3aa8f2dc5dc7e503d..be9d30c562c12d5ff758987f93f0e5b9bb51efd1 100644 (file)
@@ -90,7 +90,7 @@ let refresh_console_urls ~__context =
 let reset_vms_running_on_missing_hosts ~__context =
   List.iter (fun vm ->
               let vm_r = Db.VM.get_record ~__context ~self:vm in
-              let valid_resident_on = Db.is_valid_ref vm_r.API.vM_resident_on in
+              let valid_resident_on = Db.is_valid_ref __context vm_r.API.vM_resident_on in
               if not valid_resident_on then begin
                 if vm_r.API.vM_is_control_domain then begin
                   info "Deleting control domain VM uuid '%s' ecause VM.resident_on refers to a Host which is nolonger in the Pool" vm_r.API.vM_uuid;
@@ -200,7 +200,7 @@ let clear_uncooperative_flags_noexn __context = Helpers.log_exn_continue "cleari
 let ensure_vm_metrics_records_exist __context = 
   List.iter (fun vm ->
                                 let m = Db.VM.get_metrics ~__context ~self:vm in
-                                if not(Db.is_valid_ref m) then begin
+                                if not(Db.is_valid_ref __context m) then begin
                                   info "Regenerating missing VM_metrics record for VM %s" (Ref.string_of vm);
                                   let m = Ref.make () in
                                   let uuid = Uuid.to_string (Uuid.make_uuid ()) in
index 81e45584b010d34760a62d6f314efe25d2f08d3d..718c2812b2994ab40cd736ca62afce9d7957ed1c 100644 (file)
@@ -180,7 +180,7 @@ let update_vms ~xal ~__context =
   (* Remove all the scheduled_to_be_resident_on VMs which are resident_on somewhere since that host 'owns' them.
      NB if resident_on this host the VM will still be counted in the all_resident_on_vms set *)
   let really_my_scheduled_to_be_resident_on_vms = 
-    List.filter (fun (_, vm_r) -> not (Db.is_valid_ref vm_r.API.vM_resident_on)) all_scheduled_to_be_resident_on_vms in
+    List.filter (fun (_, vm_r) -> not (Db.is_valid_ref __context vm_r.API.vM_resident_on)) all_scheduled_to_be_resident_on_vms in
   let all_vms_assigned_to_me = Listext.List.setify (all_resident_on_vms @ really_my_scheduled_to_be_resident_on_vms) in
 
   let all_vbds = Db.VBD.get_records_where ~__context ~expr:Db_filter_types.True in
@@ -234,7 +234,7 @@ let update_vms ~xal ~__context =
         List.iter
           (fun vbd ->
              try
-                       if Db.is_valid_ref vbd && not (Db.VBD.get_empty ~__context ~self:vbd)
+                       if Db.is_valid_ref __context vbd && not (Db.VBD.get_empty ~__context ~self:vbd)
                        then Events.Resync.vbd ~__context token vmref vbd
              with e ->
                warn "Caught error resynchronising VBD: %s" (ExnHelper.string_of_exn e)) vm_vbds;
@@ -242,7 +242,7 @@ let update_vms ~xal ~__context =
         List.iter 
           (fun vif ->
              try
-                       if Db.is_valid_ref vif
+                       if Db.is_valid_ref __context vif
                        then Events.Resync.vif ~__context token vmref vif
              with e ->
                warn "Caught error resynchronising VIF: %s" (ExnHelper.string_of_exn e)) vm_vifs
@@ -450,7 +450,7 @@ let remove_all_leaked_vbds __context =
  * For example, this will prevent needless glitches in storage interfaces.
  *)
 let resynchronise_pif_params ~__context =
-       let localhost = Helpers.get_localhost () in
+       let localhost = Helpers.get_localhost ~__context in
        (* 1. Acquire data. We minimise round-trips not bandwidth *)
        let networks = Db.Network.get_all_records ~__context in
        let expr = Db_filter_types.Eq(Db_filter_types.Field "host", Db_filter_types.Literal (Ref.string_of localhost)) in
index 71be3c8a1833f3789a4ba82f6cb018619de7c45e..5f37682acd5376a62be340bac2a4e1f330f14e28 100644 (file)
@@ -76,17 +76,17 @@ let rec update_table ~__context ~include_snapshots ~preserve_power_state ~includ
       end
   in
 
-  if Db.is_valid_ref vm && not (Hashtbl.mem table (Ref.string_of vm)) then begin
+  if Db.is_valid_ref __context vm && not (Hashtbl.mem table (Ref.string_of vm)) then begin
   add vm;
   let vm = Db.VM.get_record ~__context ~self:vm in
   List.iter 
-       (fun vif -> if Db.is_valid_ref vif then begin
+       (fun vif -> if Db.is_valid_ref __context vif then begin
               add vif;
               let vif = Db.VIF.get_record ~__context ~self:vif in
               add vif.API.vIF_network end) 
        vm.API.vM_VIFs;
   List.iter 
-       (fun vbd -> if Db.is_valid_ref vbd then begin
+       (fun vbd -> if Db.is_valid_ref __context vbd then begin
               add vbd;
               let vbd = Db.VBD.get_record ~__context ~self:vbd in
               if not(vbd.API.vBD_empty)
@@ -101,7 +101,7 @@ let rec update_table ~__context ~include_snapshots ~preserve_power_state ~includ
                  vm.API.vM_snapshots;
   (* If VM is suspended then add the suspend_VDI *)
   let vdi = vm.API.vM_suspend_VDI in
-  if preserve_power_state && vm.API.vM_power_state = `Suspended && Db.is_valid_ref vdi then begin
+  if preserve_power_state && vm.API.vM_power_state = `Suspended && Db.is_valid_ref __context vdi then begin
     add_vdi vdi
   end;
   (* Add also the guest metrics *)
index b12a078040d6a0ec373f1832e0db7636a4298884..4a3580ecca740c8360ad812b55fc6d799434e362 100644 (file)
@@ -778,14 +778,14 @@ let touch_file fname =
   with
   | e -> (warn "Unable to touch ready file '%s': %s" fname (Printexc.to_string e))
 
-let vm_to_string vm = 
+let vm_to_string __context vm = 
        let str = Ref.string_of vm in
 
-       if not (Db.is_valid_ref vm)
+       if not (Db.is_valid_ref __context vm)
        then raise (Api_errors.Server_error(Api_errors.invalid_value ,[str]));
-
-       let module DB = (val (Db_cache.get ()) : Db_interface.DB_ACCESS) in
-       let fields = fst (DB.read_record Db_names.vm str) in
+       let t = Context.database_of __context in
+       let module DB = (val (Db_cache.get t) : Db_interface.DB_ACCESS) in
+       let fields = fst (DB.read_record Db_names.vm str) in
        let sexpr = SExpr.Node (List.map (fun (key,value) -> SExpr.Node [SExpr.String key; SExpr.String value]) fields) in
        SExpr.string_of sexpr
 
index f0328c63219f7ef599cc6ee64c16bae9f5d2b08b..2e6e6bdd3f9eeaf826364dfe6832b753c1a51789 100644 (file)
@@ -34,7 +34,7 @@ let vdi_of_req ~__context (req: request) =
                if List.mem_assoc "vdi" all
                then List.assoc "vdi" all
                else raise (Failure "Missing vdi query parameter") in
-       if Db.is_valid_ref (Ref.of_string vdi) 
+       if Db.is_valid_ref __context (Ref.of_string vdi) 
        then Ref.of_string vdi 
        else Db.VDI.get_by_uuid ~__context ~uuid:vdi
 
index db3ecc26eaeef52c7e78df08bce6db82fce29cd8..4b76819f3d92ce109547c272ce9ec16b63ff81aa 100644 (file)
@@ -613,7 +613,7 @@ module Forward = functor(Local: Custom_actions.CUSTOM_ACTIONS) -> struct
       let task_id = Ref.string_of (Context.get_task_id __context) in
       iter_with_drop ~doc:("unmarking VBDs after " ^ doc)
        (fun self -> 
-               if Db.is_valid_ref self then begin
+               if Db.is_valid_ref __context self then begin
                        Db.VBD.remove_from_current_operations ~__context ~self ~key:task_id;
                        Xapi_vbd_helpers.update_allowed_operations ~__context ~self;
                        Early_wakeup.broadcast (Datamodel._vbd, Ref.string_of self);
@@ -651,7 +651,7 @@ module Forward = functor(Local: Custom_actions.CUSTOM_ACTIONS) -> struct
       let task_id = Ref.string_of (Context.get_task_id __context) in
       iter_with_drop ~doc:("unmarking VIFs after " ^ doc)
        (fun self ->
-               if Db.is_valid_ref self then begin 
+               if Db.is_valid_ref __context self then begin 
                        Db.VIF.remove_from_current_operations ~__context ~self ~key:task_id;
                        Xapi_vif_helpers.update_allowed_operations ~__context ~self;
                        Early_wakeup.broadcast (Datamodel._vif, Ref.string_of self);
@@ -1280,7 +1280,7 @@ module Forward = functor(Local: Custom_actions.CUSTOM_ACTIONS) -> struct
 
                let vm = Db.VM.get_snapshot_of ~__context ~self:snapshot in
                let vm = 
-                       if Db.is_valid_ref vm
+                       if Db.is_valid_ref __context vm
                        then vm
                        else Xapi_vm_snapshot.create_vm_from_snapshot ~__context ~snapshot in
 
@@ -1298,7 +1298,7 @@ module Forward = functor(Local: Custom_actions.CUSTOM_ACTIONS) -> struct
                  let pbd = choose_pbd_for_sr ~__context ~self:sr () in
                  let host = Db.PBD.get_host ~__context ~self:pbd in
                  let metrics = Db.Host.get_metrics ~__context ~self:host in
-                 let live = Db.is_valid_ref metrics && (Db.Host_metrics.get_live ~__context ~self:metrics) in
+                 let live = Db.is_valid_ref __context metrics && (Db.Host_metrics.get_live ~__context ~self:metrics) in
                  if not live
                  then raise (Api_errors.Server_error(Api_errors.host_not_live, [ Ref.string_of host ]))
                end;
@@ -2310,7 +2310,7 @@ end
       let task_id = Ref.string_of (Context.get_task_id __context) in
       log_exn ~doc:("unmarking VIF after " ^ doc)
        (fun self ->
-               if Db.is_valid_ref self then begin
+               if Db.is_valid_ref __context self then begin
                        Db.VIF.remove_from_current_operations ~__context ~self ~key:task_id;
                        Xapi_vif_helpers.update_allowed_operations ~__context ~self;
                        Early_wakeup.broadcast (Datamodel._vif, Ref.string_of self);
@@ -2515,7 +2515,7 @@ end
       debug "Unmarking SR after %s (task=%s)" doc task_id;
       log_exn_ignore ~doc:("unmarking SR after " ^ doc)
        (fun self -> 
-               if Db.is_valid_ref self then begin
+               if Db.is_valid_ref __context self then begin
                        Db.SR.remove_from_current_operations ~__context ~self ~key:task_id;
                        Xapi_sr.update_allowed_operations ~__context ~self;
                        Early_wakeup.broadcast (Datamodel._sr, Ref.string_of self);
@@ -2659,7 +2659,7 @@ end
       let task_id = Ref.string_of (Context.get_task_id __context) in
       log_exn_ignore ~doc:("unmarking VDI after " ^ doc)
        (fun self -> 
-               if Db.is_valid_ref self then begin
+               if Db.is_valid_ref __context self then begin
                        Db.VDI.remove_from_current_operations ~__context ~self ~key:task_id;
                        Xapi_vdi.update_allowed_operations ~__context ~self;
                        Early_wakeup.broadcast (Datamodel._vdi, Ref.string_of self);
@@ -2891,7 +2891,7 @@ end
       let task_id = Ref.string_of (Context.get_task_id __context) in
       log_exn ~doc:("unmarking VBD after " ^ doc)
        (fun self -> 
-               if Db.is_valid_ref self then begin
+               if Db.is_valid_ref __context self then begin
                        Db.VBD.remove_from_current_operations ~__context ~self ~key:task_id;
                        Xapi_vbd_helpers.update_allowed_operations ~__context ~self;
                        Early_wakeup.broadcast (Datamodel._vbd, Ref.string_of vbd)
index 75b02704b0196a176fbea796c6562ad3ed520965..520d1b62116f2e12fec7b6cfac61e28cd4a9bb72 100644 (file)
@@ -26,7 +26,7 @@ open D
 let set_vm_metrics ~__context ~vm ~memory ~cpus =
        (* if vm metrics don't exist then make one *)
        let metrics = Db.VM.get_metrics ~__context ~self:vm in
-       if not (Db.is_valid_ref metrics) then
+       if not (Db.is_valid_ref __context metrics) then
          begin
            let ref = Ref.make() in
            Db.VM_metrics.create ~__context ~ref ~uuid:(Uuid.to_string (Uuid.make_uuid ()))
@@ -69,7 +69,7 @@ let update_vm_stats ~__context uuid cpus vbds vifs memory =
 
                           (* if vif metrics don't exist then make one *)
                           let metrics = Db.VIF.get_metrics ~__context ~self in
-                          if not (Db.is_valid_ref metrics) then
+                          if not (Db.is_valid_ref __context metrics) then
                             begin
                               let ref = Ref.make() in
                               Db.VIF_metrics.create ~__context ~ref ~uuid:(Uuid.to_string (Uuid.make_uuid ()))
@@ -96,7 +96,7 @@ let update_vm_stats ~__context uuid cpus vbds vifs memory =
 
                          (* if vbd metrics don't exist then make one *)
                           let metrics = Db.VBD.get_metrics ~__context ~self in
-                          if not (Db.is_valid_ref metrics) then
+                          if not (Db.is_valid_ref __context metrics) then
                             begin
                               let ref = Ref.make() in
                               Db.VBD_metrics.create ~__context ~ref ~uuid:(Uuid.to_string (Uuid.make_uuid ()))
@@ -199,7 +199,7 @@ let update_pifs ~__context host pifs =
                        let pif_stats=List.find (fun p -> p.pif_name = real_device_name) pifs in
                        let metrics = Db.PIF.get_metrics ~__context ~self:pifdev in
                        (* if PIF metrics don't exist then create one: *)
-                       if not (Db.is_valid_ref metrics) then
+                       if not (Db.is_valid_ref __context metrics) then
                          begin
                            let ref = Ref.make() in
                            Db.PIF_metrics.create ~__context ~ref ~uuid:(Uuid.to_string (Uuid.make_uuid ())) ~carrier:false
index 8d2e74e2909458171d0d30fe126e01bf3417291a..75c9e0c50ba04021de59551c4e027c63f4b16904 100644 (file)
@@ -505,7 +505,7 @@ let handler (req: Http.request) s =
                
                (* If the resident_on field is valid, or the request isn't 
                   from dbsync, then redirect *)
-               if Db.is_valid_ref host &&
+               if Db.is_valid_ref __context host &&
                  (not (List.mem_assoc "dbsync" query)) then
                  let address = Db.Host.get_address ~__context ~self:host in
                  let url = Printf.sprintf "https://%s%s?%s" address req.Http.uri (String.concat "&" (List.map (fun (a,b) -> a^"="^b) query)) in
index 29cde4acd0d3c3ca2c52b8106ea695d18d286c63..16457fc49ab8d8e6361f1a500409ea511c7fe354 100644 (file)
@@ -100,16 +100,11 @@ let string_of_process_memory_info (x: process_memory_info) =
   Printf.sprintf "size: %d KiB; rss: %d KiB; data: %d KiB; stack: %d KiB"
     x.size x.rss x.data x.stack
 
-let summarise_db_size () = match Db_cache_impl.stats () with
-  | [] -> "(running as slave; no in-memory db cache)"
-  | xs -> Printf.sprintf "(%s)" (String.concat "; " (List.map (fun (tbl, x) -> Printf.sprintf "%s[%d records]" tbl x) xs))
-
 let one () = 
   let pid = Unix.getpid () in
   let pmi = process_memory_info_of_pid pid in
-  let db = summarise_db_size () in
   let mi = string_of_meminfo (meminfo ()) in
-  debug "Process: %s; Database: %s" (string_of_process_memory_info pmi) db;
+  debug "Process: %s" (string_of_process_memory_info pmi);
   debug "System: %s" mi
     
 let last_log = ref 0.
index cf75fc63a2a117315e959fa79d7b0fb95c5ce75a..4b847d8725c1a883bd75fa86a9f074276f405936 100644 (file)
@@ -103,7 +103,7 @@ let bring_pif_down ~__context (pif: API.ref_PIF) =
        (* Check that the PIF is not in-use *)
        let uuid = Db.PIF.get_uuid ~__context ~self:pif in
        let network = Db.PIF.get_network ~__context ~self:pif in
-       Xapi_network_attach_helpers.assert_network_has_no_vifs_in_use_on_me ~__context ~host:(Helpers.get_localhost()) ~network;
+       Xapi_network_attach_helpers.assert_network_has_no_vifs_in_use_on_me ~__context ~host:(Helpers.get_localhost ~__context) ~network;
        Xapi_network_attach_helpers.assert_pif_disallow_unplug_not_set ~__context pif;
        if Db.PIF.get_currently_attached ~__context ~self:pif = true then begin
                 debug "PIF %s has currently_attached set to true; bringing down now" uuid;
index 05870646a1fd1f1eb077331ab4ad0248d5cad310..abe505b2edc7231cd99cf43d3c032b2f3db28105 100644 (file)
@@ -40,7 +40,7 @@ let write_database (s: Unix.file_descr) ~__context =
                let len = String.length minimally_compliant_miami_database in
                ignore (Unix.write s minimally_compliant_miami_database 0 len)
        else
-               Db_xml.To.fd s (Db_backend.get_database ())
+               Db_xml.To.fd s (Db_ref.get_database (Context.database_of __context))
 
 (** Make sure the backup database version is compatible *)
 let version_check db =
@@ -263,7 +263,7 @@ let pool_db_backup_thread () =
       begin
        let hosts = Db.Host.get_all ~__context in
        let hosts = List.filter (fun hostref -> hostref <> !Xapi_globs.localhost_ref) hosts in
-       let generation = Db_lock.with_lock (fun () -> Manifest.generation (Database.manifest (Db_backend.get_database ()))) in
+       let generation = Db_lock.with_lock (fun () -> Manifest.generation (Database.manifest (Db_ref.get_database (Context.database_of __context)))) in
        let dohost host =
          try
            Thread.delay pool_db_sync_timer;
index 6ce48997d9f37a7a83497fa0dccab38e51615e3d..91e2e0eef0719795d9e770f3c1a3388abc447f05 100644 (file)
@@ -46,7 +46,8 @@ let read_from_redo_log staging_path =
                  let conn = Parse_db_conf.make temp_file in
           (* ideally, the reading from the file would also respect the latest_response_time *)
                  let db = Backend_xml.populate (Schema.of_datamodel ()) conn in
-                 Db_backend.update_database (fun _ -> db);
+                 let t = Db_backend.make () in
+                 Db_ref.update_database t (fun _ -> db);
 
           R.debug "Finished reading database from %s into cache (generation = %Ld)" temp_file gen_count;
 
@@ -93,8 +94,9 @@ let read_from_redo_log staging_path =
                  R.debug "Database from redo log has generation %Ld" generation;
         (* Write the in-memory cache to the file *)
                  (* Make sure the generation count is right -- is this necessary? *)
-                 Db_backend.update_database (Db_cache_types.Database.set_generation generation);
-                 let db = Db_backend.get_database () in
+                 let t = Db_backend.make () in
+                 Db_ref.update_database t (Db_cache_types.Database.set_generation generation);
+                 let db = Db_ref.get_database t in
                  Db_xml.To.file staging_path db;
           Unixext.write_string_to_file (staging_path ^ ".generation") (Generation.to_string generation)
     end
index 5974ce9e38340981f6778e7d7cfbfe8de240828d..26f656d82bfff9135754e26b2aedba7c38dab46e 100644 (file)
@@ -47,10 +47,6 @@ let check_control_domain () =
 let startup_check () =
   Sanitycheck.check_for_bad_link ()
     
-(* Tell the dbcache whether we're a master or a slave *)
-let set_db_mode() =
-       Db_cache.set_master (Pool_role.is_master ())
-
 (* Parse db conf file from disk and use this to initialise database connections. This is done on
    both master and slave. On masters the parsed data is used to flush databases to and to populate
    cache; on the slave the parsed data is used to determine where to put backups.
@@ -75,11 +71,12 @@ let start_database_engine () =
        let schema = Schema.of_datamodel () in
        
        let connections = Db_conn_store.read_db_connections () in
-       Db_cache_impl.make connections schema;
-       Db_cache_impl.sync connections (Db_backend.get_database ());
+       let t = Db_backend.make () in
+       Db_cache_impl.make t connections schema;
+       Db_cache_impl.sync connections (Db_ref.get_database t);
 
-       Db_backend.update_database (Database.register_callback "redo_log" Redo_log.database_callback);
-       Db_backend.update_database (Database.register_callback "events" Eventgen.database_callback);
+       Db_ref.update_database t (Database.register_callback "redo_log" Redo_log.database_callback);
+       Db_ref.update_database t (Database.register_callback "events" Eventgen.database_callback);
 
   debug "Performing initial DB GC";
   Db_gc.single_pass ();
@@ -472,7 +469,7 @@ let resynchronise_ha_state () =
   try
     Server_helpers.exec_with_new_task "resynchronise_ha_state"
       (fun __context ->
-        let pool = Helpers.get_pool () in
+        let pool = Helpers.get_pool ~__context in
         let pool_ha_enabled = Db.Pool.get_ha_enabled ~__context ~self:pool in
         let local_ha_enabled = bool_of_string (Localdb.get Constants.ha_armed) in
         match local_ha_enabled, pool_ha_enabled with
@@ -793,8 +790,6 @@ let server_init() =
     "Registering master-only http handlers", [ Startup.OnlyMaster ], (fun () -> List.iter Xapi_http.add_handler master_only_http_handlers);
     "Listening unix socket", [], listen_unix_socket;
     "Listening localhost", [], listen_localhost;
-    (* Pre-requisite for starting HA since it may temporarily use the DB cache *)
-    "Set DB mode", [], set_db_mode;
     "Checking HA configuration", [], start_ha;
        "Checking for non-HA redo-log", [], start_redo_log;
     (* It is a pre-requisite for starting db engine *)
index 567c89e3d57518c85f366c658e036a1c763c2969..9780ad33829af3192bb9f0ca800f6a8765ce7070 100644 (file)
@@ -127,7 +127,7 @@ let all (lookup: string -> string option) (list: string -> string list) ~__conte
       (* Make sure our cached idea of whether the domain is live or not is correct *)
       let vm_guest_metrics = Db.VM.get_guest_metrics ~__context ~self in
          let live = true
-               && Db.is_valid_ref vm_guest_metrics 
+               && Db.is_valid_ref __context vm_guest_metrics 
                && Db.VM_guest_metrics.get_live ~__context ~self:vm_guest_metrics in
       if live then
        dead_domains := IntSet.remove domid !dead_domains
index 50179004526ca27ae44e6eb9def28f7bf51530c3..fae9aa0daba8281970856f6c65db4e7420dedbdb 100644 (file)
@@ -1049,7 +1049,7 @@ let preconfigure_host __context localhost statevdis metadata_vdi generation =
                ignore(attach_metadata_vdi ~__context metadata_vdi);
        end;
 
-       write_uuid_to_ip_mapping ();
+       write_uuid_to_ip_mapping ~__context;
 
        let base_t = Timeouts.get_base_t ~__context in
        Localdb.put Constants.ha_base_t (string_of_int base_t)
@@ -1237,7 +1237,7 @@ let disable_internal __context =
                   nodes to self-fence if the statefile disappears. *)
                Helpers.log_exn_continue
                        "stopping HA daemon on the master after setting pool state to invalid"
-                       (fun () -> ha_stop_daemon __context (Helpers.get_localhost ())) ();
+                       (fun () -> ha_stop_daemon __context (Helpers.get_localhost ~__context)) ();
 
                (* No node may become the master automatically without the statefile so we can safely change
                   the Pool state to disabled *)
@@ -1393,7 +1393,7 @@ let enable __context heartbeat_srs configuration =
        (* Check also that any PIFs with IP information set are currently attached - it's a non-fatal
           error if they are, but we'll warn with a message *)
        let pifs_with_ip_config = List.filter (fun (_,pifr) -> pifr.API.pIF_ip_configuration_mode <> `None) pifs in
-       let not_bond_slaves = List.filter (fun (_,pifr) -> not (Db.is_valid_ref pifr.API.pIF_bond_slave_of)) pifs_with_ip_config in
+       let not_bond_slaves = List.filter (fun (_,pifr) -> not (Db.is_valid_ref __context pifr.API.pIF_bond_slave_of)) pifs_with_ip_config in
        let without_disallow_unplug = List.filter (fun (_,pifr) -> not (pifr.API.pIF_disallow_unplug || pifr.API.pIF_management)) not_bond_slaves in
        if List.length without_disallow_unplug > 0 then begin
                let pifinfo = List.map (fun (pif,pifr) -> (Db.Host.get_name_label ~__context ~self:pifr.API.pIF_host, pif, pifr)) without_disallow_unplug in
@@ -1403,7 +1403,7 @@ let enable __context heartbeat_srs configuration =
                in
                warn "Warning: A possible network anomaly was found. The following hosts possibly have storage PIFs that can be unplugged: %s"
                        (String.concat ", " bodylines);
-               ignore(Xapi_message.create ~__context ~name:Api_messages.ip_configured_pif_can_unplug ~priority:5L ~cls:`Pool ~obj_uuid:(Db.Pool.get_uuid ~__context ~self:(Helpers.get_pool ()))
+               ignore(Xapi_message.create ~__context ~name:Api_messages.ip_configured_pif_can_unplug ~priority:5L ~cls:`Pool ~obj_uuid:(Db.Pool.get_uuid ~__context ~self:(Helpers.get_pool ~__context))
                        ~body:(String.concat "\n" bodylines))
        end;
 
@@ -1526,7 +1526,7 @@ let enable __context heartbeat_srs configuration =
 
                                (* ... *)
                                (* Make sure everyone's got a fresh database *)
-                               let generation = Db_lock.with_lock (fun () -> Manifest.generation (Database.manifest (Db_backend.get_database ()))) in
+                               let generation = Db_lock.with_lock (fun () -> Manifest.generation (Database.manifest (Db_ref.get_database (Db_backend.make ())))) in
                                let errors = thread_iter_all_exns
                                        (fun host ->
                                                debug "Synchronising database with host '%s' ('%s')" (Db.Host.get_name_label ~__context ~self:host) (Ref.string_of host);
index 69650bbc078bdf65bf1a948f4e31de5f7d909ede..d9667e20352134c89a2a8dfcd74e6d715afd8e6e 100644 (file)
@@ -666,7 +666,7 @@ let emergency_ha_disable ~__context = Xapi_ha.emergency_ha_disable __context
    it really should take a backup *)
    
 let request_backup ~__context ~host ~generation ~force = 
-  if Helpers.get_localhost () <> host
+  if Helpers.get_localhost ~__context <> host
   then failwith "Forwarded to the wrong host";
   let master_address = Helpers.get_main_ip_address __context in
   Pool_db_backup.fetch_database_backup ~master_address:master_address ~pool_secret:!Xapi_globs.pool_secret
@@ -748,7 +748,7 @@ let management_reconfigure ~__context ~pif =
     Xapi_network.attach_internal ~management_interface:true ~__context ~self:net ();
     change_management_interface ~__context bridge;
   
-    Xapi_pif.update_management_flags ~__context ~host:(Helpers.get_localhost ())
+    Xapi_pif.update_management_flags ~__context ~host:(Helpers.get_localhost ~__context)
   end
 
 let management_disable ~__context = 
@@ -763,10 +763,10 @@ let management_disable ~__context =
 
   Xapi_mgmt_iface.stop ();
   (* Make sure all my PIFs are marked appropriately *)
-  Xapi_pif.update_management_flags ~__context ~host:(Helpers.get_localhost ())
+  Xapi_pif.update_management_flags ~__context ~host:(Helpers.get_localhost ~__context)
 
 let get_system_status_capabilities ~__context ~host =
-  if Helpers.get_localhost () <> host
+  if Helpers.get_localhost ~__context <> host
   then failwith "Forwarded to the wrong host";
   System_status.get_capabilities()
 
index 87cfbe784f6bb7bdae027bc91bc1a56244a98794..bbb3a6259be6114d6c0616265373bd83067c9322 100644 (file)
@@ -105,7 +105,7 @@ val abort_new_master : __context:'a -> address:string -> unit
 val update_master : __context:'a -> host:'b -> master_address:'c -> 'd
 val emergency_ha_disable : __context:'a -> unit
 val request_backup :
-  __context:'a -> host:API.ref_host -> generation:int64 -> force:bool -> unit
+  __context:Context.t -> host:API.ref_host -> generation:int64 -> force:bool -> unit
 val request_config_file_sync : __context:'a -> host:'b -> hash:string -> unit
 val syslog_config_write : string -> bool -> bool -> unit
 val syslog_reconfigure : __context:Context.t -> host:'a -> unit
@@ -123,9 +123,9 @@ val management_disable : __context:Context.t -> unit
 (** {2 (Fill in title!)} *)
 
 val get_system_status_capabilities :
-  __context:'a -> host:API.ref_host -> string
+  __context:Context.t -> host:API.ref_host -> string
 val get_diagnostic_timing_stats :
-  __context:'a -> host:'b -> (string * string) list
+  __context:Context.t -> host:'b -> (string * string) list
 val set_hostname_live :
   __context:Context.t -> host:[ `host ] Ref.t -> hostname:string -> unit
 val is_in_emergency_mode : __context:'a -> bool
index 3208b827d402323959ba660bcad058d12c602168..ea5e309befb7b177afc1544bcd351ccaa74f4151 100644 (file)
@@ -136,7 +136,7 @@ let update_host_metrics ~__context ~host ~memory_total ~memory_free =
   let last_updated = Date.of_float (Unix.gettimeofday ()) in
   let m = Db.Host.get_metrics ~__context ~self:host in
   (* Every host should always have a Host_metrics object *)
-  if Db.is_valid_ref m then begin
+  if Db.is_valid_ref __context m then begin
     Db.Host_metrics.set_memory_total ~__context ~self:m ~value:memory_total;
     Db.Host_metrics.set_memory_free ~__context ~self:m ~value:memory_free;
     Db.Host_metrics.set_last_updated ~__context ~self:m ~value:last_updated;
index 2676412ee042feee33cdce867053fc3efb7b6d22..e9a9978a200be9716b5216c1cd6d50e2bbd44e31 100644 (file)
@@ -160,13 +160,14 @@ let with_context ?(dummy=false) label (req: request) (s: Unix.file_descr) f =
     if List.mem_assoc "subtask_of" all
     then Some (Ref.of_string (List.assoc "subtask_of" all))
     else None in
+  let localhost = Server_helpers.exec_with_new_task "with_context" (fun __context -> Helpers.get_localhost ~__context) in
   try
     let session_id,must_logout = 
       if List.mem_assoc "session_id" all
       then Ref.of_string (List.assoc "session_id" all), false
       else 
            if List.mem_assoc "pool_secret" all
-           then Client.Session.slave_login inet_rpc (Helpers.get_localhost ()) (List.assoc "pool_secret" all), true
+           then Client.Session.slave_login inet_rpc localhost (List.assoc "pool_secret" all), true
            else begin
              match req.Http.auth with
                | Some (Http.Basic(username, password)) ->
index e25fdbd58e3c0dbc5292851061cc2d8f8870dd02..87179589e139a4b606551a53b62fe4ef3cb440e0 100644 (file)
@@ -81,7 +81,7 @@ let on_dom0_networking_change ~__context =
      2 Host.address
      3. Console URIs *)
   let new_hostname = Helpers.reget_hostname () in
-  let localhost = Helpers.get_localhost () in
+  let localhost = Helpers.get_localhost ~__context in
   if Db.Host.get_hostname ~__context ~self:localhost <> new_hostname then begin
     debug "Changing Host.hostname in database to: %s" new_hostname;
     Db.Host.set_hostname ~__context ~self:localhost ~value:new_hostname
index d1b63776237903e2f8e531aad0c6e0f6d0cace5f..e8be4c2c17dbdc1b5d27be72a1a65c39f1056ad8 100644 (file)
@@ -31,7 +31,7 @@ let create_internal_bridge ~bridge ~uuid =
   if not(Netdev.Link.is_up bridge) then Netdev.Link.up bridge
 
 let attach_internal ?(management_interface=false) ~__context ~self () =
-  let host = Helpers.get_localhost () in
+  let host = Helpers.get_localhost ~__context in
   let shafted_pifs, local_pifs = 
     Xapi_network_attach_helpers.assert_can_attach_network_on_host ~__context ~self ~host ~overide_management_if_check:management_interface in
 
index b306ed6b60a935d412319aba045abe31b3171920..370c12b03715394c59b3bea4825b5caa00624d2a 100644 (file)
@@ -499,7 +499,7 @@ let calculate_pifs_required_at_start_of_day ~__context =
                true &&
                pifr.API.pIF_host = localhost && (* this host only *)
                Nm.is_dom0_interface pifr &&
-               not (Db.is_valid_ref pifr.API.pIF_bond_slave_of) (* not enslaved by a bond *)
+               not (Db.is_valid_ref __context pifr.API.pIF_bond_slave_of) (* not enslaved by a bond *)
   )
     (Db.PIF.get_all_records ~__context)
 
index 78faaf09ffc2b33a3bdc1e517d3e9cf9d481ec90..49bf72ffef289ed35cf2a956b1b112e733478e17 100644 (file)
@@ -216,7 +216,7 @@ val vLAN_destroy : __context:Context.t -> self:[ `VLAN ] Ref.t -> unit
     interfaces required by storage NICs etc. (these interface are not filtered out at the moment).
  *)
 val calculate_pifs_required_at_start_of_day :
-  __context:'a -> ('b Ref.t * API.pIF_t) list
+  __context:Context.t -> ('b Ref.t * API.pIF_t) list
   
 (** Attempt to bring up (plug) the required PIFs when the host starts up.
  *  Uses {!calculate_pifs_required_at_start_of_day}. *)
index 48ab9ea3e18b6c87db1b21353dc524c906f402ee..84f8633743006e412f27746b4712e0acee624c1f 100644 (file)
@@ -815,7 +815,7 @@ let sync_database ~__context =
        then debug "flushed database to metadata VDI: assuming this is sufficient."
        else begin
         debug "flushing database to all online nodes";
-                  let generation = Db_lock.with_lock (fun () -> Manifest.generation (Database.manifest (Db_backend.get_database ()))) in
+                  let generation = Db_lock.with_lock (fun () -> Manifest.generation (Database.manifest (Db_ref.get_database (Context.database_of __context)))) in
         Threadext.thread_iter
           (fun host ->
              Helpers.call_api_functions ~__context
@@ -836,7 +836,7 @@ let designate_new_master ~__context ~host =
                let all_hosts = Db.Host.get_all ~__context in
                (* We make no attempt to demand a quorum or anything. *)
                let addresses = List.map (fun self -> Db.Host.get_address ~__context ~self) all_hosts in
-               let my_address = Db.Host.get_address ~__context ~self:(Helpers.get_localhost ()) in
+               let my_address = Db.Host.get_address ~__context ~self:(Helpers.get_localhost ~__context) in
                let peers = List.filter (fun x -> x <> my_address) addresses in
                Xapi_pool_transition.attempt_two_phase_commit_of_new_master ~__context true peers my_address
        end
@@ -850,7 +850,7 @@ let is_slave ~__context ~host =
   let is_slave = not (Pool_role.is_master ()) in
   info "Pool.is_slave call received (I'm a %s)" (if is_slave then "slave" else "master");
   debug "About to kick the database connection to make sure it's still working...";
-  Db.is_valid_ref (Ref.of_string "Pool.is_slave checking to see if the database connection is up");
+  Db.is_valid_ref __context (Ref.of_string "Pool.is_slave checking to see if the database connection is up");
   is_slave
 
 let hello ~__context ~host_uuid ~host_address =
index 8f0352432e647ad1023d785f00949ccf77b39c3a..51d2b01b54dfc9946a0722d67df3d2f564a9e158 100644 (file)
@@ -80,7 +80,7 @@ val sync_m : Threadext.Mutex.t
 val sync_database : __context:Context.t -> unit
 val designate_new_master : __context:Context.t -> host:'a -> unit
 val initial_auth : __context:'a -> string
-val is_slave : __context:'a -> host:'b -> bool
+val is_slave : __context:Context.t -> host:'b -> bool
 val hello :
   __context:Context.t ->
   host_uuid:string ->
index 67416ba5ba325fb98a72e95b596e588b1f732575..5404451d898ab0b75baa31b3b42adf4ca0c1a192 100644 (file)
@@ -263,7 +263,7 @@ let destroy  ~__context ~self =
        let metrics = Db.VIF.get_metrics ~__context ~self in
        (* Don't let a failure to destroy the metrics stop us *)
        Helpers.log_exn_continue "VIF_metrics.destroy" 
-         (fun self -> if Db.is_valid_ref self then Db.VIF_metrics.destroy ~__context ~self) metrics;
+         (fun self -> if Db.is_valid_ref __context self then Db.VIF_metrics.destroy ~__context ~self) metrics;
        
        Db.VIF.destroy ~__context ~self
 
index 7ae0ff1e8a275fa4ec41b52b01b4f0a61b183dfb..d0d906247683bdf5747f0674810662ae1d9cb637 100644 (file)
@@ -670,7 +670,7 @@ let power_state_reset ~__context ~vm =
   let power_state = Db.VM.get_power_state ~__context ~self:vm in
   if power_state = `Running || power_state = `Paused then begin
     debug "VM.power_state_reset vm=%s power state is either running or paused: performing sanity checks" (Ref.string_of vm);
-    let localhost = Helpers.get_localhost () in
+    let localhost = Helpers.get_localhost ~__context in
     (* We only query domid, resident_on and Xc.domain_getinfo with the VM lock held to make
        sure the VM isn't in the middle of a migrate/reboot/shutdown. Note we don't hold it for
        the whole of this function which might perform off-box RPCs. *)
@@ -925,7 +925,7 @@ let snapshot_with_quiesce ~__context ~vm ~new_name =
 let revert ~__context ~snapshot =
        let vm = Db.VM.get_snapshot_of ~__context ~self:snapshot in
        let vm = 
-               if Db.is_valid_ref vm 
+               if Db.is_valid_ref __context vm 
                then vm
                else Xapi_vm_snapshot.create_vm_from_snapshot ~__context ~snapshot in
        Xapi_vm_snapshot.revert ~__context ~snapshot ~vm
index f7b629342a78f7b00a3bdba987975830ed05de6a..2f4f94bcfee0b12dd1a2b9066a483a276df0e701 100644 (file)
@@ -88,7 +88,7 @@ val unpause : __context:Context.t -> vm:API.ref_VM -> unit
 val start :
   __context:Context.t ->
   vm:API.ref_VM -> start_paused:bool -> force:'a -> unit
-val assert_host_is_localhost : __context:'a -> host:API.ref_host -> unit
+val assert_host_is_localhost : __context:Context.t -> host:API.ref_host -> unit
 val start_on :
   __context:Context.t ->
   vm:API.ref_VM -> host:API.ref_host -> start_paused:bool -> force:'a -> unit
@@ -213,14 +213,14 @@ val set_memory_dynamic_range :
 val set_memory_target_live :
   __context:'a -> self:API.ref_VM -> target:'b -> unit
 val wait_memory_target_live : __context:Context.t -> self:API.ref_VM -> unit
-val get_cooperative : __context:'a -> self:[ `VM ] Ref.t -> bool
+val get_cooperative : __context:Context.t -> self:[ `VM ] Ref.t -> bool
 val set_HVM_shadow_multiplier :
   __context:Context.t -> self:[ `VM ] Ref.t -> value:float -> unit
 val set_shadow_multiplier_live :
   __context:Context.t -> self:API.ref_VM -> multiplier:float -> unit
 val send_sysrq : __context:'a -> vm:API.ref_VM -> key:'b -> 'c
 val send_trigger : __context:'a -> vm:API.ref_VM -> trigger:'b -> 'c
-val get_boot_record : __context:'a -> self:API.ref_VM -> API.vM_t
+val get_boot_record : __context:Context.t -> self:API.ref_VM -> API.vM_t
 val get_data_sources :
   __context:Context.t -> self:[ `VM ] Ref.t -> API.data_source_t list
 val record_data_source :
index a823e797173c19cd982ef36b8874ee13ab636f86..ca15d2c6c4a4a6801fac0cd09516faa8710c3c82 100644 (file)
@@ -165,9 +165,9 @@ let snapshot_info ~power_state ~is_a_snapshot =
        else
                []
 
-let snapshot_metadata ~vm ~is_a_snapshot =
+let snapshot_metadata ~__context ~vm ~is_a_snapshot =
        if is_a_snapshot then
-               Helpers.vm_to_string vm
+               Helpers.vm_to_string __context vm
        else
                ""
 
@@ -212,7 +212,7 @@ let copy_vm_record ~__context ~vm ~disk_op ~new_name ~new_power_state =
        in
        (* Copy the old metrics if available, otherwise generate a fresh one *)
        let m =
-               if Db.is_valid_ref all.Db_actions.vM_metrics
+               if Db.is_valid_ref __context all.Db_actions.vM_metrics
                then Some (Db.VM_metrics.get_record_internal ~__context ~self:all.Db_actions.vM_metrics)
                else None
        in
@@ -268,7 +268,7 @@ let copy_vm_record ~__context ~vm ~disk_op ~new_name ~new_power_state =
                ~snapshot_of:(if is_a_snapshot then vm else Ref.null)
                ~snapshot_time:(if is_a_snapshot then Date.of_float (Unix.gettimeofday ()) else Date.never)
                ~snapshot_info:(snapshot_info ~power_state ~is_a_snapshot)
-               ~snapshot_metadata:(snapshot_metadata ~vm ~is_a_snapshot)
+               ~snapshot_metadata:(snapshot_metadata ~__context ~vm ~is_a_snapshot)
                ~transportable_snapshot_id:""
                ~parent
                ~resident_on:Ref.null
index 8263da4b17a4ec1e0e299b8a8ee1e471c2ac9900..aa739ec2bac795c67ba8415b7a9a260801b5c47d 100644 (file)
@@ -317,7 +317,7 @@ let check_operation_error ~vmr ~vmgmr ~ref ~clone_suspended_vm_enabled vdis_rese
        current_error
 
 let maybe_get_guest_metrics ~__context ~ref =
-       if Db.is_valid_ref ref
+       if Db.is_valid_ref __context ref
        then Some (Db.VM_guest_metrics.get_record_internal ~__context ~self:ref)
        else None
 
index 9878b5b8c1c9e8a0cae3c2d5da870ad0c47d7c08..f7a6375887888c4d7b3db2f117b27a0fea9cc98c 100644 (file)
@@ -75,7 +75,7 @@ let create_pool_snapshot_summary __context extra_guests pool =
 (** Returns a list of affinity host identifiers for the given [guest]. *)
 let affinity_host_ids_of_guest __context guest =
        let affinity_host = Db.VM.get_affinity ~__context ~self:guest in
-       let affinity_host_is_valid = Db.is_valid_ref affinity_host in
+       let affinity_host_is_valid = Db.is_valid_ref __context affinity_host in
        if affinity_host_is_valid
                then [Db.Host.get_uuid __context affinity_host]
                else []
index b12b0a5e0a50c24236f3e781105638054b98b36a..db66e0f4f09b8ecc3d087d484cbdaaf8f56ec2a4 100644 (file)
@@ -215,7 +215,7 @@ let checkpoint ~__context ~vm ~new_name =
                        try 
                                let suspend_VDI = Db.VM.get_suspend_VDI ~__context ~self:vm in
                                Vmops.resume ~__context ~xc ~xs ~vm;
-                               if Db.is_valid_ref suspend_VDI then begin
+                               if Db.is_valid_ref __context suspend_VDI then begin
                                        Db.VM.set_suspend_VDI ~__context ~self:vm ~value:Ref.null;
                                        Helpers.call_api_functions ~__context (fun rpc session_id -> Client.VDI.destroy rpc session_id suspend_VDI);
                                end;
@@ -249,7 +249,8 @@ let checkpoint ~__context ~vm ~new_name =
 let copy_vm_fields ~__context ~metadata ~dst ~do_not_copy ~default_values =
        assert (Pool_role.is_master ());
        debug "copying metadata into %s" (Ref.string_of dst);
-       let module DB = (val (Db_cache.get ()) : Db_interface.DB_ACCESS) in
+       let db = Context.database_of __context in
+       let module DB = (val (Db_cache.get db) : Db_interface.DB_ACCESS) in
        List.iter
                (fun (key,value) -> 
                        let value = 
@@ -257,21 +258,21 @@ let copy_vm_fields ~__context ~metadata ~dst ~do_not_copy ~default_values =
                                then List.assoc key default_values
                                else value in
                         if not (List.mem key do_not_copy)
-                        then DB.write_field Db_names.vm (Ref.string_of dst) key value)
+                        then DB.write_field db Db_names.vm (Ref.string_of dst) key value)
                metadata
                
 let safe_destroy_vbd ~__context ~rpc ~session_id vbd =
-       if Db.is_valid_ref vbd then begin
+       if Db.is_valid_ref __context vbd then begin
                Client.VBD.destroy rpc session_id vbd
        end
 
 let safe_destroy_vif ~__context ~rpc ~session_id vif =
-       if Db.is_valid_ref vif then begin
+       if Db.is_valid_ref __context vif then begin
                Client.VIF.destroy rpc session_id vif
        end
 
 let safe_destroy_vdi ~__context ~rpc ~session_id vdi =
-       if Db.is_valid_ref vdi then begin
+       if Db.is_valid_ref __context vdi then begin
                let sr = Db.VDI.get_SR ~__context ~self:vdi in
                if not (Db.SR.get_content_type ~__context ~self:sr = "iso") then
                        Client.VDI.destroy rpc session_id vdi
@@ -337,8 +338,8 @@ let update_guest_metrics ~__context ~vm ~snapshot =
        let vm_gm = Db.VM.get_guest_metrics ~__context ~self:vm in
 
        debug "Reverting the guest metrics";
-       if Db.is_valid_ref vm_gm then Db.VM_guest_metrics.destroy ~__context ~self:vm_gm;
-       if Db.is_valid_ref snap_gm then begin
+       if Db.is_valid_ref __context vm_gm then Db.VM_guest_metrics.destroy ~__context ~self:vm_gm;
+       if Db.is_valid_ref __context snap_gm then begin
                let new_gm = Xapi_vm_helpers.copy_guest_metrics ~__context ~vm:snapshot in
                Db.VM.set_guest_metrics ~__context ~self:vm ~value:new_gm
        end
@@ -381,7 +382,7 @@ let revert_vm_fields ~__context ~snapshot ~vm =
        let snap_metadata =
                if post_MNR
                then Helpers.vm_string_to_assoc snap_metadata 
-               else Helpers.vm_string_to_assoc (Helpers.vm_to_string snapshot) in
+               else Helpers.vm_string_to_assoc (Helpers.vm_to_string __context snapshot) in
        let do_not_copy =
                if post_MNR
                then do_not_copy
index 8849b28a93147e81b13e0b8b4af59b623b508e0a..f371fe29013a4f3f3026dd11c6c23865df70eb89 100644 (file)
@@ -68,6 +68,6 @@ open Pervasiveext
 (** Attempt to flush the database to the metadata VDI *)
 let flush_database ~__context = 
   try
-    Redo_log.flush_db_to_redo_log (Db_backend.get_database ());
+    Redo_log.flush_db_to_redo_log (Db_ref.get_database (Db_backend.make ()));
     true
   with _ -> false