[lxc-devel] [lxd/master] Database logic cleanup (part 1)

freeekanayaka on Github lxc-bot at linuxcontainers.org
Fri May 8 12:15:58 UTC 2020


A non-text attachment was scrubbed...
Name: not available
Type: text/x-mailbox
Size: 301 bytes
Desc: not available
URL: <http://lists.linuxcontainers.org/pipermail/lxc-devel/attachments/20200508/eba3865d/attachment.bin>
-------------- next part --------------
From 19b15efbf341d07c036fc7086a91f384dcbeca93 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Fri, 8 May 2020 10:23:09 +0100
Subject: [PATCH 1/3] shared/generate/db: Fix generation of Exists method

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 shared/generate/db/method.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/shared/generate/db/method.go b/shared/generate/db/method.go
index 2eb80ce77a..3b6ad84c4b 100644
--- a/shared/generate/db/method.go
+++ b/shared/generate/db/method.go
@@ -609,7 +609,7 @@ func (m *Method) exists(buf *file.Buffer) error {
 	m.begin(buf, comment, args, rets)
 	defer m.end(buf)
 
-	buf.L("_, err := c.%sID(%s)", lex.Camel(m.entity), FieldParams(nk))
+	buf.L("_, err := c.Get%sID(%s)", lex.Camel(m.entity), FieldParams(nk))
 	buf.L("if err != nil {")
 	buf.L("        if err == ErrNoSuchObject {")
 	buf.L("                return false, nil")

From 79e1d55773a55238674eeb5acdc6b61bda63aaf1 Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Fri, 8 May 2020 12:25:25 +0100
Subject: [PATCH 2/3] lxd/db: Make generated code stable across "make
 update-schema" runs

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/api_internal.go                 |  4 +-
 lxd/db/containers.go                |  6 +--
 lxd/db/instances.mapper.go          | 74 ++++++++++++++---------------
 lxd/instance.go                     |  2 +-
 lxd/instance/drivers/driver_lxc.go  |  2 +-
 lxd/instance/drivers/driver_qemu.go |  2 +-
 shared/generate/db/method.go        |  2 +-
 shared/generate/db/parse.go         | 31 +++++++++---
 8 files changed, 71 insertions(+), 52 deletions(-)

diff --git a/lxd/api_internal.go b/lxd/api_internal.go
index 1cf6ee7a7b..c0954c8e8d 100644
--- a/lxd/api_internal.go
+++ b/lxd/api_internal.go
@@ -584,7 +584,7 @@ func internalImport(d *Daemon, r *http.Request) response.Response {
 
 	if instanceErr == nil {
 		// Remove the storage volume db entry for the instance since force was specified.
-		err := d.cluster.RemoveInstance(projectName, req.Name)
+		err := d.cluster.DeleteInstance(projectName, req.Name)
 		if err != nil {
 			return response.SmartError(err)
 		}
@@ -690,7 +690,7 @@ func internalImport(d *Daemon, r *http.Request) response.Response {
 		}
 
 		if snapErr == nil {
-			err := d.cluster.RemoveInstance(projectName, snap.Name)
+			err := d.cluster.DeleteInstance(projectName, snap.Name)
 			if err != nil {
 				return response.SmartError(err)
 			}
diff --git a/lxd/db/containers.go b/lxd/db/containers.go
index 4f8a56570a..f107c22b81 100644
--- a/lxd/db/containers.go
+++ b/lxd/db/containers.go
@@ -599,8 +599,8 @@ func (c *ClusterTx) configUpdate(id int, values map[string]string, insertSQL, de
 	return nil
 }
 
-// RemoveInstance removes the instance with the given name from the database.
-func (c *Cluster) RemoveInstance(project, name string) error {
+// DeleteInstance removes the instance with the given name from the database.
+func (c *Cluster) DeleteInstance(project, name string) error {
 	if strings.Contains(name, shared.SnapshotDelimiter) {
 		parts := strings.SplitN(name, shared.SnapshotDelimiter, 2)
 		return c.Transaction(func(tx *ClusterTx) error {
@@ -608,7 +608,7 @@ func (c *Cluster) RemoveInstance(project, name string) error {
 		})
 	}
 	return c.Transaction(func(tx *ClusterTx) error {
-		return tx.RemoveInstance(project, name)
+		return tx.DeleteInstance(project, name)
 	})
 }
 
diff --git a/lxd/db/instances.mapper.go b/lxd/db/instances.mapper.go
index cda25bea17..13ad93223f 100644
--- a/lxd/db/instances.mapper.go
+++ b/lxd/db/instances.mapper.go
@@ -235,11 +235,11 @@ func (c *ClusterTx) GetInstances(filter InstanceFilter) ([]Instance, error) {
 			filter.Node,
 			filter.Name,
 		}
-	} else if criteria["Project"] != nil && criteria["Name"] != nil && criteria["Node"] != nil {
-		stmt = c.stmt(instanceObjectsByProjectAndNameAndNode)
+	} else if criteria["Project"] != nil && criteria["Type"] != nil && criteria["Node"] != nil {
+		stmt = c.stmt(instanceObjectsByProjectAndTypeAndNode)
 		args = []interface{}{
 			filter.Project,
-			filter.Name,
+			filter.Type,
 			filter.Node,
 		}
 	} else if criteria["Project"] != nil && criteria["Type"] != nil && criteria["Name"] != nil {
@@ -256,18 +256,24 @@ func (c *ClusterTx) GetInstances(filter InstanceFilter) ([]Instance, error) {
 			filter.Name,
 			filter.Node,
 		}
-	} else if criteria["Project"] != nil && criteria["Type"] != nil && criteria["Node"] != nil {
-		stmt = c.stmt(instanceObjectsByProjectAndTypeAndNode)
+	} else if criteria["Project"] != nil && criteria["Name"] != nil && criteria["Node"] != nil {
+		stmt = c.stmt(instanceObjectsByProjectAndNameAndNode)
 		args = []interface{}{
 			filter.Project,
-			filter.Type,
+			filter.Name,
 			filter.Node,
 		}
-	} else if criteria["Project"] != nil && criteria["Name"] != nil {
-		stmt = c.stmt(instanceObjectsByProjectAndName)
+	} else if criteria["Project"] != nil && criteria["Type"] != nil {
+		stmt = c.stmt(instanceObjectsByProjectAndType)
 		args = []interface{}{
 			filter.Project,
-			filter.Name,
+			filter.Type,
+		}
+	} else if criteria["Type"] != nil && criteria["Node"] != nil {
+		stmt = c.stmt(instanceObjectsByTypeAndNode)
+		args = []interface{}{
+			filter.Type,
+			filter.Node,
 		}
 	} else if criteria["Project"] != nil && criteria["Node"] != nil {
 		stmt = c.stmt(instanceObjectsByProjectAndNode)
@@ -281,11 +287,11 @@ func (c *ClusterTx) GetInstances(filter InstanceFilter) ([]Instance, error) {
 			filter.Type,
 			filter.Name,
 		}
-	} else if criteria["Project"] != nil && criteria["Type"] != nil {
-		stmt = c.stmt(instanceObjectsByProjectAndType)
+	} else if criteria["Project"] != nil && criteria["Name"] != nil {
+		stmt = c.stmt(instanceObjectsByProjectAndName)
 		args = []interface{}{
 			filter.Project,
-			filter.Type,
+			filter.Name,
 		}
 	} else if criteria["Node"] != nil && criteria["Name"] != nil {
 		stmt = c.stmt(instanceObjectsByNodeAndName)
@@ -293,11 +299,10 @@ func (c *ClusterTx) GetInstances(filter InstanceFilter) ([]Instance, error) {
 			filter.Node,
 			filter.Name,
 		}
-	} else if criteria["Type"] != nil && criteria["Node"] != nil {
-		stmt = c.stmt(instanceObjectsByTypeAndNode)
+	} else if criteria["Type"] != nil {
+		stmt = c.stmt(instanceObjectsByType)
 		args = []interface{}{
 			filter.Type,
-			filter.Node,
 		}
 	} else if criteria["Project"] != nil {
 		stmt = c.stmt(instanceObjectsByProject)
@@ -309,11 +314,6 @@ func (c *ClusterTx) GetInstances(filter InstanceFilter) ([]Instance, error) {
 		args = []interface{}{
 			filter.Node,
 		}
-	} else if criteria["Type"] != nil {
-		stmt = c.stmt(instanceObjectsByType)
-		args = []interface{}{
-			filter.Type,
-		}
 	} else if criteria["Name"] != nil {
 		stmt = c.stmt(instanceObjectsByName)
 		args = []interface{}{
@@ -595,16 +595,16 @@ func (c *ClusterTx) InstanceProfilesRef(filter InstanceFilter) (map[string]map[s
 			filter.Project,
 			filter.Name,
 		}
-	} else if criteria["Node"] != nil {
-		stmt = c.stmt(instanceProfilesRefByNode)
-		args = []interface{}{
-			filter.Node,
-		}
 	} else if criteria["Project"] != nil {
 		stmt = c.stmt(instanceProfilesRefByProject)
 		args = []interface{}{
 			filter.Project,
 		}
+	} else if criteria["Node"] != nil {
+		stmt = c.stmt(instanceProfilesRefByNode)
+		args = []interface{}{
+			filter.Node,
+		}
 	} else {
 		stmt = c.stmt(instanceProfilesRef)
 		args = []interface{}{}
@@ -686,16 +686,16 @@ func (c *ClusterTx) InstanceConfigRef(filter InstanceFilter) (map[string]map[str
 			filter.Project,
 			filter.Name,
 		}
-	} else if criteria["Node"] != nil {
-		stmt = c.stmt(instanceConfigRefByNode)
-		args = []interface{}{
-			filter.Node,
-		}
 	} else if criteria["Project"] != nil {
 		stmt = c.stmt(instanceConfigRefByProject)
 		args = []interface{}{
 			filter.Project,
 		}
+	} else if criteria["Node"] != nil {
+		stmt = c.stmt(instanceConfigRefByNode)
+		args = []interface{}{
+			filter.Node,
+		}
 	} else {
 		stmt = c.stmt(instanceConfigRef)
 		args = []interface{}{}
@@ -782,16 +782,16 @@ func (c *ClusterTx) InstanceDevicesRef(filter InstanceFilter) (map[string]map[st
 			filter.Project,
 			filter.Name,
 		}
-	} else if criteria["Node"] != nil {
-		stmt = c.stmt(instanceDevicesRefByNode)
-		args = []interface{}{
-			filter.Node,
-		}
 	} else if criteria["Project"] != nil {
 		stmt = c.stmt(instanceDevicesRefByProject)
 		args = []interface{}{
 			filter.Project,
 		}
+	} else if criteria["Node"] != nil {
+		stmt = c.stmt(instanceDevicesRefByNode)
+		args = []interface{}{
+			filter.Node,
+		}
 	} else {
 		stmt = c.stmt(instanceDevicesRef)
 		args = []interface{}{}
@@ -878,8 +878,8 @@ func (c *ClusterTx) RenameInstance(project string, name string, to string) error
 	return nil
 }
 
-// RemoveInstance deletes the instance matching the given key parameters.
-func (c *ClusterTx) RemoveInstance(project string, name string) error {
+// DeleteInstance deletes the instance matching the given key parameters.
+func (c *ClusterTx) DeleteInstance(project string, name string) error {
 	stmt := c.stmt(instanceDelete)
 	result, err := stmt.Exec(project, name)
 	if err != nil {
diff --git a/lxd/instance.go b/lxd/instance.go
index 30fa93b606..88d6292bfa 100644
--- a/lxd/instance.go
+++ b/lxd/instance.go
@@ -625,7 +625,7 @@ func instanceCreateInternal(s *state.State, args db.InstanceArgs) (instance.Inst
 			return
 		}
 
-		s.Cluster.RemoveInstance(dbInst.Project, dbInst.Name)
+		s.Cluster.DeleteInstance(dbInst.Project, dbInst.Name)
 	}()
 
 	// Wipe any existing log for this instance name.
diff --git a/lxd/instance/drivers/driver_lxc.go b/lxd/instance/drivers/driver_lxc.go
index 53bbd4cbc7..53346846ad 100644
--- a/lxd/instance/drivers/driver_lxc.go
+++ b/lxd/instance/drivers/driver_lxc.go
@@ -3519,7 +3519,7 @@ func (c *lxc) Delete() error {
 	}
 
 	// Remove the database record of the instance or snapshot instance.
-	if err := c.state.Cluster.RemoveInstance(c.project, c.Name()); err != nil {
+	if err := c.state.Cluster.DeleteInstance(c.project, c.Name()); err != nil {
 		logger.Error("Failed deleting container entry", log.Ctx{"name": c.Name(), "err": err})
 		return err
 	}
diff --git a/lxd/instance/drivers/driver_qemu.go b/lxd/instance/drivers/driver_qemu.go
index 51fdb5965e..040518d37c 100644
--- a/lxd/instance/drivers/driver_qemu.go
+++ b/lxd/instance/drivers/driver_qemu.go
@@ -2970,7 +2970,7 @@ func (vm *qemu) Delete() error {
 	}
 
 	// Remove the database record of the instance or snapshot instance.
-	if err := vm.state.Cluster.RemoveInstance(vm.Project(), vm.Name()); err != nil {
+	if err := vm.state.Cluster.DeleteInstance(vm.Project(), vm.Name()); err != nil {
 		logger.Error("Failed deleting instance entry", log.Ctx{"project": vm.Project(), "instance": vm.Name(), "err": err})
 		return err
 	}
diff --git a/shared/generate/db/method.go b/shared/generate/db/method.go
index 3b6ad84c4b..0ddc903cd9 100644
--- a/shared/generate/db/method.go
+++ b/shared/generate/db/method.go
@@ -885,7 +885,7 @@ func (m *Method) begin(buf *file.Buffer, comment string, args string, rets strin
 	case "Update":
 		name = fmt.Sprintf("Update%s", entity)
 	case "Delete":
-		name = fmt.Sprintf("Remove%s", entity)
+		name = fmt.Sprintf("Delete%s", entity)
 	default:
 		name = fmt.Sprintf("%s%s", entity, m.kind)
 	}
diff --git a/shared/generate/db/parse.go b/shared/generate/db/parse.go
index 7b38733a9b..c82691c1fd 100644
--- a/shared/generate/db/parse.go
+++ b/shared/generate/db/parse.go
@@ -51,11 +51,7 @@ func Filters(pkg *ast.Package, entity string) [][]string {
 		filters = append(filters, strings.Split(rest, "And"))
 	}
 
-	sort.SliceStable(filters, func(i, j int) bool {
-		return len(filters[i]) > len(filters[j])
-	})
-
-	return filters
+	return sortFilters(filters)
 }
 
 // RefFilters parses all filtering statement defined for the given entity reference.
@@ -73,13 +69,36 @@ func RefFilters(pkg *ast.Package, entity string, ref string) [][]string {
 		filters = append(filters, strings.Split(rest, "And"))
 	}
 
+	return sortFilters(filters)
+}
+
+func sortFilters(filters [][]string) [][]string {
 	sort.SliceStable(filters, func(i, j int) bool {
-		return len(filters[i]) > len(filters[j])
+		n1 := len(filters[i])
+		n2 := len(filters[j])
+		if n1 != len(filters[j]) {
+			return n1 > n2
+		}
+		for k := range filters[i] {
+			f1 := sortFilter(filters[i])
+			f2 := sortFilter(filters[j])
+			if f1[k] > f2[k] {
+				return true
+			}
+		}
+		return false
 	})
 
 	return filters
 }
 
+func sortFilter(filter []string) []string {
+	f := make([]string, len(filter))
+	copy(f, filter)
+	sort.Sort(sort.StringSlice(f))
+	return f
+}
+
 // Criteria returns a list of criteria
 func Criteria(pkg *ast.Package, entity string) ([]string, error) {
 	name := fmt.Sprintf("%sFilter", lex.Camel(entity))

From ac24b8823e19d7d316105213fb8605fa33b0ab4c Mon Sep 17 00:00:00 2001
From: Free Ekanayaka <free.ekanayaka at canonical.com>
Date: Fri, 8 May 2020 10:06:34 +0100
Subject: [PATCH 3/3] lxd/db: Leverage code-generation for certificates

Signed-off-by: Free Ekanayaka <free.ekanayaka at canonical.com>
---
 lxd/db/certificates.go        | 118 ++++++++--------------
 lxd/db/certificates.mapper.go | 182 ++++++++++++++++++++++++++++++++++
 2 files changed, 221 insertions(+), 79 deletions(-)
 create mode 100644 lxd/db/certificates.mapper.go

diff --git a/lxd/db/certificates.go b/lxd/db/certificates.go
index 40835e24a1..bbb0c5b1cd 100644
--- a/lxd/db/certificates.go
+++ b/lxd/db/certificates.go
@@ -2,51 +2,53 @@
 
 package db
 
-import (
-	"database/sql"
-)
-
-// CertInfo is here to pass the certificates content
+// Code generation directives.
+//
+//go:generate -command mapper lxd-generate db mapper -t certificates.mapper.go
+//go:generate mapper reset
+//
+//go:generate mapper stmt -p db -e certificate objects
+//go:generate mapper stmt -p db -e certificate objects-by-Fingerprint
+//go:generate mapper stmt -p db -e certificate id
+//go:generate mapper stmt -p db -e certificate create struct=Certificate
+//
+//go:generate mapper method -p db -e certificate List
+//go:generate mapper method -p db -e certificate Get
+//go:generate mapper method -p db -e certificate ID struct=Certificate
+//go:generate mapper method -p db -e certificate Exists struct=Certificate
+//go:generate mapper method -p db -e certificate Create struct=Certificate
+
+// Certificate is here to pass the certificates content
 // from the database around
-type CertInfo struct {
+type Certificate struct {
 	ID          int
-	Fingerprint string
+	Fingerprint string `db:"primary=yes"`
 	Type        int
 	Name        string
 	Certificate string
 }
 
+type CertInfo = Certificate
+
+// CertInfo can be used to filter results yielded by GetCertInfos
+type CertificateFilter struct {
+	Fingerprint string
+}
+
 // GetCertificates returns all certificates from the DB as CertBaseInfo objects.
 func (c *Cluster) GetCertificates() (certs []*CertInfo, err error) {
 	err = c.Transaction(func(tx *ClusterTx) error {
-		rows, err := tx.tx.Query(
-			"SELECT id, fingerprint, type, name, certificate FROM certificates",
-		)
+		certificates, err := tx.GetCertificates(CertificateFilter{})
 		if err != nil {
 			return err
 		}
-
-		defer rows.Close()
-
-		for rows.Next() {
-			cert := new(CertInfo)
-			rows.Scan(
-				&cert.ID,
-				&cert.Fingerprint,
-				&cert.Type,
-				&cert.Name,
-				&cert.Certificate,
-			)
-			certs = append(certs, cert)
+		certs = make([]*CertInfo, len(certificates))
+		for i := range certificates {
+			certs[i] = &certificates[i]
 		}
-
-		return rows.Err()
+		return nil
 	})
-	if err != nil {
-		return certs, err
-	}
-
-	return certs, nil
+	return
 }
 
 // GetCertificate gets an CertBaseInfo object from the database.
@@ -55,61 +57,19 @@ func (c *Cluster) GetCertificates() (certs []*CertInfo, err error) {
 // There can never be more than one image with a given fingerprint, as it is
 // enforced by a UNIQUE constraint in the schema.
 func (c *Cluster) GetCertificate(fingerprint string) (cert *CertInfo, err error) {
-	cert = new(CertInfo)
-
-	inargs := []interface{}{fingerprint + "%"}
-	outfmt := []interface{}{
-		&cert.ID,
-		&cert.Fingerprint,
-		&cert.Type,
-		&cert.Name,
-		&cert.Certificate,
-	}
-
-	query := `
-		SELECT
-			id, fingerprint, type, name, certificate
-		FROM
-			certificates
-		WHERE fingerprint LIKE ?`
-
-	if err = dbQueryRowScan(c.db, query, inargs, outfmt); err != nil {
-		if err == sql.ErrNoRows {
-			return nil, ErrNoSuchObject
-		}
-
-		return nil, err
-	}
-
-	return cert, err
+	err = c.Transaction(func(tx *ClusterTx) error {
+		cert, err = tx.GetCertificate(fingerprint)
+		return err
+	})
+	return
 }
 
 // CreateCertificate stores a CertInfo object in the db, it will ignore the ID
 // field from the CertInfo.
 func (c *Cluster) CreateCertificate(cert *CertInfo) error {
 	err := c.Transaction(func(tx *ClusterTx) error {
-		stmt, err := tx.tx.Prepare(`
-			INSERT INTO certificates (
-				fingerprint,
-				type,
-				name,
-				certificate
-			) VALUES (?, ?, ?, ?)`,
-		)
-		if err != nil {
-			return err
-		}
-		defer stmt.Close()
-		_, err = stmt.Exec(
-			cert.Fingerprint,
-			cert.Type,
-			cert.Name,
-			cert.Certificate,
-		)
-		if err != nil {
-			return err
-		}
-		return nil
+		_, err := tx.CreateCertificate(*cert)
+		return err
 	})
 	return err
 }
diff --git a/lxd/db/certificates.mapper.go b/lxd/db/certificates.mapper.go
new file mode 100644
index 0000000000..5a0acdc486
--- /dev/null
+++ b/lxd/db/certificates.mapper.go
@@ -0,0 +1,182 @@
+// +build linux,cgo,!agent
+
+package db
+
+// The code below was generated by lxd-generate - DO NOT EDIT!
+
+import (
+	"database/sql"
+	"fmt"
+	"github.com/lxc/lxd/lxd/db/cluster"
+	"github.com/lxc/lxd/lxd/db/query"
+	"github.com/lxc/lxd/shared/api"
+	"github.com/pkg/errors"
+)
+
+var _ = api.ServerEnvironment{}
+
+var certificateObjects = cluster.RegisterStmt(`
+SELECT certificates.id, certificates.fingerprint, certificates.type, certificates.name, certificates.certificate
+  FROM certificates
+  ORDER BY certificates.fingerprint
+`)
+
+var certificateObjectsByFingerprint = cluster.RegisterStmt(`
+SELECT certificates.id, certificates.fingerprint, certificates.type, certificates.name, certificates.certificate
+  FROM certificates
+  WHERE certificates.fingerprint = ? ORDER BY certificates.fingerprint
+`)
+
+var certificateID = cluster.RegisterStmt(`
+SELECT certificates.id FROM certificates
+  WHERE certificates.fingerprint = ?
+`)
+
+var certificateCreate = cluster.RegisterStmt(`
+INSERT INTO certificates (fingerprint, type, name, certificate)
+  VALUES (?, ?, ?, ?)
+`)
+
+// GetCertificates returns all available certificates.
+func (c *ClusterTx) GetCertificates(filter CertificateFilter) ([]Certificate, error) {
+	// Result slice.
+	objects := make([]Certificate, 0)
+
+	// Check which filter criteria are active.
+	criteria := map[string]interface{}{}
+	if filter.Fingerprint != "" {
+		criteria["Fingerprint"] = filter.Fingerprint
+	}
+
+	// Pick the prepared statement and arguments to use based on active criteria.
+	var stmt *sql.Stmt
+	var args []interface{}
+
+	if criteria["Fingerprint"] != nil {
+		stmt = c.stmt(certificateObjectsByFingerprint)
+		args = []interface{}{
+			filter.Fingerprint,
+		}
+	} else {
+		stmt = c.stmt(certificateObjects)
+		args = []interface{}{}
+	}
+
+	// Dest function for scanning a row.
+	dest := func(i int) []interface{} {
+		objects = append(objects, Certificate{})
+		return []interface{}{
+			&objects[i].ID,
+			&objects[i].Fingerprint,
+			&objects[i].Type,
+			&objects[i].Name,
+			&objects[i].Certificate,
+		}
+	}
+
+	// Select.
+	err := query.SelectObjects(stmt, dest, args...)
+	if err != nil {
+		return nil, errors.Wrap(err, "Failed to fetch certificates")
+	}
+
+	return objects, nil
+}
+
+// GetCertificate returns the certificate with the given key.
+func (c *ClusterTx) GetCertificate(fingerprint string) (*Certificate, error) {
+	filter := CertificateFilter{}
+	filter.Fingerprint = fingerprint
+
+	objects, err := c.GetCertificates(filter)
+	if err != nil {
+		return nil, errors.Wrap(err, "Failed to fetch Certificate")
+	}
+
+	switch len(objects) {
+	case 0:
+		return nil, ErrNoSuchObject
+	case 1:
+		return &objects[0], nil
+	default:
+		return nil, fmt.Errorf("More than one certificate matches")
+	}
+}
+
+// GetCertificateID return the ID of the certificate with the given key.
+func (c *ClusterTx) GetCertificateID(fingerprint string) (int64, error) {
+	stmt := c.stmt(certificateID)
+	rows, err := stmt.Query(fingerprint)
+	if err != nil {
+		return -1, errors.Wrap(err, "Failed to get certificate ID")
+	}
+	defer rows.Close()
+
+	// For sanity, make sure we read one and only one row.
+	if !rows.Next() {
+		return -1, ErrNoSuchObject
+	}
+	var id int64
+	err = rows.Scan(&id)
+	if err != nil {
+		return -1, errors.Wrap(err, "Failed to scan ID")
+	}
+	if rows.Next() {
+		return -1, fmt.Errorf("More than one row returned")
+	}
+	err = rows.Err()
+	if err != nil {
+		return -1, errors.Wrap(err, "Result set failure")
+	}
+
+	return id, nil
+}
+
+// CertificateExists checks if a certificate with the given key exists.
+func (c *ClusterTx) CertificateExists(fingerprint string) (bool, error) {
+	_, err := c.GetCertificateID(fingerprint)
+	if err != nil {
+		if err == ErrNoSuchObject {
+			return false, nil
+		}
+		return false, err
+	}
+
+	return true, nil
+}
+
+// CreateCertificate adds a new certificate to the database.
+func (c *ClusterTx) CreateCertificate(object Certificate) (int64, error) {
+	// Check if a certificate with the same key exists.
+	exists, err := c.CertificateExists(object.Fingerprint)
+	if err != nil {
+		return -1, errors.Wrap(err, "Failed to check for duplicates")
+	}
+	if exists {
+		return -1, fmt.Errorf("This certificate already exists")
+	}
+
+	args := make([]interface{}, 4)
+
+	// Populate the statement arguments.
+	args[0] = object.Fingerprint
+	args[1] = object.Type
+	args[2] = object.Name
+	args[3] = object.Certificate
+
+	// Prepared statement to use.
+	stmt := c.stmt(certificateCreate)
+
+	// Execute the statement.
+	result, err := stmt.Exec(args...)
+	if err != nil {
+		return -1, errors.Wrap(err, "Failed to create certificate")
+	}
+
+	id, err := result.LastInsertId()
+	if err != nil {
+		return -1, errors.Wrap(err, "Failed to fetch certificate ID")
+	}
+
+	return id, nil
+}


More information about the lxc-devel mailing list