[lxc-devel] [lxd/master] RFC: initial cluster database plumbing

freeekanayaka on Github lxc-bot at linuxcontainers.org
Fri Nov 10 11:25:48 UTC 2017


A non-text attachment was scrubbed...
Name: not available
Type: text/x-mailbox
Size: 908 bytes
Desc: not available
URL: <http://lists.linuxcontainers.org/pipermail/lxc-devel/attachments/20171110/184dde98/attachment.bin>
-------------- next part --------------
From 069200d4575d9054ab8180605095b8f7b4f98b32 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Thu, 14 Sep 2017 13:13:55 +0000
Subject: [PATCH 01/14] Add raft_nodes table

This new table is meant to hold addresses of LXD nodes that are
partecipating to the dqlite raft cluster. Each node in the cluster
will hold its own local copy of this table, regardless of whether it's
a raft node or not.

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/db/node/schema.go          |  7 ++++++-
 lxd/db/node/update.go          | 31 +++++++++++++++++++++++++++++++
 lxd/db/node/update_test.go     | 17 +++++++++++++++++
 test/suites/database_update.sh |  2 +-
 4 files changed, 55 insertions(+), 2 deletions(-)
 create mode 100644 lxd/db/node/update_test.go

diff --git a/lxd/db/node/schema.go b/lxd/db/node/schema.go
index cbf863e1c..a9754eeaa 100644
--- a/lxd/db/node/schema.go
+++ b/lxd/db/node/schema.go
@@ -155,6 +155,11 @@ CREATE TABLE profiles_devices_config (
     UNIQUE (profile_device_id, key),
     FOREIGN KEY (profile_device_id) REFERENCES profiles_devices (id) ON DELETE CASCADE
 );
+CREATE TABLE raft_nodes (
+    id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+    address TEXT NOT NULL,
+    UNIQUE (address)
+);
 CREATE TABLE storage_pools (
     id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
     name VARCHAR(255) NOT NULL,
@@ -188,5 +193,5 @@ CREATE TABLE storage_volumes_config (
     FOREIGN KEY (storage_volume_id) REFERENCES storage_volumes (id) ON DELETE CASCADE
 );
 
-INSERT INTO schema (version, updated_at) VALUES (36, strftime("%s"))
+INSERT INTO schema (version, updated_at) VALUES (37, strftime("%s"))
 `
diff --git a/lxd/db/node/update.go b/lxd/db/node/update.go
index 299a645e4..95a660202 100644
--- a/lxd/db/node/update.go
+++ b/lxd/db/node/update.go
@@ -84,9 +84,40 @@ var updates = map[int]schema.Update{
 	34: updateFromV33,
 	35: updateFromV34,
 	36: updateFromV35,
+	37: updateFromV36,
 }
 
 // Schema updates begin here
+
+// Add a raft_nodes table to be used when running in clustered mode. It lists
+// the current nodes in the LXD cluster that are participating to the dqlite
+// database Raft cluster.
+//
+// The 'id' column contains the raft server ID of the database node, and the
+// 'address' column its network address. Both are used internally by the raft
+// Go package to manage the cluster.
+//
+// Typical setups will have 3 LXD cluster nodes that participate to the dqlite
+// database Raft cluster, and an arbitrary number of additional LXD cluster
+// nodes that don't. Non-database nodes are not tracked in this table, but rather
+// in the nodes table of the cluster database itself.
+//
+// The data in this table must be replicated by LXD on all nodes of the
+// cluster, regardless of whether they are part of the raft cluster or not, and
+// all nodes will consult this table when they need to find out a leader to
+// send SQL queries to.
+func updateFromV36(tx *sql.Tx) error {
+	stmts := `
+CREATE TABLE raft_nodes (
+    id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+    address TEXT NOT NULL,
+    UNIQUE (address)
+);
+`
+	_, err := tx.Exec(stmts)
+	return err
+}
+
 func updateFromV35(tx *sql.Tx) error {
 	stmts := `
 CREATE TABLE tmp (
diff --git a/lxd/db/node/update_test.go b/lxd/db/node/update_test.go
new file mode 100644
index 000000000..980ef8bf3
--- /dev/null
+++ b/lxd/db/node/update_test.go
@@ -0,0 +1,17 @@
+package node_test
+
+import (
+	"testing"
+
+	"github.com/lxc/lxd/lxd/db/node"
+	"github.com/stretchr/testify/require"
+)
+
+func TestUpdateFromV36(t *testing.T) {
+	schema := node.Schema()
+	db, err := schema.ExerciseUpdate(37, nil)
+	require.NoError(t, err)
+
+	_, err = db.Exec("INSERT INTO raft_nodes VALUES (1, '1.2.3.4:666')")
+	require.NoError(t, err)
+}
diff --git a/test/suites/database_update.sh b/test/suites/database_update.sh
index 7b3737486..15189bd2f 100644
--- a/test/suites/database_update.sh
+++ b/test/suites/database_update.sh
@@ -9,7 +9,7 @@ test_database_update(){
   spawn_lxd "${LXD_MIGRATE_DIR}" true
 
   # Assert there are enough tables.
-  expected_tables=23
+  expected_tables=24
   tables=$(sqlite3 "${MIGRATE_DB}" ".dump" | grep -c "CREATE TABLE")
   [ "${tables}" -eq "${expected_tables}" ] || { echo "FAIL: Wrong number of tables after database migration. Found: ${tables}, expected ${expected_tables}"; false; }
 

From e47063a1e7a55d81dc33e403c709d44c26c003c6 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Wed, 11 Oct 2017 15:05:22 +0000
Subject: [PATCH 02/14] Add query helpers to select and insert complex objects

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/db/query/objects.go      |  85 ++++++++++++++++++++
 lxd/db/query/objects_test.go | 187 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 272 insertions(+)
 create mode 100644 lxd/db/query/objects.go
 create mode 100644 lxd/db/query/objects_test.go

diff --git a/lxd/db/query/objects.go b/lxd/db/query/objects.go
new file mode 100644
index 000000000..c8e8a19f9
--- /dev/null
+++ b/lxd/db/query/objects.go
@@ -0,0 +1,85 @@
+package query
+
+import (
+	"database/sql"
+	"fmt"
+	"strings"
+)
+
+// SelectObjects executes a statement which must yield rows with a specific
+// columns schema. It invokes the given Dest hook for each yielded row.
+func SelectObjects(tx *sql.Tx, dest Dest, query string, args ...interface{}) error {
+	rows, err := tx.Query(query, args...)
+	if err != nil {
+		return err
+	}
+	defer rows.Close()
+
+	for i := 0; rows.Next(); i++ {
+		err := rows.Scan(dest(i)...)
+		if err != nil {
+			return err
+		}
+	}
+
+	err = rows.Err()
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+// Dest is a function that is expected to return the objects to pass to the
+// 'dest' argument of sql.Rows.Scan(). It is invoked by SelectObjects once per
+// yielded row, and it will be passed the index of the row being scanned.
+type Dest func(i int) []interface{}
+
+// UpsertObject inserts or replaces a new row with the given column values, to
+// the given table using columns order. For example:
+//
+// UpsertObject(tx, "cars", []string{"id", "brand"}, []interface{}{1, "ferrari"})
+//
+// The number of elements in 'columns' must match the one in 'values'.
+func UpsertObject(tx *sql.Tx, table string, columns []string, values []interface{}) (int64, error) {
+	n := len(columns)
+	if n == 0 {
+		return -1, fmt.Errorf("columns lenght is zero")
+	}
+	if n != len(values) {
+		return -1, fmt.Errorf("columns lenght does not match values lenght")
+	}
+
+	stmt := fmt.Sprintf(
+		"INSERT OR REPLACE INTO %s (%s) VALUES %s",
+		table, strings.Join(columns, ", "), exprParams(n))
+	result, err := tx.Exec(stmt, values...)
+	if err != nil {
+		return -1, err
+	}
+	id, err := result.LastInsertId()
+	if err != nil {
+		return -1, err
+	}
+	return id, nil
+}
+
+// DeleteObject removes the row identified by the given ID. The given table
+// must have a primary key column called 'id'.
+//
+// It returns a flag indicating if a matching row was actually found and
+// deleted or not.
+func DeleteObject(tx *sql.Tx, table string, id int64) (bool, error) {
+	stmt := fmt.Sprintf("DELETE FROM %s WHERE id=?", table)
+	result, err := tx.Exec(stmt, id)
+	if err != nil {
+		return false, err
+	}
+	n, err := result.RowsAffected()
+	if err != nil {
+		return false, err
+	}
+	if n > 1 {
+		return true, fmt.Errorf("more than one row was deleted")
+	}
+	return n == 1, nil
+}
diff --git a/lxd/db/query/objects_test.go b/lxd/db/query/objects_test.go
new file mode 100644
index 000000000..e19264909
--- /dev/null
+++ b/lxd/db/query/objects_test.go
@@ -0,0 +1,187 @@
+package query_test
+
+import (
+	"database/sql"
+	"testing"
+
+	"github.com/lxc/lxd/lxd/db/query"
+	"github.com/mpvl/subtest"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+// Exercise possible failure modes.
+func TestSelectObjects_Error(t *testing.T) {
+	cases := []struct {
+		dest  query.Dest
+		query string
+		error string
+	}{
+		{
+			func(int) []interface{} { return nil },
+			"garbage",
+			"near \"garbage\": syntax error",
+		},
+		{
+			func(int) []interface{} { return make([]interface{}, 1) },
+			"SELECT id, name FROM test",
+			"sql: expected 2 destination arguments in Scan, not 1",
+		},
+	}
+	for _, c := range cases {
+		subtest.Run(t, c.query, func(t *testing.T) {
+			tx := newTxForObjects(t)
+			err := query.SelectObjects(tx, c.dest, c.query)
+			assert.EqualError(t, err, c.error)
+		})
+	}
+}
+
+// Scan rows yielded by the query.
+func TestSelectObjects(t *testing.T) {
+	tx := newTxForObjects(t)
+	objects := make([]struct {
+		ID   int
+		Name string
+	}, 1)
+	object := objects[0]
+
+	dest := func(i int) []interface{} {
+		require.Equal(t, 0, i, "expected at most one row to be yielded")
+		return []interface{}{&object.ID, &object.Name}
+	}
+
+	stmt := "SELECT id, name FROM test WHERE name=?"
+	err := query.SelectObjects(tx, dest, stmt, "bar")
+	require.NoError(t, err)
+
+	assert.Equal(t, 1, object.ID)
+	assert.Equal(t, "bar", object.Name)
+}
+
+// Exercise possible failure modes.
+func TestUpsertObject_Error(t *testing.T) {
+	cases := []struct {
+		columns []string
+		values  []interface{}
+		error   string
+	}{
+		{
+			[]string{},
+			[]interface{}{},
+			"columns lenght is zero",
+		},
+		{
+			[]string{"id"},
+			[]interface{}{2, "egg"},
+			"columns lenght does not match values lenght",
+		},
+	}
+	for _, c := range cases {
+		subtest.Run(t, c.error, func(t *testing.T) {
+			tx := newTxForObjects(t)
+			id, err := query.UpsertObject(tx, "foo", c.columns, c.values)
+			assert.Equal(t, int64(-1), id)
+			assert.EqualError(t, err, c.error)
+		})
+	}
+}
+
+// Insert a new row.
+func TestUpsertObject_Insert(t *testing.T) {
+	tx := newTxForObjects(t)
+
+	id, err := query.UpsertObject(tx, "test", []string{"name"}, []interface{}{"egg"})
+	require.NoError(t, err)
+	assert.Equal(t, int64(2), id)
+
+	objects := make([]struct {
+		ID   int
+		Name string
+	}, 1)
+	object := objects[0]
+
+	dest := func(i int) []interface{} {
+		require.Equal(t, 0, i, "expected at most one row to be yielded")
+		return []interface{}{&object.ID, &object.Name}
+	}
+
+	stmt := "SELECT id, name FROM test WHERE name=?"
+	err = query.SelectObjects(tx, dest, stmt, "egg")
+	require.NoError(t, err)
+
+	assert.Equal(t, 2, object.ID)
+	assert.Equal(t, "egg", object.Name)
+}
+
+// Update an existing row.
+func TestUpsertObject_Update(t *testing.T) {
+	tx := newTxForObjects(t)
+
+	id, err := query.UpsertObject(tx, "test", []string{"id", "name"}, []interface{}{1, "egg"})
+	require.NoError(t, err)
+	assert.Equal(t, int64(1), id)
+
+	objects := make([]struct {
+		ID   int
+		Name string
+	}, 1)
+	object := objects[0]
+
+	dest := func(i int) []interface{} {
+		require.Equal(t, 0, i, "expected at most one row to be yielded")
+		return []interface{}{&object.ID, &object.Name}
+	}
+
+	stmt := "SELECT id, name FROM test WHERE name=?"
+	err = query.SelectObjects(tx, dest, stmt, "egg")
+	require.NoError(t, err)
+
+	assert.Equal(t, 1, object.ID)
+	assert.Equal(t, "egg", object.Name)
+}
+
+// Exercise possible failure modes.
+func TestDeleteObject_Error(t *testing.T) {
+	tx := newTxForObjects(t)
+
+	deleted, err := query.DeleteObject(tx, "foo", 1)
+	assert.False(t, deleted)
+	assert.EqualError(t, err, "no such table: foo")
+}
+
+// If an row was actually deleted, the returned flag is true.
+func TestDeleteObject_Deleted(t *testing.T) {
+	tx := newTxForObjects(t)
+
+	deleted, err := query.DeleteObject(tx, "test", 1)
+	assert.True(t, deleted)
+	assert.NoError(t, err)
+}
+
+// If no row was actually deleted, the returned flag is false.
+func TestDeleteObject_NotDeleted(t *testing.T) {
+	tx := newTxForObjects(t)
+
+	deleted, err := query.DeleteObject(tx, "test", 1000)
+	assert.False(t, deleted)
+	assert.NoError(t, err)
+}
+
+// Return a new transaction against an in-memory SQLite database with a single
+// test table populated with a few rows for testing object-related queries.
+func newTxForObjects(t *testing.T) *sql.Tx {
+	db, err := sql.Open("sqlite3", ":memory:")
+	assert.NoError(t, err)
+
+	_, err = db.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)")
+	assert.NoError(t, err)
+
+	_, err = db.Exec("INSERT INTO test VALUES (0, 'foo'), (1, 'bar')")
+	assert.NoError(t, err)
+
+	tx, err := db.Begin()
+	assert.NoError(t, err)
+
+	return tx
+}

From e477a87356ae4bd9d8082d8b5740963bad00b5e6 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Fri, 15 Sep 2017 07:23:30 +0000
Subject: [PATCH 03/14] Add InsertStrings helper to insert rows with a single
 string value

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/db/query/slices.go      | 25 +++++++++++++++++++++++++
 lxd/db/query/slices_test.go | 13 +++++++++++++
 2 files changed, 38 insertions(+)

diff --git a/lxd/db/query/slices.go b/lxd/db/query/slices.go
index 4e58126c7..59d0cc892 100644
--- a/lxd/db/query/slices.go
+++ b/lxd/db/query/slices.go
@@ -2,6 +2,8 @@ package query
 
 import (
 	"database/sql"
+	"fmt"
+	"strings"
 )
 
 // SelectStrings executes a statement which must yield rows with a single string
@@ -48,6 +50,29 @@ func SelectIntegers(tx *sql.Tx, query string) ([]int, error) {
 	return values, nil
 }
 
+// InsertStrings inserts a new row for each of the given strings, using the
+// given insert statement template, which must define exactly one insertion
+// column and one substitution placeholder for the values. For example:
+// InsertStrings(tx, "INSERT INTO foo(name) VALUES %s", []string{"bar"}).
+func InsertStrings(tx *sql.Tx, stmt string, values []string) error {
+	n := len(values)
+
+	if n == 0 {
+		return nil
+	}
+
+	params := make([]string, n)
+	args := make([]interface{}, n)
+	for i, value := range values {
+		params[i] = "(?)"
+		args[i] = value
+	}
+
+	stmt = fmt.Sprintf(stmt, strings.Join(params, ", "))
+	_, err := tx.Exec(stmt, args...)
+	return err
+}
+
 // Execute the given query and ensure that it yields rows with a single column
 // of the given database type. For every row yielded, execute the given
 // scanner.
diff --git a/lxd/db/query/slices_test.go b/lxd/db/query/slices_test.go
index f5bb6549a..36e31a5b9 100644
--- a/lxd/db/query/slices_test.go
+++ b/lxd/db/query/slices_test.go
@@ -6,6 +6,7 @@ import (
 
 	_ "github.com/mattn/go-sqlite3"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"github.com/lxc/lxd/lxd/db/query"
 	"github.com/lxc/lxd/shared/subtest"
@@ -51,6 +52,18 @@ func TestIntegers(t *testing.T) {
 	assert.Equal(t, []int{0, 1}, values)
 }
 
+// Insert new rows in bulk.
+func TestInsertStrings(t *testing.T) {
+	tx := newTxForSlices(t)
+
+	err := query.InsertStrings(tx, "INSERT INTO test(name) VALUES %s", []string{"xx", "yy"})
+	require.NoError(t, err)
+
+	values, err := query.SelectStrings(tx, "SELECT name FROM test ORDER BY name DESC LIMIT 2")
+	require.NoError(t, err)
+	assert.Equal(t, values, []string{"yy", "xx"})
+}
+
 // Return a new transaction against an in-memory SQLite database with a single
 // test table populated with a few rows.
 func newTxForSlices(t *testing.T) *sql.Tx {

From 0dacdcb10c5b7e69601738b23236ea3c35d68f96 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Sun, 1 Oct 2017 17:13:52 +0000
Subject: [PATCH 04/14] Add util.InMemoryNetwork to create in-memory
 listener/dialer pairs.

This is a convenience for creating in-memory networks that implement
the net.Conn interface. It will be used when running a node in
non-clustered mode, where there will be no actual TCP/gRCP connection
to an external dqlite node, but rather just an in-memory connection to
the local dqlite instance (which will be the leader).

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/util/net.go      | 47 +++++++++++++++++++++++++++++++++++++++++++++++
 lxd/util/net_test.go | 18 ++++++++++++++++++
 2 files changed, 65 insertions(+)

diff --git a/lxd/util/net.go b/lxd/util/net.go
index 1f96f27f4..38d6acb8b 100644
--- a/lxd/util/net.go
+++ b/lxd/util/net.go
@@ -7,6 +7,53 @@ import (
 	"github.com/lxc/lxd/shared"
 )
 
+// InMemoryNetwork creates a fully in-memory listener and dial function.
+//
+// Each time the dial function is invoked a new pair of net.Conn objects will
+// be created using net.Pipe: the listener's Accept method will unblock and
+// return one end of the pipe and the other end will be returned by the dial
+// function.
+func InMemoryNetwork() (net.Listener, func() net.Conn) {
+	listener := &inMemoryListener{conns: make(chan net.Conn, 16)}
+	dialer := func() net.Conn {
+		server, client := net.Pipe()
+		listener.conns <- server
+		return client
+	}
+	return listener, dialer
+}
+
+type inMemoryListener struct {
+	conns chan net.Conn
+}
+
+// Accept waits for and returns the next connection to the listener.
+func (l *inMemoryListener) Accept() (net.Conn, error) {
+	return <-l.conns, nil
+}
+
+// Close closes the listener.
+// Any blocked Accept operations will be unblocked and return errors.
+func (l *inMemoryListener) Close() error {
+	return nil
+}
+
+// Addr returns the listener's network address.
+func (l *inMemoryListener) Addr() net.Addr {
+	return &inMemoryAddr{}
+}
+
+type inMemoryAddr struct {
+}
+
+func (a *inMemoryAddr) Network() string {
+	return "memory"
+}
+
+func (a *inMemoryAddr) String() string {
+	return ""
+}
+
 // CanonicalNetworkAddress parses the given network address and returns a
 // string of the form "host:port", possibly filling it with the default port if
 // it's missing.
diff --git a/lxd/util/net_test.go b/lxd/util/net_test.go
index 0b29eb576..a56581464 100644
--- a/lxd/util/net_test.go
+++ b/lxd/util/net_test.go
@@ -6,8 +6,26 @@ import (
 	"github.com/lxc/lxd/lxd/util"
 	"github.com/mpvl/subtest"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
+// The connection returned by the dialer is paired with the one returned by the
+// Accept() method of the listener.
+func TestInMemoryNetwork(t *testing.T) {
+	listener, dialer := util.InMemoryNetwork()
+	client := dialer()
+	server, err := listener.Accept()
+	require.NoError(t, err)
+
+	go client.Write([]byte("hello"))
+	buffer := make([]byte, 5)
+	n, err := server.Read(buffer)
+	require.NoError(t, err)
+
+	assert.Equal(t, 5, n)
+	assert.Equal(t, []byte("hello"), buffer)
+}
+
 func TestCanonicalNetworkAddress(t *testing.T) {
 	cases := map[string]string{
 		"127.0.0.1":                             "127.0.0.1:8443",

From 29d862b03ef5eb919c3e95ef2532db85579b9093 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Mon, 11 Sep 2017 12:19:44 +0000
Subject: [PATCH 05/14] Add db.Cluster with basic initialization

A new Cluster structure has been added to the lxd/db sub-package. It
is meant to mediate access to the dqlite-based cluster database. It
uses the go-grpc-sql package to serialize SQL queries over a gRPC
connection against the target dqlite leader node.

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/db/cluster/open.go   | 48 ++++++++++++++++++++++++++++++++++
 lxd/db/db.go             | 38 ++++++++++++++++++++++++++-
 lxd/db/db_export_test.go |  9 +++++++
 lxd/db/db_test.go        | 16 ++++++++++++
 lxd/db/testing.go        | 67 ++++++++++++++++++++++++++++++++++++++++++++++++
 lxd/db/transaction.go    | 18 +++++++++++++
 6 files changed, 195 insertions(+), 1 deletion(-)
 create mode 100644 lxd/db/cluster/open.go
 create mode 100644 lxd/db/db_export_test.go

diff --git a/lxd/db/cluster/open.go b/lxd/db/cluster/open.go
new file mode 100644
index 000000000..d135dea6f
--- /dev/null
+++ b/lxd/db/cluster/open.go
@@ -0,0 +1,48 @@
+package cluster
+
+import (
+	"database/sql"
+	"fmt"
+	"sync/atomic"
+
+	"github.com/CanonicalLtd/go-grpc-sql"
+)
+
+// Open the cluster database object.
+//
+// The name argument is the name of the cluster database. It defaults to
+// 'db.bin', but can be overwritten for testing.
+//
+// The dialer argument is a function that returns a gRPC dialer that can be
+// used to connect to a database node using the gRPC SQL package.
+func Open(name string, dialer grpcsql.Dialer) (*sql.DB, error) {
+	driver := grpcsql.NewDriver(dialer)
+	driverName := grpcSQLDriverName()
+	sql.Register(driverName, driver)
+
+	// Create the cluster db. This won't immediately establish any gRPC
+	// connection, that will happen only when a db transaction is started
+	// (see the database/sql connection pooling code for more details).
+	if name == "" {
+		name = "db.bin"
+	}
+	db, err := sql.Open(driverName, name)
+	if err != nil {
+		return nil, fmt.Errorf("cannot open cluster database: %v", err)
+	}
+
+	return db, nil
+}
+
+// Generate a new name for the grpcsql driver registration. We need it to be
+// unique for testing, see below.
+func grpcSQLDriverName() string {
+	defer atomic.AddUint64(&grpcSQLDriverSerial, 1)
+	return fmt.Sprintf("grpc-%d", grpcSQLDriverSerial)
+}
+
+// Monotonic serial number for registering new instances of grpcsql.Driver
+// using the database/sql stdlib package. This is needed since there's no way
+// to unregister drivers, and in unit tests more than one driver gets
+// registered.
+var grpcSQLDriverSerial uint64
diff --git a/lxd/db/db.go b/lxd/db/db.go
index c43bba0f3..f1eae9653 100644
--- a/lxd/db/db.go
+++ b/lxd/db/db.go
@@ -5,8 +5,10 @@ import (
 	"fmt"
 	"time"
 
+	grpcsql "github.com/CanonicalLtd/go-grpc-sql"
 	"github.com/mattn/go-sqlite3"
 
+	"github.com/lxc/lxd/lxd/db/cluster"
 	"github.com/lxc/lxd/lxd/db/node"
 	"github.com/lxc/lxd/lxd/db/query"
 	"github.com/lxc/lxd/shared/logger"
@@ -30,7 +32,6 @@ var (
 // Node mediates access to LXD's data stored in the node-local SQLite database.
 type Node struct {
 	db *sql.DB // Handle to the node-local SQLite database file.
-
 }
 
 // OpenNode creates a new Node object.
@@ -111,6 +112,41 @@ func (n *Node) Begin() (*sql.Tx, error) {
 	return begin(n.db)
 }
 
+// Cluster mediates access to LXD's data stored in the cluster dqlite database.
+type Cluster struct {
+	db *sql.DB // Handle to the cluster dqlite database, gated behind gRPC SQL.
+}
+
+// OpenCluster creates a new Cluster object for interacting with the dqlite
+// database.
+func OpenCluster(name string, dialer grpcsql.Dialer) (*Cluster, error) {
+	db, err := cluster.Open(name, dialer)
+	if err != nil {
+		return nil, err
+	}
+	cluster := &Cluster{
+		db: db,
+	}
+	return cluster, nil
+}
+
+// Transaction creates a new ClusterTx object and transactionally executes the
+// cluster database interactions invoked by the given function. If the function
+// returns no error, all database changes are committed to the cluster database
+// database, otherwise they are rolled back.
+func (c *Cluster) Transaction(f func(*ClusterTx) error) error {
+	clusterTx := &ClusterTx{}
+	return query.Transaction(c.db, func(tx *sql.Tx) error {
+		clusterTx.tx = tx
+		return f(clusterTx)
+	})
+}
+
+// Close the database facade.
+func (c *Cluster) Close() error {
+	return c.db.Close()
+}
+
 // UpdateSchemasDotGo updates the schema.go files in the local/ and cluster/
 // sub-packages.
 func UpdateSchemasDotGo() error {
diff --git a/lxd/db/db_export_test.go b/lxd/db/db_export_test.go
new file mode 100644
index 000000000..a975c9081
--- /dev/null
+++ b/lxd/db/db_export_test.go
@@ -0,0 +1,9 @@
+package db
+
+import "database/sql"
+
+// DB returns the low level database handle to the cluster gRPC SQL database
+// handler. Used by tests for introspecing the database with raw SQL.
+func (c *Cluster) DB() *sql.DB {
+	return c.db
+}
diff --git a/lxd/db/db_test.go b/lxd/db/db_test.go
index 243d48de0..cf2eeb6df 100644
--- a/lxd/db/db_test.go
+++ b/lxd/db/db_test.go
@@ -23,3 +23,19 @@ func TestNode_Schema(t *testing.T) {
 	assert.NoError(t, rows.Scan(&n))
 	assert.Equal(t, 1, n)
 }
+
+// A gRPC SQL connection is established when starting to interact with the
+// cluster database.
+func TestCluster_Setup(t *testing.T) {
+	cluster, cleanup := db.NewTestCluster(t)
+	defer cleanup()
+
+	db := cluster.DB()
+	rows, err := db.Query("SELECT COUNT(*) FROM sqlite_master")
+	assert.NoError(t, err)
+	defer rows.Close()
+	assert.Equal(t, true, rows.Next())
+	var n uint
+	assert.NoError(t, rows.Scan(&n))
+	assert.Zero(t, n)
+}
diff --git a/lxd/db/testing.go b/lxd/db/testing.go
index 188e6f630..1cb6344d3 100644
--- a/lxd/db/testing.go
+++ b/lxd/db/testing.go
@@ -2,10 +2,16 @@ package db
 
 import (
 	"io/ioutil"
+	"net"
 	"os"
 	"testing"
+	"time"
 
+	"github.com/CanonicalLtd/go-grpc-sql"
+	"github.com/lxc/lxd/lxd/util"
+	"github.com/mattn/go-sqlite3"
 	"github.com/stretchr/testify/require"
+	"google.golang.org/grpc"
 )
 
 // NewTestNode creates a new Node for testing purposes, along with a function
@@ -43,3 +49,64 @@ func NewTestNodeTx(t *testing.T) (*NodeTx, func()) {
 
 	return nodeTx, cleanup
 }
+
+// NewTestCluster creates a new Cluster for testing purposes, along with a function
+// that can be used to clean it up when done.
+func NewTestCluster(t *testing.T) (*Cluster, func()) {
+	// Create an in-memory gRPC SQL server and dialer.
+	server, dialer := newGrpcServer()
+
+	cluster, err := OpenCluster(":memory:", dialer)
+	require.NoError(t, err)
+
+	cleanup := func() {
+		require.NoError(t, cluster.Close())
+		server.Stop()
+	}
+
+	return cluster, cleanup
+}
+
+// NewTestClusterTx returns a fresh ClusterTx object, along with a function that can
+// be called to cleanup state when done with it.
+func NewTestClusterTx(t *testing.T) (*ClusterTx, func()) {
+	cluster, clusterCleanup := NewTestCluster(t)
+
+	var err error
+
+	clusterTx := &ClusterTx{}
+	clusterTx.tx, err = cluster.db.Begin()
+	require.NoError(t, err)
+
+	cleanup := func() {
+		err := clusterTx.tx.Commit()
+		require.NoError(t, err)
+		clusterCleanup()
+	}
+
+	return clusterTx, cleanup
+}
+
+// Create a new in-memory gRPC server attached to a grpc-sql gateway backed by a
+// SQLite driver.
+//
+// Return the newly created gRPC server and a dialer that can be used to
+// connect to it.
+func newGrpcServer() (*grpc.Server, grpcsql.Dialer) {
+	listener, dial := util.InMemoryNetwork()
+	server := grpcsql.NewServer(&sqlite3.SQLiteDriver{})
+
+	// Setup an in-memory gRPC dialer.
+	options := []grpc.DialOption{
+		grpc.WithInsecure(),
+		grpc.WithDialer(func(string, time.Duration) (net.Conn, error) {
+			return dial(), nil
+		}),
+	}
+	dialer := func() (*grpc.ClientConn, error) {
+		return grpc.Dial("", options...)
+	}
+
+	go server.Serve(listener)
+	return server, dialer
+}
diff --git a/lxd/db/transaction.go b/lxd/db/transaction.go
index 4e1d89c66..de30c11f7 100644
--- a/lxd/db/transaction.go
+++ b/lxd/db/transaction.go
@@ -9,3 +9,21 @@ import "database/sql"
 type NodeTx struct {
 	tx *sql.Tx // Handle to a transaction in the node-level SQLite database.
 }
+
+// Tx returns the low level database handle to the node-local SQLite
+// transaction.
+//
+// FIXME: this is a transitional method needed for compatibility with some
+//        legacy call sites. It should be removed when there are no more
+//        consumers.
+func (n *NodeTx) Tx() *sql.Tx {
+	return n.tx
+}
+
+// ClusterTx models a single interaction with a LXD cluster database.
+//
+// It wraps low-level sql.Tx objects and offers a high-level API to fetch and
+// update data.
+type ClusterTx struct {
+	tx *sql.Tx // Handle to a transaction in the cluster dqlite database.
+}

From c6c312023f7ac064f24def9c593fb1716af5327b Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Sat, 14 Oct 2017 11:38:10 +0000
Subject: [PATCH 06/14] Add cluster.Gateway to manage the lifecycle of the
 cluster database

This is a first version of the Gateway object, an API that the daemon
will use in order to 1) run a dqlite node (if appropriate) 2) connect
to the leader dqlite node via gRPC.

For now there's no actual dqlite plumbing in place, and all the
Gateway does is to expose an regular sqlite db over an in-memory gRPC
network (client/server).

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/cluster/gateway.go      | 103 ++++++++++++++++++++++++++++++++++++++++++++
 lxd/cluster/gateway_test.go |  40 +++++++++++++++++
 lxd/db/db.go                |  11 ++++-
 3 files changed, 152 insertions(+), 2 deletions(-)
 create mode 100644 lxd/cluster/gateway.go
 create mode 100644 lxd/cluster/gateway_test.go

diff --git a/lxd/cluster/gateway.go b/lxd/cluster/gateway.go
new file mode 100644
index 000000000..e3dfe5d67
--- /dev/null
+++ b/lxd/cluster/gateway.go
@@ -0,0 +1,103 @@
+package cluster
+
+import (
+	"net"
+	"time"
+
+	"github.com/CanonicalLtd/go-grpc-sql"
+	"github.com/lxc/lxd/lxd/db"
+	"github.com/lxc/lxd/lxd/util"
+	"github.com/lxc/lxd/shared"
+	"github.com/mattn/go-sqlite3"
+	"google.golang.org/grpc"
+)
+
+// NewGateway creates a new Gateway for managing access to the dqlite cluster.
+//
+// When a new gateway is created, the node-level database is queried to check
+// what kind of role this node plays and if it's exposed over the network. It
+// will initialize internal data structures accordingly, for example starting a
+// dqlite driver if this node is a database node.
+//
+// After creation, the Daemon is expected to expose whatever http handlers the
+// HandlerFuncs method returns and to access the dqlite cluster using the gRPC
+// dialer returned by the Dialer method.
+func NewGateway(db *db.Node, cert *shared.CertInfo, latency float64) (*Gateway, error) {
+	gateway := &Gateway{
+		db:      db,
+		cert:    cert,
+		latency: latency,
+	}
+
+	err := gateway.init()
+	if err != nil {
+		return nil, err
+	}
+
+	return gateway, nil
+}
+
+// Gateway mediates access to the dqlite cluster using a gRPC SQL client, and
+// possibly runs a dqlite replica on this LXD node (if we're configured to do
+// so).
+type Gateway struct {
+	db      *db.Node
+	cert    *shared.CertInfo
+	latency float64
+
+	// The gRPC server exposing the dqlite driver created by this
+	// gateway. It's nil if this LXD node is not supposed to be part of the
+	// raft cluster.
+	server *grpc.Server
+
+	// A dialer that will connect to the gRPC server using an in-memory
+	// net.Conn. It's non-nil when clustering is not enabled on this LXD
+	// node, and so we don't expose any dqlite or raft network endpoint,
+	// but still we want to use dqlite as backend for the "cluster"
+	// database, to minimize the difference between code paths in
+	// clustering and non-clustering modes.
+	memoryDial func() (*grpc.ClientConn, error)
+}
+
+// Dialer returns a gRPC dial function that can be used to connect to one of
+// the dqlite nodes via gRPC.
+func (g *Gateway) Dialer() grpcsql.Dialer {
+	return func() (*grpc.ClientConn, error) {
+		// Memory conection.
+		return g.memoryDial()
+	}
+}
+
+// Shutdown this gateway, stopping the gRPC server and possibly the raft factory.
+func (g *Gateway) Shutdown() error {
+	if g.server != nil {
+		g.server.Stop()
+		// Unset the memory dial, since Shutdown() is also called for
+		// switching between in-memory and network mode.
+		g.memoryDial = nil
+	}
+	return nil
+}
+
+// Initialize the gateway, creating a new raft factory and gRPC server (if this
+// node is a database node), and a gRPC dialer.
+func (g *Gateway) init() error {
+	g.server = grpcsql.NewServer(&sqlite3.SQLiteDriver{})
+	listener, dial := util.InMemoryNetwork()
+	go g.server.Serve(listener)
+	g.memoryDial = grpcMemoryDial(dial)
+	return nil
+}
+
+// Convert a raw in-memory dial function into a gRPC one.
+func grpcMemoryDial(dial func() net.Conn) func() (*grpc.ClientConn, error) {
+	options := []grpc.DialOption{
+		grpc.WithInsecure(),
+		grpc.WithDialer(func(string, time.Duration) (net.Conn, error) {
+			return dial(), nil
+		}),
+	}
+	return func() (*grpc.ClientConn, error) {
+		return grpc.Dial("", options...)
+	}
+}
diff --git a/lxd/cluster/gateway_test.go b/lxd/cluster/gateway_test.go
new file mode 100644
index 000000000..33072e993
--- /dev/null
+++ b/lxd/cluster/gateway_test.go
@@ -0,0 +1,40 @@
+package cluster_test
+
+import (
+	"os"
+	"path/filepath"
+	"testing"
+
+	"github.com/lxc/lxd/lxd/cluster"
+	"github.com/lxc/lxd/lxd/db"
+	"github.com/lxc/lxd/shared"
+	"github.com/lxc/lxd/shared/logging"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+// Basic creation and shutdown. By default, the gateway runs an in-memory gRPC
+// server.
+func TestGateway_Single(t *testing.T) {
+	db, cleanup := db.NewTestNode(t)
+	defer cleanup()
+
+	cert := shared.TestingKeyPair()
+	gateway := newGateway(t, db, cert)
+	defer gateway.Shutdown()
+
+	dialer := gateway.Dialer()
+	conn, err := dialer()
+	assert.NoError(t, err)
+	assert.NotNil(t, conn)
+}
+
+// Create a new test Gateway with the given parameters, and ensure no error
+// happens.
+func newGateway(t *testing.T, db *db.Node, certInfo *shared.CertInfo) *cluster.Gateway {
+	logging.Testing(t)
+	require.NoError(t, os.Mkdir(filepath.Join(db.Dir(), "raft"), 0755))
+	gateway, err := cluster.NewGateway(db, certInfo, 0.2)
+	require.NoError(t, err)
+	return gateway
+}
diff --git a/lxd/db/db.go b/lxd/db/db.go
index f1eae9653..9c6add273 100644
--- a/lxd/db/db.go
+++ b/lxd/db/db.go
@@ -31,7 +31,8 @@ var (
 
 // Node mediates access to LXD's data stored in the node-local SQLite database.
 type Node struct {
-	db *sql.DB // Handle to the node-local SQLite database file.
+	db  *sql.DB // Handle to the node-local SQLite database file.
+	dir string  // Reference to the directory where the database file lives.
 }
 
 // OpenNode creates a new Node object.
@@ -55,7 +56,8 @@ func OpenNode(dir string, fresh func(*Node) error, legacyPatches map[int]*Legacy
 	}
 
 	node := &Node{
-		db: db,
+		db:  db,
+		dir: dir,
 	}
 
 	if initial == 0 {
@@ -90,6 +92,11 @@ func (n *Node) DB() *sql.DB {
 	return n.db
 }
 
+// Dir returns the directory of the underlying SQLite database file.
+func (n *Node) Dir() string {
+	return n.dir
+}
+
 // Transaction creates a new NodeTx object and transactionally executes the
 // node-level database interactions invoked by the given function. If the
 // function returns no error, all database changes are committed to the

From e2dedbeee8e9417fea4fd0d65c4f07478e0dea5e Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Sat, 14 Oct 2017 12:01:54 +0000
Subject: [PATCH 07/14] Wire cluster.Gateway into Daemon

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/daemon.go                  | 71 +++++++++++++++++++++++++++++++-----------
 lxd/daemon_integration_test.go |  4 ++-
 lxd/main_daemon.go             |  5 ++-
 3 files changed, 58 insertions(+), 22 deletions(-)

diff --git a/lxd/daemon.go b/lxd/daemon.go
index 6968e08ae..67785d552 100644
--- a/lxd/daemon.go
+++ b/lxd/daemon.go
@@ -25,6 +25,7 @@ import (
 	"gopkg.in/macaroon-bakery.v2/bakery/identchecker"
 	"gopkg.in/macaroon-bakery.v2/httpbakery"
 
+	"github.com/lxc/lxd/lxd/cluster"
 	"github.com/lxc/lxd/lxd/db"
 	"github.com/lxc/lxd/lxd/endpoints"
 	"github.com/lxc/lxd/lxd/state"
@@ -43,6 +44,7 @@ type Daemon struct {
 	clientCerts  []x509.Certificate
 	os           *sys.OS
 	db           *db.Node
+	cluster      *db.Cluster
 	readyChan    chan bool
 	shutdownChan chan bool
 
@@ -56,6 +58,7 @@ type Daemon struct {
 
 	config    *DaemonConfig
 	endpoints *endpoints.Endpoints
+	gateway   *cluster.Gateway
 
 	proxy func(req *http.Request) (*url.URL, error)
 
@@ -69,7 +72,8 @@ type externalAuth struct {
 
 // DaemonConfig holds configuration values for Daemon.
 type DaemonConfig struct {
-	Group string // Group name the local unix socket should be chown'ed to
+	Group       string  // Group name the local unix socket should be chown'ed to
+	RaftLatency float64 // Coarse grain measure of the cluster latency
 }
 
 // NewDaemon returns a new Daemon object with the given configuration.
@@ -80,9 +84,16 @@ func NewDaemon(config *DaemonConfig, os *sys.OS) *Daemon {
 	}
 }
 
+// DefaultDaemonConfig returns a DaemonConfig object with default values/
+func DefaultDaemonConfig() *DaemonConfig {
+	return &DaemonConfig{
+		RaftLatency: 1.0,
+	}
+}
+
 // DefaultDaemon returns a new, un-initialized Daemon object with default values.
 func DefaultDaemon() *Daemon {
-	config := &DaemonConfig{}
+	config := DefaultDaemonConfig()
 	os := sys.DefaultOS()
 	return NewDaemon(config, os)
 }
@@ -362,6 +373,37 @@ func (d *Daemon) init() error {
 		return err
 	}
 
+	/* Setup server certificate */
+	certInfo, err := shared.KeyPairAndCA(d.os.VarDir, "server", shared.CertServer)
+	if err != nil {
+		return err
+	}
+
+	/* Setup dqlite */
+	d.gateway, err = cluster.NewGateway(d.db, certInfo, d.config.RaftLatency)
+	if err != nil {
+		return err
+	}
+
+	/* Setup some mounts (nice to have) */
+	if !d.os.MockMode {
+		// Attempt to mount the shmounts tmpfs
+		setupSharedMounts()
+
+		// Attempt to Mount the devlxd tmpfs
+		devlxd := filepath.Join(d.os.VarDir, "devlxd")
+		if !shared.IsMountPoint(devlxd) {
+			syscall.Mount("tmpfs", devlxd, "tmpfs", 0, "size=100k,mode=0755")
+		}
+	}
+
+	/* Open the cluster database */
+	clusterFilename := filepath.Join(d.os.VarDir, "db.bin")
+	d.cluster, err = db.OpenCluster(clusterFilename, d.gateway.Dialer())
+	if err != nil {
+		return err
+	}
+
 	/* Read the storage pools */
 	err = SetupStorageDriver(d.State(), false)
 	if err != nil {
@@ -396,17 +438,6 @@ func (d *Daemon) init() error {
 		daemonConfig["core.proxy_ignore_hosts"].Get(),
 	)
 
-	/* Setup some mounts (nice to have) */
-	if !d.os.MockMode {
-		// Attempt to mount the shmounts tmpfs
-		setupSharedMounts()
-
-		// Attempt to Mount the devlxd tmpfs
-		if !shared.IsMountPoint(shared.VarPath("devlxd")) {
-			syscall.Mount("tmpfs", shared.VarPath("devlxd"), "tmpfs", 0, "size=100k,mode=0755")
-		}
-	}
-
 	if !d.os.MockMode {
 		/* Start the scheduler */
 		go deviceEventListener(d.State())
@@ -419,11 +450,6 @@ func (d *Daemon) init() error {
 	}
 
 	/* Setup the web server */
-	certInfo, err := shared.KeyPairAndCA(d.os.VarDir, "server", shared.CertServer)
-	if err != nil {
-		return err
-	}
-
 	config := &endpoints.Config{
 		Dir:                  d.os.VarDir,
 		Cert:                 certInfo,
@@ -531,6 +557,15 @@ func (d *Daemon) Stop() error {
 		logger.Infof("Closing the database")
 		trackError(d.db.Close())
 	}
+	if d.cluster != nil {
+		trackError(d.cluster.Close())
+	}
+	if d.gateway != nil {
+		trackError(d.gateway.Shutdown())
+	}
+	if d.endpoints != nil {
+		trackError(d.endpoints.Down())
+	}
 
 	logger.Infof("Saving simplestreams cache")
 	trackError(imageSaveStreamCache(d.os))
diff --git a/lxd/daemon_integration_test.go b/lxd/daemon_integration_test.go
index 79e8700b3..0f689dfa5 100644
--- a/lxd/daemon_integration_test.go
+++ b/lxd/daemon_integration_test.go
@@ -55,7 +55,9 @@ func newDaemon(t *testing.T) (*Daemon, func()) {
 
 // Create a new DaemonConfig object for testing purposes.
 func newConfig() *DaemonConfig {
-	return &DaemonConfig{}
+	return &DaemonConfig{
+		RaftLatency: 0.2,
+	}
 }
 
 // Create a new sys.OS object for testing purposes.
diff --git a/lxd/main_daemon.go b/lxd/main_daemon.go
index 4b0948544..7b9d84372 100644
--- a/lxd/main_daemon.go
+++ b/lxd/main_daemon.go
@@ -38,9 +38,8 @@ func cmdDaemon(args *Args) error {
 		}
 
 	}
-	c := &DaemonConfig{
-		Group: args.Group,
-	}
+	c := DefaultDaemonConfig()
+	c.Group = args.Group
 	d := NewDaemon(c, sys.DefaultOS())
 	err = d.Init()
 	if err != nil {

From a7528870ee2517aed400511271c9bf65dbbb4b23 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Thu, 12 Oct 2017 20:12:21 +0000
Subject: [PATCH 08/14] Add V1 cluster schema

This is an initial pass at creating the first version of the cluster
database schema.

An new updateFromV0 patch has been added, which for now only creates a
single table ("nodes") for holding the list of all LXD nodes
participating to the cluster.

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/db/cluster/open.go        |  9 +++++++++
 lxd/db/cluster/schema.go      | 22 +++++++++++++++++++++
 lxd/db/cluster/update.go      | 46 +++++++++++++++++++++++++++++++++++++++++++
 lxd/db/cluster/update_test.go | 26 ++++++++++++++++++++++++
 lxd/db/db.go                  |  6 +++++-
 5 files changed, 108 insertions(+), 1 deletion(-)
 create mode 100644 lxd/db/cluster/schema.go
 create mode 100644 lxd/db/cluster/update.go
 create mode 100644 lxd/db/cluster/update_test.go

diff --git a/lxd/db/cluster/open.go b/lxd/db/cluster/open.go
index d135dea6f..bf05f8790 100644
--- a/lxd/db/cluster/open.go
+++ b/lxd/db/cluster/open.go
@@ -34,6 +34,15 @@ func Open(name string, dialer grpcsql.Dialer) (*sql.DB, error) {
 	return db, nil
 }
 
+// EnsureSchema applies all relevant schema updates to the cluster database.
+//
+// Return the initial schema version found before starting the update, along
+// with any error occurred.
+func EnsureSchema(db *sql.DB) (int, error) {
+	schema := Schema()
+	return schema.Ensure(db)
+}
+
 // Generate a new name for the grpcsql driver registration. We need it to be
 // unique for testing, see below.
 func grpcSQLDriverName() string {
diff --git a/lxd/db/cluster/schema.go b/lxd/db/cluster/schema.go
new file mode 100644
index 000000000..90a358e96
--- /dev/null
+++ b/lxd/db/cluster/schema.go
@@ -0,0 +1,22 @@
+package cluster
+
+// DO NOT EDIT BY HAND
+//
+// This code was generated by the schema.DotGo function. If you need to
+// modify the database schema, please add a new schema update to update.go
+// and the run 'make update-schema'.
+const freshSchema = `
+CREATE TABLE nodes (
+    id INTEGER PRIMARY KEY,
+    name TEXT NOT NULL,
+    description TEXT DEFAULT '',
+    address TEXT NOT NULL,
+    schema INTEGER NOT NULL,
+    api_extensions INTEGER NOT NULL,
+    heartbeat DATETIME DEFAULT CURRENT_TIMESTAMP,
+    UNIQUE (name),
+    UNIQUE (address)
+);
+
+INSERT INTO schema (version, updated_at) VALUES (1, strftime("%s"))
+`
diff --git a/lxd/db/cluster/update.go b/lxd/db/cluster/update.go
new file mode 100644
index 000000000..3d43e9b2e
--- /dev/null
+++ b/lxd/db/cluster/update.go
@@ -0,0 +1,46 @@
+package cluster
+
+import (
+	"database/sql"
+
+	"github.com/lxc/lxd/lxd/db/schema"
+)
+
+// Schema for the cluster database.
+func Schema() *schema.Schema {
+	schema := schema.NewFromMap(updates)
+	schema.Fresh(freshSchema)
+	return schema
+}
+
+// SchemaDotGo refreshes the schema.go file in this package, using the updates
+// defined here.
+func SchemaDotGo() error {
+	return schema.DotGo(updates, "schema")
+}
+
+// SchemaVersion is the current version of the cluster database schema.
+var SchemaVersion = len(updates)
+
+var updates = map[int]schema.Update{
+	1: updateFromV0,
+}
+
+func updateFromV0(tx *sql.Tx) error {
+	// v0..v1 the dawn of clustering
+	stmt := `
+CREATE TABLE nodes (
+    id INTEGER PRIMARY KEY,
+    name TEXT NOT NULL,
+    description TEXT DEFAULT '',
+    address TEXT NOT NULL,
+    schema INTEGER NOT NULL,
+    api_extensions INTEGER NOT NULL,
+    heartbeat DATETIME DEFAULT CURRENT_TIMESTAMP,
+    UNIQUE (name),
+    UNIQUE (address)
+);
+`
+	_, err := tx.Exec(stmt)
+	return err
+}
diff --git a/lxd/db/cluster/update_test.go b/lxd/db/cluster/update_test.go
new file mode 100644
index 000000000..c80a51574
--- /dev/null
+++ b/lxd/db/cluster/update_test.go
@@ -0,0 +1,26 @@
+package cluster_test
+
+import (
+	"testing"
+	"time"
+
+	"github.com/lxc/lxd/lxd/db/cluster"
+	"github.com/stretchr/testify/require"
+)
+
+func TestUpdateFromV0(t *testing.T) {
+	schema := cluster.Schema()
+	db, err := schema.ExerciseUpdate(1, nil)
+	require.NoError(t, err)
+
+	_, err = db.Exec("INSERT INTO nodes VALUES (1, 'foo', 'blah', '1.2.3.4:666', 1, 32, ?)", time.Now())
+	require.NoError(t, err)
+
+	// Unique constraint on name
+	_, err = db.Exec("INSERT INTO nodes VALUES (2, 'foo', 'gosh', '5.6.7.8:666', 5, 20, ?)", time.Now())
+	require.Error(t, err)
+
+	// Unique constraint on address
+	_, err = db.Exec("INSERT INTO nodes VALUES (3, 'bar', 'gasp', '1.2.3.4:666', 9, 11), ?)", time.Now())
+	require.Error(t, err)
+}
diff --git a/lxd/db/db.go b/lxd/db/db.go
index 9c6add273..0bc0d0e39 100644
--- a/lxd/db/db.go
+++ b/lxd/db/db.go
@@ -159,7 +159,11 @@ func (c *Cluster) Close() error {
 func UpdateSchemasDotGo() error {
 	err := node.SchemaDotGo()
 	if err != nil {
-		return fmt.Errorf("failed to update local schema.go: %v", err)
+		return fmt.Errorf("failed to update node schema.go: %v", err)
+	}
+	err = cluster.SchemaDotGo()
+	if err != nil {
+		return fmt.Errorf("failed to update cluster schema.go: %v", err)
 	}
 
 	return nil

From f2643e71a782abf47a6a2eb5280842cd6504a647 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Sat, 14 Oct 2017 12:24:49 +0000
Subject: [PATCH 09/14] Wire cluster.EnsureSchema into db.OpenCluster

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/db/db.go      | 10 +++++++++-
 lxd/db/db_test.go | 32 +++++++++++++++++++-------------
 2 files changed, 28 insertions(+), 14 deletions(-)

diff --git a/lxd/db/db.go b/lxd/db/db.go
index 0bc0d0e39..f07ced010 100644
--- a/lxd/db/db.go
+++ b/lxd/db/db.go
@@ -7,6 +7,7 @@ import (
 
 	grpcsql "github.com/CanonicalLtd/go-grpc-sql"
 	"github.com/mattn/go-sqlite3"
+	"github.com/pkg/errors"
 
 	"github.com/lxc/lxd/lxd/db/cluster"
 	"github.com/lxc/lxd/lxd/db/node"
@@ -129,11 +130,18 @@ type Cluster struct {
 func OpenCluster(name string, dialer grpcsql.Dialer) (*Cluster, error) {
 	db, err := cluster.Open(name, dialer)
 	if err != nil {
-		return nil, err
+		return nil, errors.Wrap(err, "failed to open database")
+	}
+
+	_, err = cluster.EnsureSchema(db)
+	if err != nil {
+		return nil, errors.Wrap(err, "failed to ensure schema")
 	}
+
 	cluster := &Cluster{
 		db: db,
 	}
+
 	return cluster, nil
 }
 
diff --git a/lxd/db/db_test.go b/lxd/db/db_test.go
index cf2eeb6df..33b27c003 100644
--- a/lxd/db/db_test.go
+++ b/lxd/db/db_test.go
@@ -4,7 +4,9 @@ import (
 	"testing"
 
 	"github.com/lxc/lxd/lxd/db"
+	"github.com/lxc/lxd/lxd/db/query"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 // Node database objects automatically initialize their schema as needed.
@@ -15,13 +17,14 @@ func TestNode_Schema(t *testing.T) {
 	// The underlying node-level database has exactly one row in the schema
 	// table.
 	db := node.DB()
-	rows, err := db.Query("SELECT COUNT(*) FROM schema")
-	assert.NoError(t, err)
-	defer rows.Close()
-	assert.Equal(t, true, rows.Next())
-	var n int
-	assert.NoError(t, rows.Scan(&n))
+	tx, err := db.Begin()
+	require.NoError(t, err)
+	n, err := query.Count(tx, "schema", "")
+	require.NoError(t, err)
 	assert.Equal(t, 1, n)
+
+	assert.NoError(t, tx.Commit())
+	assert.NoError(t, db.Close())
 }
 
 // A gRPC SQL connection is established when starting to interact with the
@@ -30,12 +33,15 @@ func TestCluster_Setup(t *testing.T) {
 	cluster, cleanup := db.NewTestCluster(t)
 	defer cleanup()
 
+	// The underlying node-level database has exactly one row in the schema
+	// table.
 	db := cluster.DB()
-	rows, err := db.Query("SELECT COUNT(*) FROM sqlite_master")
-	assert.NoError(t, err)
-	defer rows.Close()
-	assert.Equal(t, true, rows.Next())
-	var n uint
-	assert.NoError(t, rows.Scan(&n))
-	assert.Zero(t, n)
+	tx, err := db.Begin()
+	require.NoError(t, err)
+	n, err := query.Count(tx, "schema", "")
+	require.NoError(t, err)
+	assert.Equal(t, 1, n)
+
+	assert.NoError(t, tx.Commit())
+	assert.NoError(t, db.Close())
 }

From 37c8f23e1ec65b51d4b18e4fadb3086a65e0690d Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Fri, 13 Oct 2017 10:21:56 +0000
Subject: [PATCH 10/14] Check the versions of other nodes in
 cluster.EnsureSchema

Modify cluster.EnsureSchema to also check that all other nodes in the
cluster have a schema version and an API extensions count that match
the ones of the node.

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/daemon.go                        |   9 +-
 lxd/db/cluster/open.go               | 124 +++++++++++++++++++++++-
 lxd/db/cluster/open_test.go          | 180 +++++++++++++++++++++++++++++++++++
 lxd/db/cluster/query.go              |  50 ++++++++++
 lxd/db/cluster/schema_export_test.go |   3 +
 lxd/db/db.go                         |  16 +++-
 lxd/db/testing.go                    |   2 +-
 7 files changed, 374 insertions(+), 10 deletions(-)
 create mode 100644 lxd/db/cluster/open_test.go
 create mode 100644 lxd/db/cluster/query.go
 create mode 100644 lxd/db/cluster/schema_export_test.go

diff --git a/lxd/daemon.go b/lxd/daemon.go
index 67785d552..ce296417d 100644
--- a/lxd/daemon.go
+++ b/lxd/daemon.go
@@ -19,6 +19,7 @@ import (
 	"github.com/gorilla/mux"
 	"github.com/juju/idmclient"
 	_ "github.com/mattn/go-sqlite3"
+	"github.com/pkg/errors"
 	"golang.org/x/net/context"
 	"gopkg.in/macaroon-bakery.v2/bakery"
 	"gopkg.in/macaroon-bakery.v2/bakery/checkers"
@@ -397,11 +398,13 @@ func (d *Daemon) init() error {
 		}
 	}
 
+	address := daemonConfig["core.https_address"].Get()
+
 	/* Open the cluster database */
 	clusterFilename := filepath.Join(d.os.VarDir, "db.bin")
-	d.cluster, err = db.OpenCluster(clusterFilename, d.gateway.Dialer())
+	d.cluster, err = db.OpenCluster(clusterFilename, d.gateway.Dialer(), address)
 	if err != nil {
-		return err
+		return errors.Wrap(err, "failed to open cluster database")
 	}
 
 	/* Read the storage pools */
@@ -456,7 +459,7 @@ func (d *Daemon) init() error {
 		RestServer:           RestServer(d),
 		DevLxdServer:         DevLxdServer(d),
 		LocalUnixSocketGroup: d.config.Group,
-		NetworkAddress:       daemonConfig["core.https_address"].Get(),
+		NetworkAddress:       address,
 	}
 	d.endpoints, err = endpoints.Up(config)
 	if err != nil {
diff --git a/lxd/db/cluster/open.go b/lxd/db/cluster/open.go
index bf05f8790..f9b3139e7 100644
--- a/lxd/db/cluster/open.go
+++ b/lxd/db/cluster/open.go
@@ -6,6 +6,9 @@ import (
 	"sync/atomic"
 
 	"github.com/CanonicalLtd/go-grpc-sql"
+	"github.com/lxc/lxd/lxd/db/schema"
+	"github.com/lxc/lxd/shared/version"
+	"github.com/pkg/errors"
 )
 
 // Open the cluster database object.
@@ -36,11 +39,58 @@ func Open(name string, dialer grpcsql.Dialer) (*sql.DB, error) {
 
 // EnsureSchema applies all relevant schema updates to the cluster database.
 //
-// Return the initial schema version found before starting the update, along
-// with any error occurred.
-func EnsureSchema(db *sql.DB) (int, error) {
+// Before actually doing anything, this function will make sure that all nodes
+// in the cluster have a schema version and a number of API extensions that
+// match our one. If it's not the case, we either return an error (if some
+// nodes have version greater than us and we need to be upgraded), or return
+// false and no error (if some nodes have a lower version, and we need to wait
+// till they get upgraded and restarted).
+func EnsureSchema(db *sql.DB, address string) (bool, error) {
+	someNodesAreBehind := false
+	apiExtensions := len(version.APIExtensions)
+
+	check := func(current int, tx *sql.Tx) error {
+		// If we're bootstrapping a fresh schema, skip any check, since
+		// it's safe to assume we are the only node.
+		if current == 0 {
+			return nil
+		}
+
+		// Check if we're clustered
+		n, err := selectNodesCount(tx)
+		if err != nil {
+			return errors.Wrap(err, "failed to fetch current nodes count")
+		}
+		if n == 0 {
+			return nil // Nothing to do.
+		}
+
+		// Update the schema and api_extension columns of ourselves.
+		err = updateNodeVersion(tx, address, apiExtensions)
+		if err != nil {
+			return errors.Wrap(err, "failed to update node version")
+
+		}
+
+		err = checkClusterIsUpgradable(tx, [2]int{len(updates), apiExtensions})
+		if err == errSomeNodesAreBehind {
+			someNodesAreBehind = true
+			return schema.ErrGracefulAbort
+		}
+		return err
+	}
+
 	schema := Schema()
-	return schema.Ensure(db)
+	schema.Check(check)
+
+	_, err := schema.Ensure(db)
+	if someNodesAreBehind {
+		return false, nil
+	}
+	if err != nil {
+		return false, err
+	}
+	return true, err
 }
 
 // Generate a new name for the grpcsql driver registration. We need it to be
@@ -55,3 +105,69 @@ func grpcSQLDriverName() string {
 // to unregister drivers, and in unit tests more than one driver gets
 // registered.
 var grpcSQLDriverSerial uint64
+
+func checkClusterIsUpgradable(tx *sql.Tx, target [2]int) error {
+	// Get the current versions in the nodes table.
+	versions, err := selectNodesVersions(tx)
+	if err != nil {
+		return errors.Wrap(err, "failed to fetch current nodes versions")
+	}
+
+	for _, version := range versions {
+		n, err := compareVersions(target, version)
+		if err != nil {
+			return err
+		}
+		switch n {
+		case 0:
+			// Versions are equal, there's hope for the
+			// update. Let's check the next node.
+			continue
+		case 1:
+			// Our version is bigger, we should stop here
+			// and wait for other nodes to be upgraded and
+			// restarted.
+			return errSomeNodesAreBehind
+		case 2:
+			// Another node has a version greater than ours
+			// and presumeably is waiting for other nodes
+			// to upgrade. Let's error out and shutdown
+			// since we need a greater version.
+			return fmt.Errorf("this node's version is behind, please upgrade")
+		default:
+			// Sanity.
+			panic("unexpected return value from compareVersions")
+		}
+	}
+	return nil
+}
+
+// Compare two nodes versions.
+//
+// A version consists of the version the node's schema and the number of API
+// extensions it supports.
+//
+// Return 0 if they equal, 1 if the first version is greater than the second
+// and 2 if the second is greater than the first.
+//
+// Return an error if inconsistent versions are detected, for example the first
+// node's schema is greater than the second's, but the number of extensions is
+// smaller.
+func compareVersions(version1, version2 [2]int) (int, error) {
+	schema1, extensions1 := version1[0], version1[1]
+	schema2, extensions2 := version2[0], version2[1]
+
+	if schema1 == schema2 && extensions1 == extensions2 {
+		return 0, nil
+	}
+	if schema1 >= schema2 && extensions1 >= extensions2 {
+		return 1, nil
+	}
+	if schema1 <= schema2 && extensions1 <= extensions2 {
+		return 2, nil
+	}
+
+	return -1, fmt.Errorf("nodes have inconsistent versions")
+}
+
+var errSomeNodesAreBehind = fmt.Errorf("some nodes are behind this node's version")
diff --git a/lxd/db/cluster/open_test.go b/lxd/db/cluster/open_test.go
new file mode 100644
index 000000000..f858d7b35
--- /dev/null
+++ b/lxd/db/cluster/open_test.go
@@ -0,0 +1,180 @@
+package cluster_test
+
+import (
+	"database/sql"
+	"fmt"
+	"testing"
+
+	"github.com/lxc/lxd/lxd/db/cluster"
+	"github.com/lxc/lxd/lxd/db/query"
+	"github.com/lxc/lxd/shared/version"
+	"github.com/mpvl/subtest"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+// If the node is not clustered, the schema updates works normally.
+func TestEnsureSchema_NoClustered(t *testing.T) {
+	db := newDB(t)
+	ready, err := cluster.EnsureSchema(db, "1.2.3.4:666")
+	assert.True(t, ready)
+	assert.NoError(t, err)
+}
+
+// Exercise EnsureSchema failures when the cluster can't be upgraded right now.
+func TestEnsureSchema_ClusterNotUpgradable(t *testing.T) {
+	schema := cluster.SchemaVersion
+	apiExtensions := len(version.APIExtensions)
+
+	cases := []struct {
+		title string
+		setup func(*testing.T, *sql.DB)
+		ready bool
+		error string
+	}{
+		{
+			`a node's schema version is behind`,
+			func(t *testing.T, db *sql.DB) {
+				addNode(t, db, "1", schema, apiExtensions)
+				addNode(t, db, "2", schema-1, apiExtensions)
+			},
+			false, // The schema was not updated
+			"",    // No error is returned
+		},
+		{
+			`a node's number of API extensions is behind`,
+			func(t *testing.T, db *sql.DB) {
+				addNode(t, db, "1", schema, apiExtensions)
+				addNode(t, db, "2", schema, apiExtensions-1)
+			},
+			false, // The schema was not updated
+			"",    // No error is returned
+		},
+		{
+			`this node's schema is behind`,
+			func(t *testing.T, db *sql.DB) {
+				addNode(t, db, "1", schema, apiExtensions)
+				addNode(t, db, "2", schema+1, apiExtensions)
+			},
+			false,
+			"this node's version is behind, please upgrade",
+		},
+		{
+			`this node's number of API extensions is behind`,
+			func(t *testing.T, db *sql.DB) {
+				addNode(t, db, "1", schema, apiExtensions)
+				addNode(t, db, "2", schema, apiExtensions+1)
+			},
+			false,
+			"this node's version is behind, please upgrade",
+		},
+		{
+			`inconsistent schema version and API extensions number`,
+			func(t *testing.T, db *sql.DB) {
+				addNode(t, db, "1", schema, apiExtensions)
+				addNode(t, db, "2", schema+1, apiExtensions-1)
+			},
+			false,
+			"nodes have inconsistent versions",
+		},
+	}
+	for _, c := range cases {
+		subtest.Run(t, c.title, func(t *testing.T) {
+			db := newDB(t)
+			c.setup(t, db)
+			ready, err := cluster.EnsureSchema(db, "1")
+			assert.Equal(t, c.ready, ready)
+			if c.error == "" {
+				assert.NoError(t, err)
+			} else {
+				assert.EqualError(t, err, c.error)
+			}
+		})
+	}
+}
+
+// Regardless of whether the schema could actually be upgraded or not, the
+// version of this node gets updated.
+func TestEnsureSchema_UpdateNodeVersion(t *testing.T) {
+	schema := cluster.SchemaVersion
+	apiExtensions := len(version.APIExtensions)
+
+	cases := []struct {
+		setup func(*testing.T, *sql.DB)
+		ready bool
+	}{
+		{
+			func(t *testing.T, db *sql.DB) {},
+			true,
+		},
+		{
+			func(t *testing.T, db *sql.DB) {
+				// Add a node which is behind.
+				addNode(t, db, "2", schema, apiExtensions-1)
+			},
+			true,
+		},
+	}
+	for _, c := range cases {
+		subtest.Run(t, fmt.Sprintf("%v", c.ready), func(t *testing.T) {
+			db := newDB(t)
+
+			// Add ourselves with an older schema version and API
+			// extensions number.
+			addNode(t, db, "1", schema-1, apiExtensions-1)
+
+			// Ensure the schema.
+			ready, err := cluster.EnsureSchema(db, "1")
+			assert.NoError(t, err)
+			assert.Equal(t, c.ready, ready)
+
+			// Check that the nodes table was updated with our new
+			// schema version and API extensions number.
+			assertNode(t, db, "1", schema, apiExtensions)
+		})
+	}
+}
+
+// Create a new in-memory SQLite database with a fresh cluster schema.
+func newDB(t *testing.T) *sql.DB {
+	db, err := sql.Open("sqlite3", ":memory:")
+	require.NoError(t, err)
+
+	createTableSchema := `
+CREATE TABLE schema (
+    id         INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+    version    INTEGER NOT NULL,
+    updated_at DATETIME NOT NULL,
+    UNIQUE (version)
+);
+`
+	_, err = db.Exec(createTableSchema + cluster.FreshSchema)
+	require.NoError(t, err)
+
+	return db
+}
+
+// Add a new node with the given address, schema version and number of api extensions.
+func addNode(t *testing.T, db *sql.DB, address string, schema int, apiExtensions int) {
+	err := query.Transaction(db, func(tx *sql.Tx) error {
+		stmt := `
+INSERT INTO nodes(name, address, schema, api_extensions) VALUES (?, ?, ?, ?)
+`
+		name := fmt.Sprintf("node at %s", address)
+		_, err := tx.Exec(stmt, name, address, schema, apiExtensions)
+		return err
+	})
+	require.NoError(t, err)
+}
+
+// Assert that the node with the given address has the given schema version and API
+// extensions number.
+func assertNode(t *testing.T, db *sql.DB, address string, schema int, apiExtensions int) {
+	err := query.Transaction(db, func(tx *sql.Tx) error {
+		where := "address=? AND schema=? AND api_extensions=?"
+		n, err := query.Count(tx, "nodes", where, address, schema, apiExtensions)
+		assert.Equal(t, 1, n, "node does not have expected version")
+		return err
+	})
+	require.NoError(t, err)
+}
diff --git a/lxd/db/cluster/query.go b/lxd/db/cluster/query.go
new file mode 100644
index 000000000..286ffe2db
--- /dev/null
+++ b/lxd/db/cluster/query.go
@@ -0,0 +1,50 @@
+package cluster
+
+import (
+	"database/sql"
+	"fmt"
+
+	"github.com/lxc/lxd/lxd/db/query"
+)
+
+// Update the schema and api_extensions columns of the row in the nodes table
+// that matches the given id.
+//
+// If not such row is found, an error is returned.
+func updateNodeVersion(tx *sql.Tx, address string, apiExtensions int) error {
+	stmt := "UPDATE nodes SET schema=?, api_extensions=? WHERE address=?"
+	result, err := tx.Exec(stmt, len(updates), apiExtensions, address)
+	if err != nil {
+		return err
+	}
+	n, err := result.RowsAffected()
+	if err != nil {
+		return err
+	}
+	if n != 1 {
+		return fmt.Errorf("updated %d rows instead of 1", n)
+	}
+	return nil
+}
+
+// Return the number of rows in the nodes table.
+func selectNodesCount(tx *sql.Tx) (int, error) {
+	return query.Count(tx, "nodes", "")
+}
+
+// Return a slice of binary integer tuples. Each tuple contains the schema
+// version and number of api extensions of a node in the cluster.
+func selectNodesVersions(tx *sql.Tx) ([][2]int, error) {
+	versions := [][2]int{}
+
+	dest := func(i int) []interface{} {
+		versions = append(versions, [2]int{})
+		return []interface{}{&versions[i][0], &versions[i][1]}
+	}
+
+	err := query.SelectObjects(tx, dest, "SELECT schema, api_extensions FROM nodes")
+	if err != nil {
+		return nil, err
+	}
+	return versions, nil
+}
diff --git a/lxd/db/cluster/schema_export_test.go b/lxd/db/cluster/schema_export_test.go
new file mode 100644
index 000000000..d2041016a
--- /dev/null
+++ b/lxd/db/cluster/schema_export_test.go
@@ -0,0 +1,3 @@
+package cluster
+
+var FreshSchema = freshSchema
diff --git a/lxd/db/db.go b/lxd/db/db.go
index f07ced010..6b4a49b6d 100644
--- a/lxd/db/db.go
+++ b/lxd/db/db.go
@@ -28,6 +28,8 @@ var (
 	 * already do.
 	 */
 	NoSuchObjectError = fmt.Errorf("No such object")
+
+	Upgrading = fmt.Errorf("The cluster database is upgrading")
 )
 
 // Node mediates access to LXD's data stored in the node-local SQLite database.
@@ -127,13 +129,23 @@ type Cluster struct {
 
 // OpenCluster creates a new Cluster object for interacting with the dqlite
 // database.
-func OpenCluster(name string, dialer grpcsql.Dialer) (*Cluster, error) {
+//
+// - name: Basename of the database file holding the data. Typically "db.bin".
+// - dialer: Function used to connect to the dqlite backend via gRPC SQL.
+// - address: Network address of this node (or empty string).
+// - api: Number of API extensions that this node supports.
+//
+// The address and api parameters will be used to determine if the cluster
+// database matches our version, and possibly trigger a schema update. If the
+// schema update can't be performed right now, because some nodes are still
+// behind, an Upgrading error is returned.
+func OpenCluster(name string, dialer grpcsql.Dialer, address string) (*Cluster, error) {
 	db, err := cluster.Open(name, dialer)
 	if err != nil {
 		return nil, errors.Wrap(err, "failed to open database")
 	}
 
-	_, err = cluster.EnsureSchema(db)
+	_, err = cluster.EnsureSchema(db, address)
 	if err != nil {
 		return nil, errors.Wrap(err, "failed to ensure schema")
 	}
diff --git a/lxd/db/testing.go b/lxd/db/testing.go
index 1cb6344d3..65c5ddcae 100644
--- a/lxd/db/testing.go
+++ b/lxd/db/testing.go
@@ -56,7 +56,7 @@ func NewTestCluster(t *testing.T) (*Cluster, func()) {
 	// Create an in-memory gRPC SQL server and dialer.
 	server, dialer := newGrpcServer()
 
-	cluster, err := OpenCluster(":memory:", dialer)
+	cluster, err := OpenCluster(":memory:", dialer, "1")
 	require.NoError(t, err)
 
 	cleanup := func() {

From 2f59b144b3f0c7d71bfaed2e1dbf78cf45449db4 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Thu, 12 Oct 2017 16:18:14 +0000
Subject: [PATCH 11/14] Rename State.DB to State.Node and add State.Cluster

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/container.go             | 34 +++++++++++++++++-----------------
 lxd/container_lxc.go         | 18 +++++++++---------
 lxd/containers.go            |  8 ++++----
 lxd/containers_get.go        |  2 +-
 lxd/daemon.go                |  2 +-
 lxd/devices.go               |  8 ++++----
 lxd/logging.go               |  2 +-
 lxd/networks.go              |  8 ++++----
 lxd/networks_utils.go        |  4 ++--
 lxd/profiles.go              |  6 +++---
 lxd/state/state.go           | 12 +++++++-----
 lxd/storage.go               | 28 ++++++++++++++--------------
 lxd/storage_ceph.go          |  2 +-
 lxd/storage_lvm_utils.go     |  6 +++---
 lxd/storage_pools_utils.go   | 10 +++++-----
 lxd/storage_volumes_utils.go | 16 ++++++++--------
 16 files changed, 84 insertions(+), 82 deletions(-)

diff --git a/lxd/container.go b/lxd/container.go
index e5e5ef885..e91238b61 100644
--- a/lxd/container.go
+++ b/lxd/container.go
@@ -537,7 +537,7 @@ func containerCreateEmptySnapshot(s *state.State, args db.ContainerArgs) (contai
 	// Now create the empty snapshot
 	err = c.Storage().ContainerSnapshotCreateEmpty(c)
 	if err != nil {
-		s.DB.ContainerRemove(args.Name)
+		s.Node.ContainerRemove(args.Name)
 		return nil, err
 	}
 
@@ -546,7 +546,7 @@ func containerCreateEmptySnapshot(s *state.State, args db.ContainerArgs) (contai
 
 func containerCreateFromImage(s *state.State, args db.ContainerArgs, hash string) (container, error) {
 	// Get the image properties
-	_, img, err := s.DB.ImageGet(hash, false, false)
+	_, img, err := s.Node.ImageGet(hash, false, false)
 	if err != nil {
 		return nil, err
 	}
@@ -567,16 +567,16 @@ func containerCreateFromImage(s *state.State, args db.ContainerArgs, hash string
 		return nil, err
 	}
 
-	err = s.DB.ImageLastAccessUpdate(hash, time.Now().UTC())
+	err = s.Node.ImageLastAccessUpdate(hash, time.Now().UTC())
 	if err != nil {
-		s.DB.ContainerRemove(args.Name)
+		s.Node.ContainerRemove(args.Name)
 		return nil, fmt.Errorf("Error updating image last use date: %s", err)
 	}
 
 	// Now create the storage from an image
 	err = c.Storage().ContainerCreateFromImage(c, hash)
 	if err != nil {
-		s.DB.ContainerRemove(args.Name)
+		s.Node.ContainerRemove(args.Name)
 		return nil, err
 	}
 
@@ -601,7 +601,7 @@ func containerCreateAsCopy(s *state.State, args db.ContainerArgs, sourceContaine
 	if !containerOnly {
 		snapshots, err := sourceContainer.Snapshots()
 		if err != nil {
-			s.DB.ContainerRemove(args.Name)
+			s.Node.ContainerRemove(args.Name)
 			return nil, err
 		}
 
@@ -633,9 +633,9 @@ func containerCreateAsCopy(s *state.State, args db.ContainerArgs, sourceContaine
 	err = ct.Storage().ContainerCopy(ct, sourceContainer, containerOnly)
 	if err != nil {
 		for _, v := range csList {
-			s.DB.ContainerRemove((*v).Name())
+			s.Node.ContainerRemove((*v).Name())
 		}
-		s.DB.ContainerRemove(args.Name)
+		s.Node.ContainerRemove(args.Name)
 		return nil, err
 	}
 
@@ -704,7 +704,7 @@ func containerCreateAsSnapshot(s *state.State, args db.ContainerArgs, sourceCont
 	// Clone the container
 	err = sourceContainer.Storage().ContainerSnapshotCreate(c, sourceContainer)
 	if err != nil {
-		s.DB.ContainerRemove(args.Name)
+		s.Node.ContainerRemove(args.Name)
 		return nil, err
 	}
 
@@ -767,7 +767,7 @@ func containerCreateInternal(s *state.State, args db.ContainerArgs) (container,
 	}
 
 	// Validate container devices
-	err = containerValidDevices(s.DB, args.Devices, false, false)
+	err = containerValidDevices(s.Node, args.Devices, false, false)
 	if err != nil {
 		return nil, err
 	}
@@ -783,7 +783,7 @@ func containerCreateInternal(s *state.State, args db.ContainerArgs) (container,
 	}
 
 	// Validate profiles
-	profiles, err := s.DB.Profiles()
+	profiles, err := s.Node.Profiles()
 	if err != nil {
 		return nil, err
 	}
@@ -795,7 +795,7 @@ func containerCreateInternal(s *state.State, args db.ContainerArgs) (container,
 	}
 
 	// Create the container entry
-	id, err := s.DB.ContainerCreate(args)
+	id, err := s.Node.ContainerCreate(args)
 	if err != nil {
 		if err == db.DbErrAlreadyDefined {
 			thing := "Container"
@@ -813,9 +813,9 @@ func containerCreateInternal(s *state.State, args db.ContainerArgs) (container,
 	args.Id = id
 
 	// Read the timestamp from the database
-	dbArgs, err := s.DB.ContainerGet(args.Name)
+	dbArgs, err := s.Node.ContainerGet(args.Name)
 	if err != nil {
-		s.DB.ContainerRemove(args.Name)
+		s.Node.ContainerRemove(args.Name)
 		return nil, err
 	}
 	args.CreationDate = dbArgs.CreationDate
@@ -824,7 +824,7 @@ func containerCreateInternal(s *state.State, args db.ContainerArgs) (container,
 	// Setup the container struct and finish creation (storage and idmap)
 	c, err := containerLXCCreate(s, args)
 	if err != nil {
-		s.DB.ContainerRemove(args.Name)
+		s.Node.ContainerRemove(args.Name)
 		return nil, err
 	}
 
@@ -879,7 +879,7 @@ func containerConfigureInternal(c container) error {
 
 func containerLoadById(s *state.State, id int) (container, error) {
 	// Get the DB record
-	name, err := s.DB.ContainerName(id)
+	name, err := s.Node.ContainerName(id)
 	if err != nil {
 		return nil, err
 	}
@@ -889,7 +889,7 @@ func containerLoadById(s *state.State, id int) (container, error) {
 
 func containerLoadByName(s *state.State, name string) (container, error) {
 	// Get the DB record
-	args, err := s.DB.ContainerGet(name)
+	args, err := s.Node.ContainerGet(name)
 	if err != nil {
 		return nil, err
 	}
diff --git a/lxd/container_lxc.go b/lxd/container_lxc.go
index e23b4fcb1..1e4751d95 100644
--- a/lxd/container_lxc.go
+++ b/lxd/container_lxc.go
@@ -257,7 +257,7 @@ func containerLXCCreate(s *state.State, args db.ContainerArgs) (container, error
 	// Create the container struct
 	c := &containerLXC{
 		state:        s,
-		db:           s.DB,
+		db:           s.Node,
 		id:           args.Id,
 		name:         args.Name,
 		description:  args.Description,
@@ -293,7 +293,7 @@ func containerLXCCreate(s *state.State, args db.ContainerArgs) (container, error
 		return nil, err
 	}
 
-	err = containerValidDevices(s.DB, c.expandedDevices, false, true)
+	err = containerValidDevices(s.Node, c.expandedDevices, false, true)
 	if err != nil {
 		c.Delete()
 		logger.Error("Failed creating container", ctxMap)
@@ -315,7 +315,7 @@ func containerLXCCreate(s *state.State, args db.ContainerArgs) (container, error
 	storagePool := rootDiskDevice["pool"]
 
 	// Get the storage pool ID for the container
-	poolID, pool, err := s.DB.StoragePoolGet(storagePool)
+	poolID, pool, err := s.Node.StoragePoolGet(storagePool)
 	if err != nil {
 		c.Delete()
 		return nil, err
@@ -329,7 +329,7 @@ func containerLXCCreate(s *state.State, args db.ContainerArgs) (container, error
 	}
 
 	// Create a new database entry for the container's storage volume
-	_, err = s.DB.StoragePoolVolumeCreate(args.Name, "", storagePoolVolumeTypeContainer, poolID, volumeConfig)
+	_, err = s.Node.StoragePoolVolumeCreate(args.Name, "", storagePoolVolumeTypeContainer, poolID, volumeConfig)
 	if err != nil {
 		c.Delete()
 		return nil, err
@@ -339,7 +339,7 @@ func containerLXCCreate(s *state.State, args db.ContainerArgs) (container, error
 	cStorage, err := storagePoolVolumeContainerCreateInit(s, storagePool, args.Name)
 	if err != nil {
 		c.Delete()
-		s.DB.StoragePoolVolumeDelete(args.Name, storagePoolVolumeTypeContainer, poolID)
+		s.Node.StoragePoolVolumeDelete(args.Name, storagePoolVolumeTypeContainer, poolID)
 		logger.Error("Failed to initialize container storage", ctxMap)
 		return nil, err
 	}
@@ -425,7 +425,7 @@ func containerLXCLoad(s *state.State, args db.ContainerArgs) (container, error)
 	// Create the container struct
 	c := &containerLXC{
 		state:        s,
-		db:           s.DB,
+		db:           s.Node,
 		id:           args.Id,
 		name:         args.Name,
 		description:  args.Description,
@@ -710,7 +710,7 @@ func findIdmap(state *state.State, cName string, isolatedStr string, configBase
 	idmapLock.Lock()
 	defer idmapLock.Unlock()
 
-	cs, err := state.DB.ContainersList(db.CTypeRegular)
+	cs, err := state.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		return nil, 0, err
 	}
@@ -3211,12 +3211,12 @@ func writeBackupFile(c container) error {
 	}
 
 	s := c.DaemonState()
-	poolID, pool, err := s.DB.StoragePoolGet(poolName)
+	poolID, pool, err := s.Node.StoragePoolGet(poolName)
 	if err != nil {
 		return err
 	}
 
-	_, volume, err := s.DB.StoragePoolVolumeGetType(c.Name(), storagePoolVolumeTypeContainer, poolID)
+	_, volume, err := s.Node.StoragePoolVolumeGetType(c.Name(), storagePoolVolumeTypeContainer, poolID)
 	if err != nil {
 		return err
 	}
diff --git a/lxd/containers.go b/lxd/containers.go
index c4a169770..b80c7ebcb 100644
--- a/lxd/containers.go
+++ b/lxd/containers.go
@@ -99,7 +99,7 @@ func (slice containerAutostartList) Swap(i, j int) {
 
 func containersRestart(s *state.State) error {
 	// Get all the containers
-	result, err := s.DB.ContainersList(db.CTypeRegular)
+	result, err := s.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		return err
 	}
@@ -146,13 +146,13 @@ func containersShutdown(s *state.State) error {
 	var wg sync.WaitGroup
 
 	// Get all the containers
-	results, err := s.DB.ContainersList(db.CTypeRegular)
+	results, err := s.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		return err
 	}
 
 	// Reset all container states
-	err = s.DB.ContainersResetState()
+	err = s.Node.ContainersResetState()
 	if err != nil {
 		return err
 	}
@@ -200,7 +200,7 @@ func containerDeleteSnapshots(s *state.State, cname string) error {
 	logger.Debug("containerDeleteSnapshots",
 		log.Ctx{"container": cname})
 
-	results, err := s.DB.ContainerGetSnapshots(cname)
+	results, err := s.Node.ContainerGetSnapshots(cname)
 	if err != nil {
 		return err
 	}
diff --git a/lxd/containers_get.go b/lxd/containers_get.go
index b86dbb336..9ae37928b 100644
--- a/lxd/containers_get.go
+++ b/lxd/containers_get.go
@@ -34,7 +34,7 @@ func containersGet(d *Daemon, r *http.Request) Response {
 }
 
 func doContainersGet(s *state.State, recursion bool) (interface{}, error) {
-	result, err := s.DB.ContainersList(db.CTypeRegular)
+	result, err := s.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		return nil, err
 	}
diff --git a/lxd/daemon.go b/lxd/daemon.go
index ce296417d..3be36d97d 100644
--- a/lxd/daemon.go
+++ b/lxd/daemon.go
@@ -184,7 +184,7 @@ func isJSONRequest(r *http.Request) bool {
 
 // State creates a new State instance liked to our internal db and os.
 func (d *Daemon) State() *state.State {
-	return state.NewState(d.db, d.os)
+	return state.NewState(d.db, d.cluster, d.os)
 }
 
 // UnixSocket returns the full path to the unix.socket file that this daemon is
diff --git a/lxd/devices.go b/lxd/devices.go
index fd9211d3b..a75275b65 100644
--- a/lxd/devices.go
+++ b/lxd/devices.go
@@ -595,7 +595,7 @@ func deviceTaskBalance(s *state.State) {
 	}
 
 	// Iterate through the containers
-	containers, err := s.DB.ContainersList(db.CTypeRegular)
+	containers, err := s.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		logger.Error("problem loading containers list", log.Ctx{"err": err})
 		return
@@ -721,7 +721,7 @@ func deviceNetworkPriority(s *state.State, netif string) {
 		return
 	}
 
-	containers, err := s.DB.ContainersList(db.CTypeRegular)
+	containers, err := s.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		return
 	}
@@ -752,7 +752,7 @@ func deviceNetworkPriority(s *state.State, netif string) {
 }
 
 func deviceUSBEvent(s *state.State, usb usbDevice) {
-	containers, err := s.DB.ContainersList(db.CTypeRegular)
+	containers, err := s.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		logger.Error("problem loading containers list", log.Ctx{"err": err})
 		return
@@ -838,7 +838,7 @@ func deviceEventListener(s *state.State) {
 
 			logger.Debugf("Scheduler: network: %s has been added: updating network priorities", e[0])
 			deviceNetworkPriority(s, e[0])
-			networkAutoAttach(s.DB, e[0])
+			networkAutoAttach(s.Node, e[0])
 		case e := <-chUSB:
 			deviceUSBEvent(s, e)
 		case e := <-deviceSchedRebalance:
diff --git a/lxd/logging.go b/lxd/logging.go
index 1b9d3447e..fab90adcf 100644
--- a/lxd/logging.go
+++ b/lxd/logging.go
@@ -41,7 +41,7 @@ func expireLogs(ctx context.Context, state *state.State) error {
 	var containers []string
 	ch := make(chan struct{})
 	go func() {
-		containers, err = state.DB.ContainersList(db.CTypeRegular)
+		containers, err = state.Node.ContainersList(db.CTypeRegular)
 		ch <- struct{}{}
 	}()
 	select {
diff --git a/lxd/networks.go b/lxd/networks.go
index 94d80fd98..05971e969 100644
--- a/lxd/networks.go
+++ b/lxd/networks.go
@@ -392,19 +392,19 @@ var networkCmd = Command{name: "networks/{name}", get: networkGet, delete: netwo
 
 // The network structs and functions
 func networkLoadByName(s *state.State, name string) (*network, error) {
-	id, dbInfo, err := s.DB.NetworkGet(name)
+	id, dbInfo, err := s.Node.NetworkGet(name)
 	if err != nil {
 		return nil, err
 	}
 
-	n := network{db: s.DB, state: s, id: id, name: name, description: dbInfo.Description, config: dbInfo.Config}
+	n := network{db: s.Node, state: s, id: id, name: name, description: dbInfo.Description, config: dbInfo.Config}
 
 	return &n, nil
 }
 
 func networkStartup(s *state.State) error {
 	// Get a list of managed networks
-	networks, err := s.DB.Networks()
+	networks, err := s.Node.Networks()
 	if err != nil {
 		return err
 	}
@@ -428,7 +428,7 @@ func networkStartup(s *state.State) error {
 
 func networkShutdown(s *state.State) error {
 	// Get a list of managed networks
-	networks, err := s.DB.Networks()
+	networks, err := s.Node.Networks()
 	if err != nil {
 		return err
 	}
diff --git a/lxd/networks_utils.go b/lxd/networks_utils.go
index fa7c0fd9b..66fd3b60d 100644
--- a/lxd/networks_utils.go
+++ b/lxd/networks_utils.go
@@ -744,7 +744,7 @@ func networkUpdateStatic(s *state.State, networkName string) error {
 	defer networkStaticLock.Unlock()
 
 	// Get all the containers
-	containers, err := s.DB.ContainersList(db.CTypeRegular)
+	containers, err := s.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		return err
 	}
@@ -753,7 +753,7 @@ func networkUpdateStatic(s *state.State, networkName string) error {
 	var networks []string
 	if networkName == "" {
 		var err error
-		networks, err = s.DB.Networks()
+		networks, err = s.Node.Networks()
 		if err != nil {
 			return err
 		}
diff --git a/lxd/profiles.go b/lxd/profiles.go
index 510908f84..cb4b63090 100644
--- a/lxd/profiles.go
+++ b/lxd/profiles.go
@@ -105,12 +105,12 @@ var profilesCmd = Command{
 	post: profilesPost}
 
 func doProfileGet(s *state.State, name string) (*api.Profile, error) {
-	_, profile, err := s.DB.ProfileGet(name)
+	_, profile, err := s.Node.ProfileGet(name)
 	if err != nil {
 		return nil, err
 	}
 
-	cts, err := s.DB.ProfileContainersGet(name)
+	cts, err := s.Node.ProfileContainersGet(name)
 	if err != nil {
 		return nil, err
 	}
@@ -139,7 +139,7 @@ func profileGet(d *Daemon, r *http.Request) Response {
 func getContainersWithProfile(s *state.State, profile string) []container {
 	results := []container{}
 
-	output, err := s.DB.ProfileContainersGet(profile)
+	output, err := s.Node.ProfileContainersGet(profile)
 	if err != nil {
 		return results
 	}
diff --git a/lxd/state/state.go b/lxd/state/state.go
index 62b0afd72..dc49a823b 100644
--- a/lxd/state/state.go
+++ b/lxd/state/state.go
@@ -9,15 +9,17 @@ import (
 // and the operating system. It's typically used by model entities such as
 // containers, volumes, etc. in order to perform changes.
 type State struct {
-	DB *db.Node
-	OS *sys.OS
+	Node    *db.Node
+	Cluster *db.Cluster
+	OS      *sys.OS
 }
 
 // NewState returns a new State object with the given database and operating
 // system components.
-func NewState(db *db.Node, os *sys.OS) *State {
+func NewState(node *db.Node, cluster *db.Cluster, os *sys.OS) *State {
 	return &State{
-		DB: db,
-		OS: os,
+		Node:    node,
+		Cluster: cluster,
+		OS:      os,
 	}
 }
diff --git a/lxd/storage.go b/lxd/storage.go
index aaf581c0a..582bd6403 100644
--- a/lxd/storage.go
+++ b/lxd/storage.go
@@ -284,7 +284,7 @@ func storageCoreInit(driver string) (storage, error) {
 
 func storageInit(s *state.State, poolName string, volumeName string, volumeType int) (storage, error) {
 	// Load the storage pool.
-	poolID, pool, err := s.DB.StoragePoolGet(poolName)
+	poolID, pool, err := s.Node.StoragePoolGet(poolName)
 	if err != nil {
 		return nil, err
 	}
@@ -299,7 +299,7 @@ func storageInit(s *state.State, poolName string, volumeName string, volumeType
 	// Load the storage volume.
 	volume := &api.StorageVolume{}
 	if volumeName != "" && volumeType >= 0 {
-		_, volume, err = s.DB.StoragePoolVolumeGetType(volumeName, volumeType, poolID)
+		_, volume, err = s.Node.StoragePoolVolumeGetType(volumeName, volumeType, poolID)
 		if err != nil {
 			return nil, err
 		}
@@ -317,7 +317,7 @@ func storageInit(s *state.State, poolName string, volumeName string, volumeType
 		btrfs.pool = pool
 		btrfs.volume = volume
 		btrfs.s = s
-		btrfs.db = s.DB
+		btrfs.db = s.Node
 		err = btrfs.StoragePoolInit()
 		if err != nil {
 			return nil, err
@@ -329,7 +329,7 @@ func storageInit(s *state.State, poolName string, volumeName string, volumeType
 		dir.pool = pool
 		dir.volume = volume
 		dir.s = s
-		dir.db = s.DB
+		dir.db = s.Node
 		err = dir.StoragePoolInit()
 		if err != nil {
 			return nil, err
@@ -341,7 +341,7 @@ func storageInit(s *state.State, poolName string, volumeName string, volumeType
 		ceph.pool = pool
 		ceph.volume = volume
 		ceph.s = s
-		ceph.db = s.DB
+		ceph.db = s.Node
 		err = ceph.StoragePoolInit()
 		if err != nil {
 			return nil, err
@@ -353,7 +353,7 @@ func storageInit(s *state.State, poolName string, volumeName string, volumeType
 		lvm.pool = pool
 		lvm.volume = volume
 		lvm.s = s
-		lvm.db = s.DB
+		lvm.db = s.Node
 		err = lvm.StoragePoolInit()
 		if err != nil {
 			return nil, err
@@ -365,7 +365,7 @@ func storageInit(s *state.State, poolName string, volumeName string, volumeType
 		mock.pool = pool
 		mock.volume = volume
 		mock.s = s
-		mock.db = s.DB
+		mock.db = s.Node
 		err = mock.StoragePoolInit()
 		if err != nil {
 			return nil, err
@@ -377,7 +377,7 @@ func storageInit(s *state.State, poolName string, volumeName string, volumeType
 		zfs.pool = pool
 		zfs.volume = volume
 		zfs.s = s
-		zfs.db = s.DB
+		zfs.db = s.Node
 		err = zfs.StoragePoolInit()
 		if err != nil {
 			return nil, err
@@ -518,11 +518,11 @@ func storagePoolVolumeAttachInit(s *state.State, poolName string, volumeName str
 
 	st.SetStoragePoolVolumeWritable(&poolVolumePut)
 
-	poolID, err := s.DB.StoragePoolGetID(poolName)
+	poolID, err := s.Node.StoragePoolGetID(poolName)
 	if err != nil {
 		return nil, err
 	}
-	err = s.DB.StoragePoolVolumeUpdate(volumeName, volumeType, poolID, poolVolumePut.Description, poolVolumePut.Config)
+	err = s.Node.StoragePoolVolumeUpdate(volumeName, volumeType, poolID, poolVolumePut.Description, poolVolumePut.Config)
 	if err != nil {
 		return nil, err
 	}
@@ -545,7 +545,7 @@ func storagePoolVolumeContainerCreateInit(s *state.State, poolName string, conta
 
 func storagePoolVolumeContainerLoadInit(s *state.State, containerName string) (storage, error) {
 	// Get the storage pool of a given container.
-	poolName, err := s.DB.ContainerPool(containerName)
+	poolName, err := s.Node.ContainerPool(containerName)
 	if err != nil {
 		return nil, err
 	}
@@ -811,7 +811,7 @@ func StorageProgressWriter(op *operation, key string, description string) func(i
 }
 
 func SetupStorageDriver(s *state.State, forceCheck bool) error {
-	pools, err := s.DB.StoragePools()
+	pools, err := s.Node.StoragePools()
 	if err != nil {
 		if err == db.NoSuchObjectError {
 			logger.Debugf("No existing storage pools detected.")
@@ -828,7 +828,7 @@ func SetupStorageDriver(s *state.State, forceCheck bool) error {
 	// but the upgrade somehow got messed up then there will be no
 	// "storage_api" entry in the db.
 	if len(pools) > 0 && !forceCheck {
-		appliedPatches, err := s.DB.Patches()
+		appliedPatches, err := s.Node.Patches()
 		if err != nil {
 			return err
 		}
@@ -864,7 +864,7 @@ func SetupStorageDriver(s *state.State, forceCheck bool) error {
 	// appropriate. (Should be cheaper then querying the db all the time,
 	// especially if we keep adding more storage drivers.)
 	if !storagePoolDriversCacheInitialized {
-		tmp, err := s.DB.StoragePoolsGetDrivers()
+		tmp, err := s.Node.StoragePoolsGetDrivers()
 		if err != nil && err != db.NoSuchObjectError {
 			return nil
 		}
diff --git a/lxd/storage_ceph.go b/lxd/storage_ceph.go
index 32b0a5cf7..840f50f0f 100644
--- a/lxd/storage_ceph.go
+++ b/lxd/storage_ceph.go
@@ -972,7 +972,7 @@ func (s *storageCeph) ContainerCreateFromImage(container container, fingerprint
 			fingerprint, storagePoolVolumeTypeNameImage, s.UserName)
 
 		if ok {
-			_, volume, err := s.s.DB.StoragePoolVolumeGetType(fingerprint, db.StoragePoolVolumeTypeImage, s.poolID)
+			_, volume, err := s.s.Node.StoragePoolVolumeGetType(fingerprint, db.StoragePoolVolumeTypeImage, s.poolID)
 			if err != nil {
 				return err
 			}
diff --git a/lxd/storage_lvm_utils.go b/lxd/storage_lvm_utils.go
index 3eed137a9..c6bf53718 100644
--- a/lxd/storage_lvm_utils.go
+++ b/lxd/storage_lvm_utils.go
@@ -488,7 +488,7 @@ func (s *storageLvm) containerCreateFromImageThinLv(c container, fp string) erro
 		var imgerr error
 		ok, _ := storageLVExists(imageLvmDevPath)
 		if ok {
-			_, volume, err := s.s.DB.StoragePoolVolumeGetType(fp, db.StoragePoolVolumeTypeImage, s.poolID)
+			_, volume, err := s.s.Node.StoragePoolVolumeGetType(fp, db.StoragePoolVolumeTypeImage, s.poolID)
 			if err != nil {
 				return err
 			}
@@ -675,7 +675,7 @@ func storageLVMThinpoolExists(vgName string, poolName string) (bool, error) {
 func storageLVMGetThinPoolUsers(s *state.State) ([]string, error) {
 	results := []string{}
 
-	cNames, err := s.DB.ContainersList(db.CTypeRegular)
+	cNames, err := s.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		return results, err
 	}
@@ -693,7 +693,7 @@ func storageLVMGetThinPoolUsers(s *state.State) ([]string, error) {
 		}
 	}
 
-	imageNames, err := s.DB.ImagesGet(false)
+	imageNames, err := s.Node.ImagesGet(false)
 	if err != nil {
 		return results, err
 	}
diff --git a/lxd/storage_pools_utils.go b/lxd/storage_pools_utils.go
index 1059d3765..849100675 100644
--- a/lxd/storage_pools_utils.go
+++ b/lxd/storage_pools_utils.go
@@ -62,7 +62,7 @@ func storagePoolUpdate(state *state.State, name, newDescription string, newConfi
 
 	// Update the database if something changed
 	if len(changedConfig) != 0 || newDescription != oldDescription {
-		err = state.DB.StoragePoolUpdate(name, newDescription, newConfig)
+		err = state.Node.StoragePoolUpdate(name, newDescription, newConfig)
 		if err != nil {
 			return err
 		}
@@ -164,7 +164,7 @@ func storagePoolDBCreate(s *state.State, poolName, poolDescription string, drive
 	}
 
 	// Check that the storage pool does not already exist.
-	_, err = s.DB.StoragePoolGetID(poolName)
+	_, err = s.Node.StoragePoolGetID(poolName)
 	if err == nil {
 		return fmt.Errorf("The storage pool already exists")
 	}
@@ -187,7 +187,7 @@ func storagePoolDBCreate(s *state.State, poolName, poolDescription string, drive
 	}
 
 	// Create the database entry for the storage pool.
-	_, err = dbStoragePoolCreateAndUpdateCache(s.DB, poolName, poolDescription, driver, config)
+	_, err = dbStoragePoolCreateAndUpdateCache(s.Node, poolName, poolDescription, driver, config)
 	if err != nil {
 		return fmt.Errorf("Error inserting %s into database: %s", poolName, err)
 	}
@@ -209,7 +209,7 @@ func storagePoolCreateInternal(state *state.State, poolName, poolDescription str
 		if !tryUndo {
 			return
 		}
-		dbStoragePoolDeleteAndUpdateCache(state.DB, poolName)
+		dbStoragePoolDeleteAndUpdateCache(state.Node, poolName)
 	}()
 
 	s, err := storagePoolInit(state, poolName)
@@ -238,7 +238,7 @@ func storagePoolCreateInternal(state *state.State, poolName, poolDescription str
 	configDiff, _ := storageConfigDiff(config, postCreateConfig)
 	if len(configDiff) > 0 {
 		// Create the database entry for the storage pool.
-		err = state.DB.StoragePoolUpdate(poolName, poolDescription, postCreateConfig)
+		err = state.Node.StoragePoolUpdate(poolName, poolDescription, postCreateConfig)
 		if err != nil {
 			return fmt.Errorf("Error inserting %s into database: %s", poolName, err)
 		}
diff --git a/lxd/storage_volumes_utils.go b/lxd/storage_volumes_utils.go
index 7e690e60b..c79b1e461 100644
--- a/lxd/storage_volumes_utils.go
+++ b/lxd/storage_volumes_utils.go
@@ -151,14 +151,14 @@ func storagePoolVolumeUpdate(state *state.State, poolName string, volumeName str
 		s.SetStoragePoolVolumeWritable(&newWritable)
 	}
 
-	poolID, err := state.DB.StoragePoolGetID(poolName)
+	poolID, err := state.Node.StoragePoolGetID(poolName)
 	if err != nil {
 		return err
 	}
 
 	// Update the database if something changed
 	if len(changedConfig) != 0 || newDescription != oldDescription {
-		err = state.DB.StoragePoolVolumeUpdate(volumeName, volumeType, poolID, newDescription, newConfig)
+		err = state.Node.StoragePoolVolumeUpdate(volumeName, volumeType, poolID, newDescription, newConfig)
 		if err != nil {
 			return err
 		}
@@ -172,7 +172,7 @@ func storagePoolVolumeUpdate(state *state.State, poolName string, volumeName str
 
 func storagePoolVolumeUsedByContainersGet(s *state.State, volumeName string,
 	volumeTypeName string) ([]string, error) {
-	cts, err := s.DB.ContainersList(db.CTypeRegular)
+	cts, err := s.Node.ContainersList(db.CTypeRegular)
 	if err != nil {
 		return []string{}, err
 	}
@@ -233,7 +233,7 @@ func storagePoolVolumeUsedByGet(s *state.State, volumeName string, volumeTypeNam
 			fmt.Sprintf("/%s/containers/%s", version.APIVersion, ct))
 	}
 
-	profiles, err := profilesUsingPoolVolumeGetNames(s.DB, volumeName, volumeTypeName)
+	profiles, err := profilesUsingPoolVolumeGetNames(s.Node, volumeName, volumeTypeName)
 	if err != nil {
 		return []string{}, err
 	}
@@ -302,14 +302,14 @@ func storagePoolVolumeDBCreate(s *state.State, poolName string, volumeName, volu
 	}
 
 	// Load storage pool the volume will be attached to.
-	poolID, poolStruct, err := s.DB.StoragePoolGet(poolName)
+	poolID, poolStruct, err := s.Node.StoragePoolGet(poolName)
 	if err != nil {
 		return err
 	}
 
 	// Check that a storage volume of the same storage volume type does not
 	// already exist.
-	volumeID, _ := s.DB.StoragePoolVolumeGetTypeID(volumeName, volumeType, poolID)
+	volumeID, _ := s.Node.StoragePoolVolumeGetTypeID(volumeName, volumeType, poolID)
 	if volumeID > 0 {
 		return fmt.Errorf("a storage volume of type %s does already exist", volumeTypeName)
 	}
@@ -331,7 +331,7 @@ func storagePoolVolumeDBCreate(s *state.State, poolName string, volumeName, volu
 	}
 
 	// Create the database entry for the storage volume.
-	_, err = s.DB.StoragePoolVolumeCreate(volumeName, volumeDescription, volumeType, poolID, volumeConfig)
+	_, err = s.Node.StoragePoolVolumeCreate(volumeName, volumeDescription, volumeType, poolID, volumeConfig)
 	if err != nil {
 		return fmt.Errorf("Error inserting %s of type %s into database: %s", poolName, volumeTypeName, err)
 	}
@@ -361,7 +361,7 @@ func storagePoolVolumeCreateInternal(state *state.State, poolName string, volume
 	// Create storage volume.
 	err = s.StoragePoolVolumeCreate()
 	if err != nil {
-		state.DB.StoragePoolVolumeDelete(volumeName, volumeType, poolID)
+		state.Node.StoragePoolVolumeDelete(volumeName, volumeType, poolID)
 		return err
 	}
 

From 4bf68ed55930d2955379d63c57e73e7377c9499c Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Thu, 12 Oct 2017 16:52:25 +0000
Subject: [PATCH 12/14] Add testing facilities for state.State and sys.OS

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/state/testing.go | 29 +++++++++++++++++++++++++++++
 lxd/sys/testing.go   | 28 ++++++++++++++++++++++++++++
 2 files changed, 57 insertions(+)
 create mode 100644 lxd/state/testing.go
 create mode 100644 lxd/sys/testing.go

diff --git a/lxd/state/testing.go b/lxd/state/testing.go
new file mode 100644
index 000000000..f49cebd09
--- /dev/null
+++ b/lxd/state/testing.go
@@ -0,0 +1,29 @@
+package state
+
+import (
+	"testing"
+
+	"github.com/lxc/lxd/lxd/db"
+	"github.com/lxc/lxd/lxd/sys"
+)
+
+// NewTestState returns a State object initialized with testable instances of
+// the node/cluster databases and of the OS facade.
+//
+// Return the newly created State object, along with a function that can be
+// used for cleaning it up.
+func NewTestState(t *testing.T) (*State, func()) {
+	node, nodeCleanup := db.NewTestNode(t)
+	cluster, clusterCleanup := db.NewTestCluster(t)
+	os, osCleanup := sys.NewTestOS(t)
+
+	cleanup := func() {
+		nodeCleanup()
+		clusterCleanup()
+		osCleanup()
+	}
+
+	state := NewState(node, cluster, os)
+
+	return state, cleanup
+}
diff --git a/lxd/sys/testing.go b/lxd/sys/testing.go
new file mode 100644
index 000000000..b0bb8a42a
--- /dev/null
+++ b/lxd/sys/testing.go
@@ -0,0 +1,28 @@
+package sys
+
+import (
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+// NewTestOS returns a new OS instance initialized with test values.
+func NewTestOS(t *testing.T) (*OS, func()) {
+	dir, err := ioutil.TempDir("", "lxd-sys-os-test-")
+	require.NoError(t, err)
+
+	cleanup := func() {
+		require.NoError(t, os.RemoveAll(dir))
+	}
+
+	os := &OS{
+		VarDir:   dir,
+		CacheDir: filepath.Join(dir, "cache"),
+		LogDir:   filepath.Join(dir, "log"),
+	}
+
+	return os, cleanup
+}

From d71bab43542184304ba1878fa7f7833ba73a71f8 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Fri, 15 Sep 2017 07:24:43 +0000
Subject: [PATCH 13/14] Add db APIs to read and update the raft_nodes table

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/db/query/slices.go |  12 ++---
 lxd/db/raft.go         | 114 ++++++++++++++++++++++++++++++++++++++++++
 lxd/db/raft_test.go    | 133 +++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 253 insertions(+), 6 deletions(-)
 create mode 100644 lxd/db/raft.go
 create mode 100644 lxd/db/raft_test.go

diff --git a/lxd/db/query/slices.go b/lxd/db/query/slices.go
index 59d0cc892..6cd9a7934 100644
--- a/lxd/db/query/slices.go
+++ b/lxd/db/query/slices.go
@@ -8,7 +8,7 @@ import (
 
 // SelectStrings executes a statement which must yield rows with a single string
 // column. It returns the list of column values.
-func SelectStrings(tx *sql.Tx, query string) ([]string, error) {
+func SelectStrings(tx *sql.Tx, query string, args ...interface{}) ([]string, error) {
 	values := []string{}
 	scan := func(rows *sql.Rows) error {
 		var value string
@@ -20,7 +20,7 @@ func SelectStrings(tx *sql.Tx, query string) ([]string, error) {
 		return nil
 	}
 
-	err := scanSingleColumn(tx, query, "TEXT", scan)
+	err := scanSingleColumn(tx, query, args, "TEXT", scan)
 	if err != nil {
 		return nil, err
 	}
@@ -30,7 +30,7 @@ func SelectStrings(tx *sql.Tx, query string) ([]string, error) {
 
 // SelectIntegers executes a statement which must yield rows with a single integer
 // column. It returns the list of column values.
-func SelectIntegers(tx *sql.Tx, query string) ([]int, error) {
+func SelectIntegers(tx *sql.Tx, query string, args ...interface{}) ([]int, error) {
 	values := []int{}
 	scan := func(rows *sql.Rows) error {
 		var value int
@@ -42,7 +42,7 @@ func SelectIntegers(tx *sql.Tx, query string) ([]int, error) {
 		return nil
 	}
 
-	err := scanSingleColumn(tx, query, "INTEGER", scan)
+	err := scanSingleColumn(tx, query, args, "INTEGER", scan)
 	if err != nil {
 		return nil, err
 	}
@@ -76,8 +76,8 @@ func InsertStrings(tx *sql.Tx, stmt string, values []string) error {
 // Execute the given query and ensure that it yields rows with a single column
 // of the given database type. For every row yielded, execute the given
 // scanner.
-func scanSingleColumn(tx *sql.Tx, query string, typeName string, scan scanFunc) error {
-	rows, err := tx.Query(query)
+func scanSingleColumn(tx *sql.Tx, query string, args []interface{}, typeName string, scan scanFunc) error {
+	rows, err := tx.Query(query, args...)
 	if err != nil {
 		return err
 	}
diff --git a/lxd/db/raft.go b/lxd/db/raft.go
new file mode 100644
index 000000000..40d6b29cb
--- /dev/null
+++ b/lxd/db/raft.go
@@ -0,0 +1,114 @@
+package db
+
+import (
+	"fmt"
+
+	"github.com/lxc/lxd/lxd/db/query"
+	"github.com/pkg/errors"
+)
+
+// RaftNode holds information about a single node in the dqlite raft cluster.
+type RaftNode struct {
+	ID      int64  // Stable node identifier
+	Address string // Network address of the node
+}
+
+// RaftNodes returns information about all LXD nodes that are members of the
+// dqlite Raft cluster (possibly including the local node). If this LXD
+// instance is not running in clustered mode, an empty list is returned.
+func (n *NodeTx) RaftNodes() ([]RaftNode, error) {
+	nodes := []RaftNode{}
+	dest := func(i int) []interface{} {
+		nodes = append(nodes, RaftNode{})
+		return []interface{}{&nodes[i].ID, &nodes[i].Address}
+	}
+	err := query.SelectObjects(n.tx, dest, "SELECT id, address FROM raft_nodes ORDER BY id")
+	if err != nil {
+		return nil, errors.Wrap(err, "failed to fecth raft nodes")
+	}
+	return nodes, nil
+}
+
+// RaftNodeAddresses returns the addresses of all LXD nodes that are members of
+// the dqlite Raft cluster (possibly including the local node). If this LXD
+// instance is not running in clustered mode, an empty list is returned.
+func (n *NodeTx) RaftNodeAddresses() ([]string, error) {
+	return query.SelectStrings(n.tx, "SELECT address FROM raft_nodes")
+}
+
+// RaftNodeAddress returns the address of the LXD raft node with the given ID,
+// if any matching row exists.
+func (n *NodeTx) RaftNodeAddress(id int64) (string, error) {
+	stmt := "SELECT address FROM raft_nodes WHERE id=?"
+	addresses, err := query.SelectStrings(n.tx, stmt, id)
+	if err != nil {
+		return "", err
+	}
+	switch len(addresses) {
+	case 0:
+		return "", NoSuchObjectError
+	case 1:
+		return addresses[0], nil
+	default:
+		// This should never happen since we have a UNIQUE constraint
+		// on the raft_nodes.id column.
+		return "", fmt.Errorf("more than one match found")
+	}
+}
+
+// RaftNodeFirst adds a the first node if the cluster. It ensures that the
+// database ID is 1, to match the server ID of first raft log entry.
+//
+// This method is supposed to be called when there are no rows in raft_nodes,
+// and it will replace whatever existing row has ID 1.
+func (n *NodeTx) RaftNodeFirst(address string) error {
+	columns := []string{"id", "address"}
+	values := []interface{}{int64(1), address}
+	id, err := query.UpsertObject(n.tx, "raft_nodes", columns, values)
+	if err != nil {
+		return err
+	}
+	if id != 1 {
+		return fmt.Errorf("could not set raft node ID to 1")
+	}
+	return nil
+}
+
+// RaftNodeAdd adds a node to the current list of LXD nodes that are part of the
+// dqlite Raft cluster. It returns the ID of the newly inserted row.
+func (n *NodeTx) RaftNodeAdd(address string) (int64, error) {
+	columns := []string{"address"}
+	values := []interface{}{address}
+	return query.UpsertObject(n.tx, "raft_nodes", columns, values)
+}
+
+// RaftNodeDelete removes a node from the current list of LXD nodes that are
+// part of the dqlite Raft cluster.
+func (n *NodeTx) RaftNodeDelete(id int64) error {
+	deleted, err := query.DeleteObject(n.tx, "raft_nodes", id)
+	if err != nil {
+		return err
+	}
+	if !deleted {
+		return NoSuchObjectError
+	}
+	return nil
+}
+
+// RaftNodesReplace replaces the current list of raft nodes.
+func (n *NodeTx) RaftNodesReplace(nodes []RaftNode) error {
+	_, err := n.tx.Exec("DELETE FROM raft_nodes")
+	if err != nil {
+		return err
+	}
+
+	columns := []string{"id", "address"}
+	for _, node := range nodes {
+		values := []interface{}{node.ID, node.Address}
+		_, err := query.UpsertObject(n.tx, "raft_nodes", columns, values)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
diff --git a/lxd/db/raft_test.go b/lxd/db/raft_test.go
new file mode 100644
index 000000000..dd74b8237
--- /dev/null
+++ b/lxd/db/raft_test.go
@@ -0,0 +1,133 @@
+package db_test
+
+import (
+	"testing"
+
+	"github.com/lxc/lxd/lxd/db"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+// Fetch all raft nodes.
+func TestRaftNodes(t *testing.T) {
+	tx, cleanup := db.NewTestNodeTx(t)
+	defer cleanup()
+
+	id1, err := tx.RaftNodeAdd("1.2.3.4:666")
+	require.NoError(t, err)
+
+	id2, err := tx.RaftNodeAdd("5.6.7.8:666")
+	require.NoError(t, err)
+
+	nodes, err := tx.RaftNodes()
+	require.NoError(t, err)
+
+	assert.Equal(t, id1, nodes[0].ID)
+	assert.Equal(t, id2, nodes[1].ID)
+	assert.Equal(t, "1.2.3.4:666", nodes[0].Address)
+	assert.Equal(t, "5.6.7.8:666", nodes[1].Address)
+}
+
+// Fetch the addresses of all raft nodes.
+func TestRaftNodeAddresses(t *testing.T) {
+	tx, cleanup := db.NewTestNodeTx(t)
+	defer cleanup()
+
+	_, err := tx.RaftNodeAdd("1.2.3.4:666")
+	require.NoError(t, err)
+
+	_, err = tx.RaftNodeAdd("5.6.7.8:666")
+	require.NoError(t, err)
+
+	addresses, err := tx.RaftNodeAddresses()
+	require.NoError(t, err)
+
+	assert.Equal(t, []string{"1.2.3.4:666", "5.6.7.8:666"}, addresses)
+}
+
+// Fetch the address of the raft node with the given ID.
+func TestRaftNodeAddress(t *testing.T) {
+	tx, cleanup := db.NewTestNodeTx(t)
+	defer cleanup()
+
+	_, err := tx.RaftNodeAdd("1.2.3.4:666")
+	require.NoError(t, err)
+
+	id, err := tx.RaftNodeAdd("5.6.7.8:666")
+	require.NoError(t, err)
+
+	address, err := tx.RaftNodeAddress(id)
+	require.NoError(t, err)
+	assert.Equal(t, "5.6.7.8:666", address)
+}
+
+// Add the first raft node.
+func TestRaftNodeFirst(t *testing.T) {
+	tx, cleanup := db.NewTestNodeTx(t)
+	defer cleanup()
+
+	err := tx.RaftNodeFirst("1.2.3.4:666")
+	assert.NoError(t, err)
+
+	err = tx.RaftNodeDelete(1)
+	assert.NoError(t, err)
+
+	err = tx.RaftNodeFirst("5.6.7.8:666")
+	assert.NoError(t, err)
+
+	address, err := tx.RaftNodeAddress(1)
+	require.NoError(t, err)
+	assert.Equal(t, "5.6.7.8:666", address)
+}
+
+// Add a new raft node.
+func TestRaftNodeAdd(t *testing.T) {
+	tx, cleanup := db.NewTestNodeTx(t)
+	defer cleanup()
+
+	id, err := tx.RaftNodeAdd("1.2.3.4:666")
+	assert.Equal(t, int64(1), id)
+	assert.NoError(t, err)
+}
+
+// Delete an existing raft node.
+func TestRaftNodeDelete(t *testing.T) {
+	tx, cleanup := db.NewTestNodeTx(t)
+	defer cleanup()
+
+	id, err := tx.RaftNodeAdd("1.2.3.4:666")
+	require.NoError(t, err)
+
+	err = tx.RaftNodeDelete(id)
+	assert.NoError(t, err)
+}
+
+// Delete a non-existing raft node returns an error.
+func TestRaftNodeDelete_NonExisting(t *testing.T) {
+	tx, cleanup := db.NewTestNodeTx(t)
+	defer cleanup()
+
+	err := tx.RaftNodeDelete(1)
+	assert.Equal(t, db.NoSuchObjectError, err)
+}
+
+// Replace all existing raft nodes.
+func TestRaftNodesReplace(t *testing.T) {
+	tx, cleanup := db.NewTestNodeTx(t)
+	defer cleanup()
+
+	_, err := tx.RaftNodeAdd("1.2.3.4:666")
+	require.NoError(t, err)
+
+	nodes := []db.RaftNode{
+		{ID: 2, Address: "2.2.2.2:666"},
+		{ID: 3, Address: "3.3.3.3:666"},
+	}
+	err = tx.RaftNodesReplace(nodes)
+	assert.NoError(t, err)
+
+	newNodes, err := tx.RaftNodes()
+	require.NoError(t, err)
+
+	assert.Equal(t, nodes, newNodes)
+}

From 97f17bc9f03c675494196ce09cf24d10e0bf476b Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Wed, 11 Oct 2017 13:34:20 +0000
Subject: [PATCH 14/14] Add node.DetermineRole function to figure what role a
 node plays

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/node/raft.go      | 60 +++++++++++++++++++++++++++++++++++++++
 lxd/node/raft_test.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 137 insertions(+)
 create mode 100644 lxd/node/raft.go
 create mode 100644 lxd/node/raft_test.go

diff --git a/lxd/node/raft.go b/lxd/node/raft.go
new file mode 100644
index 000000000..8b4605356
--- /dev/null
+++ b/lxd/node/raft.go
@@ -0,0 +1,60 @@
+package node
+
+import "github.com/lxc/lxd/lxd/db"
+
+// DetermineRaftNode figures out what raft node ID and address we have, if any.
+//
+// This decision is based on the values of the core.https_address config key
+// and on the rows in the raft_nodes table, both stored in the node-level
+// SQLite database.
+//
+// The following rules are applied:
+//
+// - If no core.https_address config key is set, this is a non-clustered node
+//   and the returned RaftNode will have ID 1 but no address, to signal that
+//   the node should setup an in-memory raft cluster where the node itself
+//   is the only member and leader.
+//
+// - If core.https_address config key is set, but there is no row in the
+//   raft_nodes table, this is a non-clustered node as well, and same behavior
+//   as the previous case applies.
+//
+// - If core.https_address config key is set and there is at least one row in
+//   the raft_nodes table, then this node is considered a raft node if
+//   core.https_address matches one of the rows in raft_nodes. In that case,
+//   the matching db.RaftNode row is returned, otherwise nil.
+func DetermineRaftNode(tx *db.NodeTx) (*db.RaftNode, error) {
+	config, err := ConfigLoad(tx)
+	if err != nil {
+		return nil, err
+	}
+
+	address := config.HTTPSAddress()
+
+	// If core.https_address is the empty string, then this LXD instance is
+	// not running in clustering mode.
+	if address == "" {
+		return &db.RaftNode{ID: 1}, nil
+	}
+
+	nodes, err := tx.RaftNodes()
+	if err != nil {
+		return nil, err
+	}
+
+	// If core.https_address is set, but raft_nodes has no rows, this is
+	// still an instance not running in clustering mode.
+	if len(nodes) == 0 {
+		return &db.RaftNode{ID: 1}, nil
+	}
+
+	// If there is one or more row in raft_nodes, try to find a matching
+	// one.
+	for _, node := range nodes {
+		if node.Address == address {
+			return &node, nil
+		}
+	}
+
+	return nil, nil
+}
diff --git a/lxd/node/raft_test.go b/lxd/node/raft_test.go
new file mode 100644
index 000000000..b376bdc3f
--- /dev/null
+++ b/lxd/node/raft_test.go
@@ -0,0 +1,77 @@
+package node_test
+
+import (
+	"testing"
+
+	"github.com/lxc/lxd/lxd/db"
+	"github.com/lxc/lxd/lxd/node"
+	"github.com/mpvl/subtest"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+// The raft identity (ID and address) of a node depends on the value of
+// core.https_address and the entries of the raft_nodes table.
+func TestDetermineRaftNode(t *testing.T) {
+	cases := []struct {
+		title     string
+		address   string       // Value of core.https_address
+		addresses []string     // Entries in raft_nodes
+		node      *db.RaftNode // Expected node value
+	}{
+		{
+			`no core.https_address set`,
+			"",
+			[]string{},
+			&db.RaftNode{ID: 1},
+		},
+		{
+			`core.https_address set and and no raft_nodes rows`,
+			"1.2.3.4:8443",
+			[]string{},
+			&db.RaftNode{ID: 1},
+		},
+		{
+			`core.https_address set and matching the one and only raft_nodes row`,
+			"1.2.3.4:8443",
+			[]string{"1.2.3.4:8443"},
+			&db.RaftNode{ID: 1, Address: "1.2.3.4:8443"},
+		},
+		{
+			`core.https_address set and matching one of many raft_nodes rows`,
+			"5.6.7.8:999",
+			[]string{"1.2.3.4:666", "5.6.7.8:999"},
+			&db.RaftNode{ID: 2, Address: "5.6.7.8:999"},
+		},
+		{
+			`core.https_address set and no matching raft_nodes row`,
+			"1.2.3.4:666",
+			[]string{"5.6.7.8:999"},
+			nil,
+		},
+	}
+
+	for _, c := range cases {
+		subtest.Run(t, c.title, func(t *testing.T) {
+			tx, cleanup := db.NewTestNodeTx(t)
+			defer cleanup()
+
+			err := tx.UpdateConfig(map[string]string{"core.https_address": c.address})
+			require.NoError(t, err)
+
+			for _, address := range c.addresses {
+				_, err := tx.RaftNodeAdd(address)
+				require.NoError(t, err)
+			}
+
+			node, err := node.DetermineRaftNode(tx)
+			require.NoError(t, err)
+			if c.node == nil {
+				assert.Nil(t, node)
+			} else {
+				assert.Equal(t, c.node.ID, node.ID)
+				assert.Equal(t, c.node.Address, node.Address)
+			}
+		})
+	}
+}


More information about the lxc-devel mailing list