Skip to content

Commit

Permalink
Use generics to simplify the TPMDirect interface (#310)
Browse files Browse the repository at this point in the history
This is unfortunately a large change, but I think it does a lot for the ergonomics of the TPMDirect API.

This change uses the new Go 1.18 generics to solve a few problems:

- We want people to be able to provide a flat `[]byte` or actual structure when instantiating TPM2Bs
- We want to avoid people directly manipulating pointer values in the TPMUs or having their TPMUs in an invalid state
- We want a nice Marshal and Unmarshal function (and later, to be able to make a nice Compare function, see #309 )

Generics to the rescue. Here's what this commit does:

- Add a new file called `marshalling.go` that handles a lot of the high level marshalling work. `reflect.go` is still the dirty reflection guts of the library
- Embed a new type called `marshalByReflection` into all the structs that can be marshalled by reflection, as a clear hint to the reflection library
- Add a new interface called `UnmarshallableWithHint` - most of the TPMU implement this, and the old `marshalUnion` and `unmarshalUnion` functions are gone now
- Bonus: I noticed using profiling that the `tags` function was allocating several orders of magnitude more memory than the rest of the library, so I rewrote it
- Introduced a generic TPM2B helper that is aliased by the concrete TPM2B types; there are constructors for instantiating TPM2B from data or structured contents.
- TPMU is public, with private fields. Introduced constructors for these with type constraints.

Fixes #307 and #292.
  • Loading branch information
chrisfenner authored Feb 9, 2023
1 parent b827cbb commit d68ba33
Show file tree
Hide file tree
Showing 33 changed files with 3,200 additions and 1,652 deletions.
1 change: 1 addition & 0 deletions .cirrus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ lint_task:
--exclude-use-default=false
--exclude stutters
--exclude underscores
--exclude unexported-return
--max-same-issues=0
--max-issues-per-linter=0
./tpmutil/...
Expand Down
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ module github.com/google/go-tpm
go 1.18

require (
github.com/google/go-cmp v0.5.0
github.com/google/go-tpm-tools v0.2.0
golang.org/x/sys v0.0.0-20210629170331-7dc0b73dc9fb
github.com/google/go-cmp v0.5.7
github.com/google/go-tpm-tools v0.3.10
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f
)

require golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 // indirect
require golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
193 changes: 15 additions & 178 deletions go.sum

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions tpm2/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ func NewAudit(hash TPMIAlgHash) (*CommandAudit, error) {
}, nil
}

// Extend extends the audit digest with the given command and response.
func (a *CommandAudit) Extend(cmd Command, rsp Response) error {
cpHash, err := auditCPHash(a.hash, cmd)
// AuditCommand extends the audit digest with the given command and response.
// Go Generics do not allow type parameters on methods, otherwise this would be
// a method on CommandAudit.
// See https://github.com/golang/go/issues/49085 for more information.
func AuditCommand[C Command[R, *R], R any](a *CommandAudit, cmd C, rsp *R) error {
cc := cmd.Command()
cpHash, err := auditCPHash[R](cc, a.hash, cmd)
if err != nil {
return err
}
rpHash, err := auditRPHash(a.hash, rsp)
rpHash, err := auditRPHash(cc, a.hash, rsp)
if err != nil {
return err
}
Expand All @@ -56,8 +60,7 @@ func (a *CommandAudit) Digest() []byte {
// auditCPHash calculates the command parameter hash for a given command with
// the given hash algorithm. The command is assumed to not have any decrypt
// sessions.
func auditCPHash(h TPMIAlgHash, c Command) ([]byte, error) {
cc := c.Command()
func auditCPHash[R any](cc TPMCC, h TPMIAlgHash, c Command[R, *R]) ([]byte, error) {
names, err := cmdNames(c)
if err != nil {
return nil, err
Expand All @@ -72,13 +75,12 @@ func auditCPHash(h TPMIAlgHash, c Command) ([]byte, error) {
// auditRPHash calculates the response parameter hash for a given response with
// the given hash algorithm. The command is assumed to be successful and to not
// have any encrypt sessions.
func auditRPHash(h TPMIAlgHash, r Response) ([]byte, error) {
cc := r.Response()
func auditRPHash(cc TPMCC, h TPMIAlgHash, r any) ([]byte, error) {
var parms bytes.Buffer
parameters := taggedMembers(reflect.ValueOf(r).Elem(), "handle", true)
for i, parameter := range parameters {
if err := marshal(&parms, parameter); err != nil {
return nil, fmt.Errorf("marshalling parameter %v: %w", i, err)
return nil, fmt.Errorf("marshalling parameter %v: %w", i+1, err)
}
}
return rpHash(h, TPMRCSuccess, cc, parms.Bytes())
Expand Down
128 changes: 128 additions & 0 deletions tpm2/marshalling.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package tpm2

import (
"bytes"
"fmt"
"reflect"
)

// Marshallable represents any TPM type that can be marshalled.
type Marshallable interface {
// marshal will serialize the given value, appending onto the given buffer.
// Returns an error if the value is not marshallable.
marshal(buf *bytes.Buffer)
}

// marshallableWithHint represents any TPM type that can be marshalled,
// but that requires a selector ("hint") value when marshalling. Most TPMU_ are
// an example of this.
type marshallableWithHint interface {
// get will return the corresponding union member by copy. If the union is
// uninitialized, it will initialize a new zero-valued one.
get(hint int64) (reflect.Value, error)
}

// Unmarshallable represents any TPM type that can be marshalled or unmarshalled.
type Unmarshallable interface {
Marshallable
// marshal will deserialize the given value from the given buffer.
// Returns an error if there was an unmarshalling error or if there was not
// enough data in the buffer.
unmarshal(buf *bytes.Buffer) error
}

// unmarshallableWithHint represents any TPM type that can be marshalled or unmarshalled,
// but that requires a selector ("hint") value when unmarshalling. Most TPMU_ are
// an example of this.
type unmarshallableWithHint interface {
marshallableWithHint
// create will instantiate and return the corresponding union member.
create(hint int64) (reflect.Value, error)
}

// Marshal will serialize the given values, returning them as a byte slice.
func Marshal(v Marshallable) []byte {
var buf bytes.Buffer
if err := marshal(&buf, reflect.ValueOf(v)); err != nil {
panic(fmt.Sprintf("unexpected error marshalling %v: %v", reflect.TypeOf(v).Name(), err))
}
return buf.Bytes()
}

// Unmarshal unmarshals the given type from the byte array.
// Returns an error if the buffer does not contain enough data to satisfy the
// types, or if the types are not unmarshallable.
func Unmarshal[T Marshallable, P interface {
*T
Unmarshallable
}](data []byte) (*T, error) {
buf := bytes.NewBuffer(data)
var t T
value := reflect.New(reflect.TypeOf(t))
if err := unmarshal(buf, value.Elem()); err != nil {
return nil, err
}
return value.Interface().(*T), nil
}

// marshallableByReflection is a placeholder interface, to hint to the unmarshalling
// library that it is supposed to use reflection.
type marshallableByReflection interface {
reflectionSafe()
}

// marshalByReflection is embedded into any type that can be marshalled by reflection,
// needing no custom logic.
type marshalByReflection struct{}

func (marshalByReflection) reflectionSafe() {}

// These placeholders are required because a type constraint cannot union another interface
// that contains methods.
// Otherwise, marshalByReflection would not implement Unmarshallable, and the Marshal/Unmarshal
// functions would accept interface{ Marshallable | marshallableByReflection } instead.

// Placeholder: because this type implements the defaultMarshallable interface,
// the reflection library knows not to call this.
func (marshalByReflection) marshal(_ *bytes.Buffer) {
panic("not implemented")
}

// Placeholder: because this type implements the defaultMarshallable interface,
// the reflection library knows not to call this.
func (*marshalByReflection) unmarshal(_ *bytes.Buffer) error {
panic("not implemented")
}

// boxed is a helper type for corner cases such as unions, where all members must be structs.
type boxed[T any] struct {
Contents *T
}

// box will put a value into a box.
func box[T any](contents *T) boxed[T] {
return boxed[T]{
Contents: contents,
}
}

// unbox will take a value out of a box.
func (b *boxed[T]) unbox() *T {
return b.Contents
}

// marshal implements the Marshallable interface.
func (b *boxed[T]) marshal(buf *bytes.Buffer) {
if b.Contents == nil {
var contents T
marshal(buf, reflect.ValueOf(&contents))
} else {
marshal(buf, reflect.ValueOf(b.Contents))
}
}

// unmarshal implements the Unmarshallable interface.
func (b *boxed[T]) unmarshal(buf *bytes.Buffer) error {
b.Contents = new(T)
return unmarshal(buf, reflect.ValueOf(b.Contents))
}
156 changes: 156 additions & 0 deletions tpm2/marshalling_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package tpm2

import (
"bytes"
"testing"
)

func TestMarshal2B(t *testing.T) {
// Define some TPMT_Public
pub := TPMTPublic{
Type: TPMAlgKeyedHash,
NameAlg: TPMAlgSHA256,
ObjectAttributes: TPMAObject{
FixedTPM: true,
FixedParent: true,
UserWithAuth: true,
NoDA: true,
},
}

// Get the wire-format version
pubBytes := Marshal(pub)

// Create two versions of the same 2B:
// one instantiated by the actual TPMTPublic
// one instantiated by the contents
var boxed1 TPM2BPublic
var boxed2 TPM2BPublic
boxed1 = New2B(pub)
boxed2 = BytesAs2B[TPMTPublic](pubBytes)

boxed1Bytes := Marshal(boxed1)
boxed2Bytes := Marshal(boxed2)

if !bytes.Equal(boxed1Bytes, boxed2Bytes) {
t.Errorf("got %x want %x", boxed2Bytes, boxed1Bytes)
}

z, err := Unmarshal[TPM2BPublic](boxed1Bytes)
if err != nil {
t.Fatalf("could not unmarshal TPM2BPublic: %v", err)
}
t.Logf("%v", z)

boxed3Bytes := Marshal(z)
if !bytes.Equal(boxed1Bytes, boxed3Bytes) {
t.Errorf("got %x want %x", boxed3Bytes, boxed1Bytes)
}

// Make a nonsense 2B_Public, demonstrating that the library doesn't have to understand the serialization
BytesAs2B[TPMTPublic]([]byte{0xff})
}

func unwrap[T any](f func() (*T, error)) *T {
t, err := f()
if err != nil {
panic(err.Error())
}
return t
}

func TestMarshalT(t *testing.T) {
// Define some TPMT_Public
pub := TPMTPublic{
Type: TPMAlgECC,
NameAlg: TPMAlgSHA256,
ObjectAttributes: TPMAObject{
SignEncrypt: true,
},
Parameters: NewTPMUPublicParms(
TPMAlgECC,
&TPMSECCParms{
CurveID: TPMECCNistP256,
},
),
Unique: NewTPMUPublicID(
// This happens to be a P256 EKpub from the simulator
TPMAlgECC,
&TPMSECCPoint{
X: TPM2BECCParameter{},
Y: TPM2BECCParameter{},
},
),
}

// Marshal each component of the parameters
symBytes := Marshal(&unwrap(pub.Parameters.ECCDetail).Symmetric)
t.Logf("Symmetric: %x\n", symBytes)
sym, err := Unmarshal[TPMTSymDefObject](symBytes)
if err != nil {
t.Fatalf("could not unmarshal TPMTSymDefObject: %v", err)
}
symBytes2 := Marshal(sym)
if !bytes.Equal(symBytes, symBytes2) {
t.Errorf("want %x\ngot %x", symBytes, symBytes2)
}
schemeBytes := Marshal(&unwrap(pub.Parameters.ECCDetail).Scheme)
t.Logf("Scheme: %x\n", symBytes)
scheme, err := Unmarshal[TPMTECCScheme](schemeBytes)
if err != nil {
t.Fatalf("could not unmarshal TPMTECCScheme: %v", err)
}
schemeBytes2 := Marshal(scheme)
if !bytes.Equal(schemeBytes, schemeBytes2) {
t.Errorf("want %x\ngot %x", schemeBytes, schemeBytes2)
}
kdfBytes := Marshal(&unwrap(pub.Parameters.ECCDetail).KDF)
t.Logf("KDF: %x\n", kdfBytes)
kdf, err := Unmarshal[TPMTKDFScheme](kdfBytes)
if err != nil {
t.Fatalf("could not unmarshal TPMTKDFScheme: %v", err)
}
kdfBytes2 := Marshal(kdf)
if !bytes.Equal(kdfBytes, kdfBytes2) {
t.Errorf("want %x\ngot %x", kdfBytes, kdfBytes2)
}

// Marshal the parameters
parmsBytes := Marshal(unwrap(pub.Parameters.ECCDetail))
t.Logf("Parms: %x\n", parmsBytes)
parms, err := Unmarshal[TPMSECCParms](parmsBytes)
if err != nil {
t.Fatalf("could not unmarshal TPMSECCParms: %v", err)
}
parmsBytes2 := Marshal(parms)
if !bytes.Equal(parmsBytes, parmsBytes2) {
t.Errorf("want %x\ngot %x", parmsBytes, parmsBytes2)
}

// Marshal the unique area
uniqueBytes := Marshal(unwrap(pub.Unique.ECC))
t.Logf("Unique: %x\n", uniqueBytes)
unique, err := Unmarshal[TPMSECCPoint](uniqueBytes)
if err != nil {
t.Fatalf("could not unmarshal TPMSECCPoint: %v", err)
}
uniqueBytes2 := Marshal(unique)
if !bytes.Equal(uniqueBytes, uniqueBytes2) {
t.Errorf("want %x\ngot %x", uniqueBytes, uniqueBytes2)
}

// Get the wire-format version of the whole thing
pubBytes := Marshal(&pub)

pub2, err := Unmarshal[TPMTPublic](pubBytes)
if err != nil {
t.Fatalf("could not unmarshal TPMTPublic: %v", err)
}

// Some default fields might have been populated in the round-trip. Get the wire-format again and compare.
pub2Bytes := Marshal(pub2)

if !bytes.Equal(pubBytes, pub2Bytes) {
t.Errorf("want %x\ngot %x", pubBytes, pub2Bytes)
}
}
8 changes: 5 additions & 3 deletions tpm2/names.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package tpm2

import (
"bytes"
"encoding/binary"
"reflect"
)

// HandleName returns the TPM Name of a PCR, session, or permanent value
Expand Down Expand Up @@ -30,11 +32,11 @@ func objectOrNVName(alg TPMAlgID, pub interface{}) (*TPM2BName, error) {
// Calculate the hash of the entire Public contents and append it to the
// result.
ha := h.New()
marshalledPub, err := Marshal(pub)
if err != nil {
var buf bytes.Buffer
if err := marshal(&buf, reflect.ValueOf(pub)); err != nil {
return nil, err
}
ha.Write(marshalledPub)
ha.Write(buf.Bytes())
result = ha.Sum(result)

return &TPM2BName{
Expand Down
Loading

0 comments on commit d68ba33

Please sign in to comment.