From d68ba3356cd92210724c141baab40ccbea95e822 Mon Sep 17 00:00:00 2001 From: Chris Fenner Date: Thu, 9 Feb 2023 07:40:25 -0800 Subject: [PATCH] Use generics to simplify the TPMDirect interface (#310) This is unfortunately a large change, but I think it does a lot for the ergonomics of the TPMDirect API. This change uses the new Go 1.18 generics to solve a few problems: - We want people to be able to provide a flat `[]byte` or actual structure when instantiating TPM2Bs - We want to avoid people directly manipulating pointer values in the TPMUs or having their TPMUs in an invalid state - We want a nice Marshal and Unmarshal function (and later, to be able to make a nice Compare function, see #309 ) Generics to the rescue. Here's what this commit does: - Add a new file called `marshalling.go` that handles a lot of the high level marshalling work. `reflect.go` is still the dirty reflection guts of the library - Embed a new type called `marshalByReflection` into all the structs that can be marshalled by reflection, as a clear hint to the reflection library - Add a new interface called `UnmarshallableWithHint` - most of the TPMU implement this, and the old `marshalUnion` and `unmarshalUnion` functions are gone now - Bonus: I noticed using profiling that the `tags` function was allocating several orders of magnitude more memory than the rest of the library, so I rewrote it - Introduced a generic TPM2B helper that is aliased by the concrete TPM2B types; there are constructors for instantiating TPM2B from data or structured contents. - TPMU is public, with private fields. Introduced constructors for these with type constraints. Fixes #307 and #292. --- .cirrus.yml | 1 + go.mod | 8 +- go.sum | 193 +-- tpm2/audit.go | 20 +- tpm2/marshalling.go | 128 ++ tpm2/marshalling_test.go | 156 +++ tpm2/names.go | 8 +- tpm2/policy.go | 12 +- tpm2/reflect.go | 259 ++-- tpm2/reflect_test.go | 243 ---- tpm2/sessions.go | 88 +- tpm2/structures.go | 1812 +++++++++++++++++++++++-- tpm2/templates.go | 112 +- tpm2/test/activate_credential_test.go | 12 +- tpm2/test/audit_test.go | 71 +- tpm2/test/certify_test.go | 286 ++-- tpm2/test/clear_test.go | 8 +- tpm2/test/combined_context_test.go | 48 +- tpm2/test/commit_test.go | 30 +- tpm2/test/create_loaded_test.go | 93 +- tpm2/test/ecdh_test.go | 82 +- tpm2/test/ek_test.go | 55 +- tpm2/test/load_external_test.go | 68 +- tpm2/test/names_test.go | 27 +- tpm2/test/nv_test.go | 54 +- tpm2/test/pcr_test.go | 2 +- tpm2/test/policy_test.go | 90 +- tpm2/test/read_public_test.go | 67 +- tpm2/test/sealing_test.go | 43 +- tpm2/test/sign_test.go | 75 +- tpm2/tpm2.go | 608 ++++----- tpm2/tpm2b.go | 83 ++ tpm2/wrappers.go | 10 - 33 files changed, 3200 insertions(+), 1652 deletions(-) create mode 100644 tpm2/marshalling.go create mode 100644 tpm2/marshalling_test.go delete mode 100644 tpm2/reflect_test.go create mode 100644 tpm2/tpm2b.go delete mode 100644 tpm2/wrappers.go diff --git a/.cirrus.yml b/.cirrus.yml index a944cccf..f98596d0 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -27,6 +27,7 @@ lint_task: --exclude-use-default=false --exclude stutters --exclude underscores + --exclude unexported-return --max-same-issues=0 --max-issues-per-linter=0 ./tpmutil/... diff --git a/go.mod b/go.mod index 26bcf5ed..b9e097da 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module github.com/google/go-tpm go 1.18 require ( - github.com/google/go-cmp v0.5.0 - github.com/google/go-tpm-tools v0.2.0 - golang.org/x/sys v0.0.0-20210629170331-7dc0b73dc9fb + github.com/google/go-cmp v0.5.7 + github.com/google/go-tpm-tools v0.3.10 + golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f ) -require golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 // indirect +require golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect diff --git a/go.sum b/go.sum index e2b5f2a6..efb357cc 100644 --- a/go.sum +++ b/go.sum @@ -1,179 +1,16 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= -github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-tpm v0.1.2-0.20190725015402-ae6dd98980d4/go.mod h1:H9HbmUG2YgV/PHITkO7p6wxEEj/v5nlsVWIwumwH2NI= -github.com/google/go-tpm v0.3.0/go.mod h1:iVLWvrPp/bHeEkxTFi9WG6K9w0iy2yIszHwZGHPbzAw= -github.com/google/go-tpm-tools v0.0.0-20190906225433-1614c142f845/go.mod h1:AVfHadzbdzHo54inR2x1v640jdi1YSi3NauM2DUsxk0= -github.com/google/go-tpm-tools v0.2.0 h1:pBflcn8x5iFohPScqlmLaImrC7ts/EUJa7ZY4FkTFq4= -github.com/google/go-tpm-tools v0.2.0/go.mod h1:npUd03rQ60lxN7tzeBJreG38RvWwme2N1reF/eeiBk4= -github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= -github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= -github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= -github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/cobra v1.0.0/go.mod h1:/6GTrnGXV9HjY+aR4k0oJ5tcvakLuG6EuKReYlHNrgE= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= -github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= -github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= -github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20210629170331-7dc0b73dc9fb h1:sgcyLNYiHqEd8eFVh0PflG5ABPTGcPSJacD3s19RTcY= -golang.org/x/sys v0.0.0-20210629170331-7dc0b73dc9fb/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-sev-guest v0.4.1 h1:IjxtGAvzR+zSyAqMc1FWfYKCg1cwPkBly9+Xog3YMZc= +github.com/google/go-tpm-tools v0.3.10 h1:hz9EoyG4Ewa0leT3OvxlWprq14Lw0RBmfFcH9H9+Yas= +github.com/google/go-tpm-tools v0.3.10/go.mod h1:HQfQboO+M8pRtBfO5U3KMhwzfC/XC3TaMCgRfTpII8Q= +github.com/google/logger v1.1.1 h1:+6Z2geNxc9G+4D4oDO9njjjn2d0wN5d7uOo0vOIW1NQ= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= +github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= diff --git a/tpm2/audit.go b/tpm2/audit.go index 658ab1ac..95ad2cb0 100644 --- a/tpm2/audit.go +++ b/tpm2/audit.go @@ -26,13 +26,17 @@ func NewAudit(hash TPMIAlgHash) (*CommandAudit, error) { }, nil } -// Extend extends the audit digest with the given command and response. -func (a *CommandAudit) Extend(cmd Command, rsp Response) error { - cpHash, err := auditCPHash(a.hash, cmd) +// AuditCommand extends the audit digest with the given command and response. +// Go Generics do not allow type parameters on methods, otherwise this would be +// a method on CommandAudit. +// See https://github.com/golang/go/issues/49085 for more information. +func AuditCommand[C Command[R, *R], R any](a *CommandAudit, cmd C, rsp *R) error { + cc := cmd.Command() + cpHash, err := auditCPHash[R](cc, a.hash, cmd) if err != nil { return err } - rpHash, err := auditRPHash(a.hash, rsp) + rpHash, err := auditRPHash(cc, a.hash, rsp) if err != nil { return err } @@ -56,8 +60,7 @@ func (a *CommandAudit) Digest() []byte { // auditCPHash calculates the command parameter hash for a given command with // the given hash algorithm. The command is assumed to not have any decrypt // sessions. -func auditCPHash(h TPMIAlgHash, c Command) ([]byte, error) { - cc := c.Command() +func auditCPHash[R any](cc TPMCC, h TPMIAlgHash, c Command[R, *R]) ([]byte, error) { names, err := cmdNames(c) if err != nil { return nil, err @@ -72,13 +75,12 @@ func auditCPHash(h TPMIAlgHash, c Command) ([]byte, error) { // auditRPHash calculates the response parameter hash for a given response with // the given hash algorithm. The command is assumed to be successful and to not // have any encrypt sessions. -func auditRPHash(h TPMIAlgHash, r Response) ([]byte, error) { - cc := r.Response() +func auditRPHash(cc TPMCC, h TPMIAlgHash, r any) ([]byte, error) { var parms bytes.Buffer parameters := taggedMembers(reflect.ValueOf(r).Elem(), "handle", true) for i, parameter := range parameters { if err := marshal(&parms, parameter); err != nil { - return nil, fmt.Errorf("marshalling parameter %v: %w", i, err) + return nil, fmt.Errorf("marshalling parameter %v: %w", i+1, err) } } return rpHash(h, TPMRCSuccess, cc, parms.Bytes()) diff --git a/tpm2/marshalling.go b/tpm2/marshalling.go new file mode 100644 index 00000000..04d16074 --- /dev/null +++ b/tpm2/marshalling.go @@ -0,0 +1,128 @@ +package tpm2 + +import ( + "bytes" + "fmt" + "reflect" +) + +// Marshallable represents any TPM type that can be marshalled. +type Marshallable interface { + // marshal will serialize the given value, appending onto the given buffer. + // Returns an error if the value is not marshallable. + marshal(buf *bytes.Buffer) +} + +// marshallableWithHint represents any TPM type that can be marshalled, +// but that requires a selector ("hint") value when marshalling. Most TPMU_ are +// an example of this. +type marshallableWithHint interface { + // get will return the corresponding union member by copy. If the union is + // uninitialized, it will initialize a new zero-valued one. + get(hint int64) (reflect.Value, error) +} + +// Unmarshallable represents any TPM type that can be marshalled or unmarshalled. +type Unmarshallable interface { + Marshallable + // marshal will deserialize the given value from the given buffer. + // Returns an error if there was an unmarshalling error or if there was not + // enough data in the buffer. + unmarshal(buf *bytes.Buffer) error +} + +// unmarshallableWithHint represents any TPM type that can be marshalled or unmarshalled, +// but that requires a selector ("hint") value when unmarshalling. Most TPMU_ are +// an example of this. +type unmarshallableWithHint interface { + marshallableWithHint + // create will instantiate and return the corresponding union member. + create(hint int64) (reflect.Value, error) +} + +// Marshal will serialize the given values, returning them as a byte slice. +func Marshal(v Marshallable) []byte { + var buf bytes.Buffer + if err := marshal(&buf, reflect.ValueOf(v)); err != nil { + panic(fmt.Sprintf("unexpected error marshalling %v: %v", reflect.TypeOf(v).Name(), err)) + } + return buf.Bytes() +} + +// Unmarshal unmarshals the given type from the byte array. +// Returns an error if the buffer does not contain enough data to satisfy the +// types, or if the types are not unmarshallable. +func Unmarshal[T Marshallable, P interface { + *T + Unmarshallable +}](data []byte) (*T, error) { + buf := bytes.NewBuffer(data) + var t T + value := reflect.New(reflect.TypeOf(t)) + if err := unmarshal(buf, value.Elem()); err != nil { + return nil, err + } + return value.Interface().(*T), nil +} + +// marshallableByReflection is a placeholder interface, to hint to the unmarshalling +// library that it is supposed to use reflection. +type marshallableByReflection interface { + reflectionSafe() +} + +// marshalByReflection is embedded into any type that can be marshalled by reflection, +// needing no custom logic. +type marshalByReflection struct{} + +func (marshalByReflection) reflectionSafe() {} + +// These placeholders are required because a type constraint cannot union another interface +// that contains methods. +// Otherwise, marshalByReflection would not implement Unmarshallable, and the Marshal/Unmarshal +// functions would accept interface{ Marshallable | marshallableByReflection } instead. + +// Placeholder: because this type implements the defaultMarshallable interface, +// the reflection library knows not to call this. +func (marshalByReflection) marshal(_ *bytes.Buffer) { + panic("not implemented") +} + +// Placeholder: because this type implements the defaultMarshallable interface, +// the reflection library knows not to call this. +func (*marshalByReflection) unmarshal(_ *bytes.Buffer) error { + panic("not implemented") +} + +// boxed is a helper type for corner cases such as unions, where all members must be structs. +type boxed[T any] struct { + Contents *T +} + +// box will put a value into a box. +func box[T any](contents *T) boxed[T] { + return boxed[T]{ + Contents: contents, + } +} + +// unbox will take a value out of a box. +func (b *boxed[T]) unbox() *T { + return b.Contents +} + +// marshal implements the Marshallable interface. +func (b *boxed[T]) marshal(buf *bytes.Buffer) { + if b.Contents == nil { + var contents T + marshal(buf, reflect.ValueOf(&contents)) + } else { + marshal(buf, reflect.ValueOf(b.Contents)) + } +} + +// unmarshal implements the Unmarshallable interface. +func (b *boxed[T]) unmarshal(buf *bytes.Buffer) error { + b.Contents = new(T) + return unmarshal(buf, reflect.ValueOf(b.Contents)) +} diff --git a/tpm2/marshalling_test.go b/tpm2/marshalling_test.go new file mode 100644 index 00000000..bb4c45dc --- /dev/null +++ b/tpm2/marshalling_test.go @@ -0,0 +1,156 @@ +package tpm2 + +import ( + "bytes" + "testing" +) + +func TestMarshal2B(t *testing.T) { + // Define some TPMT_Public + pub := TPMTPublic{ + Type: TPMAlgKeyedHash, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + FixedParent: true, + UserWithAuth: true, + NoDA: true, + }, + } + + // Get the wire-format version + pubBytes := Marshal(pub) + + // Create two versions of the same 2B: + // one instantiated by the actual TPMTPublic + // one instantiated by the contents + var boxed1 TPM2BPublic + var boxed2 TPM2BPublic + boxed1 = New2B(pub) + boxed2 = BytesAs2B[TPMTPublic](pubBytes) + + boxed1Bytes := Marshal(boxed1) + boxed2Bytes := Marshal(boxed2) + + if !bytes.Equal(boxed1Bytes, boxed2Bytes) { + t.Errorf("got %x want %x", boxed2Bytes, boxed1Bytes) + } + + z, err := Unmarshal[TPM2BPublic](boxed1Bytes) + if err != nil { + t.Fatalf("could not unmarshal TPM2BPublic: %v", err) + } + t.Logf("%v", z) + + boxed3Bytes := Marshal(z) + if !bytes.Equal(boxed1Bytes, boxed3Bytes) { + t.Errorf("got %x want %x", boxed3Bytes, boxed1Bytes) + } + + // Make a nonsense 2B_Public, demonstrating that the library doesn't have to understand the serialization + BytesAs2B[TPMTPublic]([]byte{0xff}) +} + +func unwrap[T any](f func() (*T, error)) *T { + t, err := f() + if err != nil { + panic(err.Error()) + } + return t +} + +func TestMarshalT(t *testing.T) { + // Define some TPMT_Public + pub := TPMTPublic{ + Type: TPMAlgECC, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + SignEncrypt: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ + CurveID: TPMECCNistP256, + }, + ), + Unique: NewTPMUPublicID( + // This happens to be a P256 EKpub from the simulator + TPMAlgECC, + &TPMSECCPoint{ + X: TPM2BECCParameter{}, + Y: TPM2BECCParameter{}, + }, + ), + } + + // Marshal each component of the parameters + symBytes := Marshal(&unwrap(pub.Parameters.ECCDetail).Symmetric) + t.Logf("Symmetric: %x\n", symBytes) + sym, err := Unmarshal[TPMTSymDefObject](symBytes) + if err != nil { + t.Fatalf("could not unmarshal TPMTSymDefObject: %v", err) + } + symBytes2 := Marshal(sym) + if !bytes.Equal(symBytes, symBytes2) { + t.Errorf("want %x\ngot %x", symBytes, symBytes2) + } + schemeBytes := Marshal(&unwrap(pub.Parameters.ECCDetail).Scheme) + t.Logf("Scheme: %x\n", symBytes) + scheme, err := Unmarshal[TPMTECCScheme](schemeBytes) + if err != nil { + t.Fatalf("could not unmarshal TPMTECCScheme: %v", err) + } + schemeBytes2 := Marshal(scheme) + if !bytes.Equal(schemeBytes, schemeBytes2) { + t.Errorf("want %x\ngot %x", schemeBytes, schemeBytes2) + } + kdfBytes := Marshal(&unwrap(pub.Parameters.ECCDetail).KDF) + t.Logf("KDF: %x\n", kdfBytes) + kdf, err := Unmarshal[TPMTKDFScheme](kdfBytes) + if err != nil { + t.Fatalf("could not unmarshal TPMTKDFScheme: %v", err) + } + kdfBytes2 := Marshal(kdf) + if !bytes.Equal(kdfBytes, kdfBytes2) { + t.Errorf("want %x\ngot %x", kdfBytes, kdfBytes2) + } + + // Marshal the parameters + parmsBytes := Marshal(unwrap(pub.Parameters.ECCDetail)) + t.Logf("Parms: %x\n", parmsBytes) + parms, err := Unmarshal[TPMSECCParms](parmsBytes) + if err != nil { + t.Fatalf("could not unmarshal TPMSECCParms: %v", err) + } + parmsBytes2 := Marshal(parms) + if !bytes.Equal(parmsBytes, parmsBytes2) { + t.Errorf("want %x\ngot %x", parmsBytes, parmsBytes2) + } + + // Marshal the unique area + uniqueBytes := Marshal(unwrap(pub.Unique.ECC)) + t.Logf("Unique: %x\n", uniqueBytes) + unique, err := Unmarshal[TPMSECCPoint](uniqueBytes) + if err != nil { + t.Fatalf("could not unmarshal TPMSECCPoint: %v", err) + } + uniqueBytes2 := Marshal(unique) + if !bytes.Equal(uniqueBytes, uniqueBytes2) { + t.Errorf("want %x\ngot %x", uniqueBytes, uniqueBytes2) + } + + // Get the wire-format version of the whole thing + pubBytes := Marshal(&pub) + + pub2, err := Unmarshal[TPMTPublic](pubBytes) + if err != nil { + t.Fatalf("could not unmarshal TPMTPublic: %v", err) + } + + // Some default fields might have been populated in the round-trip. Get the wire-format again and compare. + pub2Bytes := Marshal(pub2) + + if !bytes.Equal(pubBytes, pub2Bytes) { + t.Errorf("want %x\ngot %x", pubBytes, pub2Bytes) + } +} diff --git a/tpm2/names.go b/tpm2/names.go index 9738efc2..4b5741d0 100644 --- a/tpm2/names.go +++ b/tpm2/names.go @@ -1,7 +1,9 @@ package tpm2 import ( + "bytes" "encoding/binary" + "reflect" ) // HandleName returns the TPM Name of a PCR, session, or permanent value @@ -30,11 +32,11 @@ func objectOrNVName(alg TPMAlgID, pub interface{}) (*TPM2BName, error) { // Calculate the hash of the entire Public contents and append it to the // result. ha := h.New() - marshalledPub, err := Marshal(pub) - if err != nil { + var buf bytes.Buffer + if err := marshal(&buf, reflect.ValueOf(pub)); err != nil { return nil, err } - ha.Write(marshalledPub) + ha.Write(buf.Bytes()) result = ha.Sum(result) return &TPM2BName{ diff --git a/tpm2/policy.go b/tpm2/policy.go index 41ffd26f..c64f17e4 100644 --- a/tpm2/policy.go +++ b/tpm2/policy.go @@ -1,7 +1,9 @@ package tpm2 import ( + "bytes" "crypto" + "reflect" ) // PolicyCalculator represents a TPM 2.0 policy that needs to be calculated @@ -36,11 +38,13 @@ func (p *PolicyCalculator) Reset() { func (p *PolicyCalculator) Update(data ...interface{}) error { hash := p.hash.New() hash.Write(p.state) - serialized, err := Marshal(data...) - if err != nil { - return err + var buf bytes.Buffer + for _, d := range data { + if err := marshal(&buf, reflect.ValueOf(d)); err != nil { + return err + } } - hash.Write(serialized) + hash.Write(buf.Bytes()) p.state = hash.Sum(nil) return nil } diff --git a/tpm2/reflect.go b/tpm2/reflect.go index c9447be1..e948e461 100644 --- a/tpm2/reflect.go +++ b/tpm2/reflect.go @@ -22,11 +22,8 @@ const ( ) // execute sends the provided command and returns the TPM's response. -func execute(t transport.TPM, cmd Command, rsp Response, extraSess ...Session) error { +func execute[R any](t transport.TPM, cmd Command[R, *R], rsp *R, extraSess ...Session) error { cc := cmd.Command() - if rsp.Response() != cc { - return fmt.Errorf("cmd and rsp must be for same command: %v != %v", cc, rsp.Response()) - } sess, err := cmdAuths(cmd) if err != nil { return err @@ -118,25 +115,45 @@ func execute(t transport.TPM, cmd Command, rsp Response, extraSess ...Session) e return nil } -// Marshal will serialize the given values, returning them as a byte slice. -// Returns an error if any of the values are not marshallable. -func Marshal(vs ...interface{}) ([]byte, error) { - var reflects []reflect.Value - for _, v := range vs { - reflects = append(reflects, reflect.ValueOf(v)) - } - var buf bytes.Buffer - for _, reflect := range reflects { - if err := marshal(&buf, reflect); err != nil { - return nil, err +func isMarshalledByReflection(v reflect.Value) bool { + var mbr marshallableByReflection + if v.Type().AssignableTo(reflect.TypeOf(&mbr).Elem()) { + return true + } + // basic types are also marshalled by reflection, as are empty structs + switch v.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Array, reflect.Slice, reflect.Ptr: + return true + case reflect.Struct: + if v.NumField() == 0 { + return true } } - return buf.Bytes(), nil + return false } // marshal will serialize the given value, appending onto the given buffer. // Returns an error if the value is not marshallable. func marshal(buf *bytes.Buffer, v reflect.Value) error { + // If the type is not marshalled by reflection, try to call the custom marshal method. + if !isMarshalledByReflection(v) { + u, ok := v.Interface().(Marshallable) + if ok { + u.marshal(buf) + return nil + } + if v.CanAddr() { + // Maybe we got an addressable value whose pointer implements Marshallable + pu, ok := v.Addr().Interface().(Marshallable) + if ok { + pu.marshal(buf) + return nil + } + } + return fmt.Errorf("can't marshal: type %v does not implement Marshallable or marshallableByReflection", v.Type().Name()) + } + + // Otherwise, use reflection. switch v.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return marshalNumeric(buf, v) @@ -145,7 +162,7 @@ func marshal(buf *bytes.Buffer, v reflect.Value) error { case reflect.Struct: return marshalStruct(buf, v) case reflect.Ptr: - return marshalStruct(buf, v.Elem()) + return marshal(buf, v.Elem()) case reflect.Interface: // Special case: there are very few TPM types which, for TPM spec // backwards-compatibility reasons, are implemented as Go interfaces @@ -257,7 +274,7 @@ func marshalStruct(buf *bytes.Buffer, v reflect.Value) error { list := hasTag(v.Type().Field(i), "list") sized := hasTag(v.Type().Field(i), "sized") sized8 := hasTag(v.Type().Field(i), "sized8") - tag := tags(v.Type().Field(i))["tag"] + tag, _ := tag(v.Type().Field(i), "tag") // Serialize to a temporary buffer, in case we need to size it // (Better to simplify this complex reflection-based marshalling // code than to save some unnecessary copying before talking to @@ -270,13 +287,23 @@ func marshalStruct(buf *bytes.Buffer, v reflect.Value) error { // Check that the tagged value was present (and numeric // and smaller than MaxInt64) tagValue, ok := possibleSelectors[tag] + // Don't marshal anything if the tag value was TPM_ALG_NULL + if tagValue == int64(TPMAlgNull) { + continue + } if !ok { return fmt.Errorf("union tag '%v' for member '%v' of struct '%v' did not reference "+ "a numeric field of int64-compatible value", tag, v.Type().Field(i).Name, v.Type().Name()) } - if err := marshalUnion(&res, v.Field(i), tagValue); err != nil { - return err + if u, ok := v.Field(i).Interface().(marshallableWithHint); ok { + v, err := u.get(tagValue) + if err != nil { + return err + } + if err := marshal(buf, v); err != nil { + return err + } } } else if v.Field(i).IsZero() && v.Field(i).Kind() == reflect.Uint32 && hasTag(v.Type().Field(i), "nullable") { // Special case: Anything with the same underlying type @@ -357,57 +384,18 @@ func marshalBitwise(buf *bytes.Buffer, v reflect.Value) error { return nil } -// Marshals the member of the given union struct corresponding to the given -// selector. Marshals nothing if the selector is equal to TPM_ALG_NULL (0x0010). -func marshalUnion(buf *bytes.Buffer, v reflect.Value, selector int64) error { - // Special case: TPM_ALG_NULL as a selector means marshal nothing - if selector == int64(TPMAlgNull) { - return nil - } - for i := 0; i < v.NumField(); i++ { - sel, ok := numericTag(v.Type().Field(i), "selector") - if !ok { - return fmt.Errorf("'%v' union member '%v' did not have a selector tag", v.Type().Name(), v.Type().Field(i).Name) - } - if sel == selector { - if v.Field(i).IsNil() { - // Special case: if the selected value is found - // but nil, marshal the zero-value instead - return marshal(buf, reflect.New(v.Field(i).Type().Elem()).Elem()) - } - return marshal(buf, v.Field(i).Elem()) - } - } - return fmt.Errorf("selector value '%v' not handled for type '%v'", selector, v.Type().Name()) -} - -// Unmarshal deserializes the given values from the byte slice. -// Returns an error if the buffer does not contain enough data to satisfy the -// types, or if the types are not unmarshallable. -func Unmarshal(data []byte, vs ...interface{}) error { - var reflects []reflect.Value - for _, v := range vs { - if reflect.ValueOf(v).Kind() != reflect.Ptr { - return fmt.Errorf("all parameters to Unmarshal must be pointers") - } - reflects = append(reflects, reflect.ValueOf(v).Elem()) - } - var buf bytes.Buffer - if _, err := buf.Write(data); err != nil { - return err - } - for _, reflect := range reflects { - if err := unmarshal(&buf, reflect); err != nil { - return err - } - } - return nil -} - // unmarshal will deserialize the given value from the given buffer. // Returns an error if the buffer does not contain enough data to satisfy the // type. func unmarshal(buf *bytes.Buffer, v reflect.Value) error { + // If the type is not marshalled by reflection, try to call the custom unmarshal method. + if !isMarshalledByReflection(v) { + if u, ok := v.Addr().Interface().(Unmarshallable); ok { + return u.unmarshal(buf) + } + return fmt.Errorf("can't unmarshal: type %v does not implement Unmarshallable or marshallableByReflection", v.Type().Name()) + } + switch v.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if err := unmarshalNumeric(buf, v); err != nil { @@ -439,14 +427,13 @@ func unmarshal(buf *bytes.Buffer, v reflect.Value) error { return err } v.Set(tmp) + return nil case reflect.Array: - if err := unmarshalArray(buf, v); err != nil { - return err - } + return unmarshalArray(buf, v) case reflect.Struct: - if err := unmarshalStruct(buf, v); err != nil { - return err - } + return unmarshalStruct(buf, v) + case reflect.Ptr: + return unmarshal(buf, v.Elem()) default: return fmt.Errorf("not unmarshallable: %v", v.Type()) } @@ -566,7 +553,7 @@ func unmarshalStruct(buf *bytes.Buffer, v reflect.Value) error { } bufToReadFrom = bytes.NewBuffer(sizedBufArray) } - tag := tags(v.Type().Field(i))["tag"] + tag, _ := tag(v.Type().Field(i), "tag") if tag != "" { // Make a pass to create a map of tag values // UInt64-valued fields with values greater than @@ -586,13 +573,36 @@ func unmarshalStruct(buf *bytes.Buffer, v reflect.Value) error { // Check that the tagged value was present (and numeric // and smaller than MaxInt64) tagValue, ok := possibleSelectors[tag] + // Don't marshal anything if the tag value was TPM_ALG_NULL + if tagValue == int64(TPMAlgNull) { + continue + } if !ok { return fmt.Errorf("union tag '%v' for member '%v' of struct '%v' did not reference "+ "a numeric field of in64-compatible value", tag, v.Type().Field(i).Name, v.Type().Name()) } - if err := unmarshalUnion(bufToReadFrom, v.Field(i), tagValue); err != nil { - return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err) + var uwh unmarshallableWithHint + if v.Field(i).CanAddr() && v.Field(i).Addr().Type().AssignableTo(reflect.TypeOf(&uwh).Elem()) { + u := v.Field(i).Addr().Interface().(unmarshallableWithHint) + contents, err := u.create(tagValue) + if err != nil { + return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err) + } + err = unmarshal(buf, contents) + if err != nil { + return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err) + } + } else if v.Field(i).Type().AssignableTo(reflect.TypeOf(&uwh).Elem()) { + u := v.Field(i).Interface().(unmarshallableWithHint) + contents, err := u.create(tagValue) + if err != nil { + return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err) + } + err = unmarshal(buf, contents) + if err != nil { + return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err) + } } } else { if err := unmarshal(bufToReadFrom, v.Field(i)); err != nil { @@ -655,39 +665,14 @@ func unmarshalBitwise(buf *bytes.Buffer, v reflect.Value) error { return nil } -// Unmarshals the member of the given union struct corresponding to the given -// selector. Unmarshals nothing if the selector is TPM_ALG_NULL (0x0010). -func unmarshalUnion(buf *bytes.Buffer, v reflect.Value, selector int64) error { - // Special case: TPM_ALG_NULL as a selector means unmarshal nothing - if selector == int64(TPMAlgNull) { - return nil - } - for i := 0; i < v.NumField(); i++ { - sel, ok := numericTag(v.Type().Field(i), "selector") - if !ok { - return fmt.Errorf("'%v' union member '%v' did not have a selector tag", v.Type().Name(), v.Type().Field(i).Name) - } - if sel == selector { - val := reflect.New(v.Type().Field(i).Type.Elem()) - if err := unmarshal(buf, val.Elem()); err != nil { - return fmt.Errorf("unmarshalling '%v' union member '%v': %w", v.Type().Name(), v.Type().Field(i).Name, err) - } - v.Field(i).Set(val) - return nil - } - } - return fmt.Errorf("selector value '%v' not handled for type '%v'", selector, v.Type().Name()) -} - -// Returns all the gotpm tags on a field as a map. +// Looks up the given gotpm tag on a field. // Some tags are settable (with "="). For these, the value is the RHS. // For all others, the value is the empty string. -func tags(t reflect.StructField) map[string]string { +func tag(t reflect.StructField, query string) (string, bool) { allTags, ok := t.Tag.Lookup("gotpm") if !ok { - return nil + return "", false } - result := make(map[string]string) tags := strings.Split(allTags, ",") for _, tag := range tags { // Split on the equals sign for settable tags. @@ -695,45 +680,29 @@ func tags(t reflect.StructField) map[string]string { // un-settable tag or an empty tag (which we'll ignore). // If the split returns a slice of length 2, this is a settable // tag. - assignment := strings.SplitN(tag, "=", 2) - val := "" - if len(assignment) > 1 { - val = assignment[1] + if tag == query { + return "", true } - if len(assignment) > 0 && assignment[0] != "" { - key := assignment[0] - result[key] = val + if strings.HasPrefix(tag, query+"=") { + assignment := strings.SplitN(tag, "=", 2) + return assignment[1], true } } - return result + return "", false } // hasTag looks up to see if the type's gotpm-namespaced tag contains the // given value. // Returns false if there is no gotpm-namespaced tag on the type. -func hasTag(t reflect.StructField, tag string) bool { - ts := tags(t) - _, ok := ts[tag] +func hasTag(t reflect.StructField, query string) bool { + _, ok := tag(t, query) return ok } -// Returns the numeric tag value, or false if the tag is not present. -func numericTag(t reflect.StructField, tag string) (int64, bool) { - val, ok := tags(t)[tag] - if !ok { - return 0, false - } - v, err := strconv.ParseInt(val, 0, 64) - if err != nil { - return 0, false - } - return v, true -} - // Returns the range on a tag like 4:3 or 4. // If there is no colon, the low and high part of the range are equal. -func rangeTag(t reflect.StructField, tag string) (int, int, bool) { - val, ok := tags(t)[tag] +func rangeTag(t reflect.StructField, query string) (int, int, bool) { + val, ok := tag(t, query) if !ok { return 0, 0, false } @@ -776,8 +745,8 @@ func taggedMembers(v reflect.Value, tag string, invert bool) []reflect.Value { } // cmdAuths returns the authorization sessions of the command. -func cmdAuths(cmd Command) ([]Session, error) { - authHandles := taggedMembers(reflect.ValueOf(cmd).Elem(), "auth", false) +func cmdAuths[R any](cmd Command[R, *R]) ([]Session, error) { + authHandles := taggedMembers(reflect.ValueOf(cmd), "auth", false) var result []Session for i, authHandle := range authHandles { // TODO: A cleaner way to do this would be to have an interface method that @@ -797,8 +766,8 @@ func cmdAuths(cmd Command) ([]Session, error) { } // cmdHandles returns the handles area of the command. -func cmdHandles(cmd Command) ([]byte, error) { - handles := taggedMembers(reflect.ValueOf(cmd).Elem(), "handle", false) +func cmdHandles[R any](cmd Command[R, *R]) ([]byte, error) { + handles := taggedMembers(reflect.ValueOf(cmd), "handle", false) // Initial capacity is enough to hold 3 handles result := bytes.NewBuffer(make([]byte, 0, 12)) @@ -823,8 +792,8 @@ func cmdHandles(cmd Command) ([]byte, error) { } // cmdNames returns the names of the entities referenced by the handles of the command. -func cmdNames(cmd Command) ([]TPM2BName, error) { - handles := taggedMembers(reflect.ValueOf(cmd).Elem(), "handle", false) +func cmdNames[R any](cmd Command[R, *R]) ([]TPM2BName, error) { + handles := taggedMembers(reflect.ValueOf(cmd), "handle", false) var result []TPM2BName for i, maybeHandle := range handles { h, ok := maybeHandle.Interface().(handle) @@ -852,13 +821,13 @@ func cmdNames(cmd Command) ([]TPM2BName, error) { // TODO: Extract the logic of "marshal the Nth field of some struct after the handles" // For now, we duplicate some logic from marshalStruct here. -func marshalParameter(buf *bytes.Buffer, cmd Command, i int) error { - numHandles := len(taggedMembers(reflect.ValueOf(cmd).Elem(), "handle", false)) - if numHandles+i >= reflect.TypeOf(cmd).Elem().NumField() { +func marshalParameter[R any](buf *bytes.Buffer, cmd Command[R, *R], i int) error { + numHandles := len(taggedMembers(reflect.ValueOf(cmd), "handle", false)) + if numHandles+i >= reflect.TypeOf(cmd).NumField() { return fmt.Errorf("invalid parameter index %v", i) } - parm := reflect.ValueOf(cmd).Elem().Field(numHandles + i) - field := reflect.TypeOf(cmd).Elem().Field(numHandles + i) + parm := reflect.ValueOf(cmd).Field(numHandles + i) + field := reflect.TypeOf(cmd).Field(numHandles + i) if hasTag(field, "optional") { return marshalOptional(buf, parm) } else if parm.IsZero() && parm.Kind() == reflect.Uint32 && hasTag(field, "nullable") { @@ -872,8 +841,8 @@ func marshalParameter(buf *bytes.Buffer, cmd Command, i int) error { // cmdParameters returns the parameters area of the command. // The first parameter may be encrypted by one of the sessions. -func cmdParameters(cmd Command, sess []Session) ([]byte, error) { - parms := taggedMembers(reflect.ValueOf(cmd).Elem(), "handle", true) +func cmdParameters[R any](cmd Command[R, *R], sess []Session) ([]byte, error) { + parms := taggedMembers(reflect.ValueOf(cmd), "handle", true) if len(parms) == 0 { return nil, nil } @@ -1007,7 +976,7 @@ func rspHeader(rsp *bytes.Buffer) error { // If there is a mismatch between the expected and actual amount of handles, // returns an error here. // rsp is updated to point to the rest of the response after the handles. -func rspHandles(rsp *bytes.Buffer, rspStruct Response) error { +func rspHandles(rsp *bytes.Buffer, rspStruct any) error { handles := taggedMembers(reflect.ValueOf(rspStruct).Elem(), "handle", false) for i, handle := range handles { if err := unmarshal(rsp, handle); err != nil { @@ -1072,7 +1041,7 @@ func rspSessions(rsp *bytes.Buffer, rc TPMRC, cc TPMCC, names []TPM2BName, parms // rspParameters decrypts (if needed) the parameters area of the response // into the response structure. If there is a mismatch between the expected // and actual response structure, returns an error here. -func rspParameters(parms []byte, sess []Session, rspStruct Response) error { +func rspParameters(parms []byte, sess []Session, rspStruct any) error { numHandles := len(taggedMembers(reflect.ValueOf(rspStruct).Elem(), "handle", false)) // Use the heuristic of "does interpreting the first 2 bytes of response diff --git a/tpm2/reflect_test.go b/tpm2/reflect_test.go deleted file mode 100644 index f194c11d..00000000 --- a/tpm2/reflect_test.go +++ /dev/null @@ -1,243 +0,0 @@ -package tpm2 - -import ( - "bytes" - "fmt" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -func marshalUnmarshal(t *testing.T, v interface{}, want []byte) { - t.Helper() - var buf bytes.Buffer - marshal(&buf, reflect.ValueOf(v)) - if !bytes.Equal(buf.Bytes(), want) { - t.Errorf("want %x got %x", want, buf.Bytes()) - } - got := reflect.New(reflect.TypeOf(v)) - err := unmarshal(&buf, got.Elem()) - if err != nil { - t.Fatalf("want nil, got %v", err) - } - var opts []cmp.Option - if reflect.TypeOf(v).Kind() == reflect.Struct { - opts = append(opts, cmpopts.IgnoreUnexported(v)) - } - if !cmp.Equal(v, got.Elem().Interface(), opts...) { - t.Errorf("want %#v, got %#v\n%v", v, got.Elem().Interface(), cmp.Diff(v, got.Elem().Interface(), opts...)) - } -} - -func TestMarshalNumeric(t *testing.T) { - vals := map[interface{}][]byte{ - false: []byte{0}, - byte(1): []byte{1}, - int8(2): []byte{2}, - uint8(3): []byte{3}, - int16(260): []byte{1, 4}, - uint16(261): []byte{1, 5}, - int32(65542): []byte{0, 1, 0, 6}, - uint32(65543): []byte{0, 1, 0, 7}, - int64(4294967304): []byte{0, 0, 0, 1, 0, 0, 0, 8}, - uint64(4294967305): []byte{0, 0, 0, 1, 0, 0, 0, 9}, - } - for v, want := range vals { - t.Run(fmt.Sprintf("%v-%v", reflect.TypeOf(v), v), func(t *testing.T) { - marshalUnmarshal(t, v, want) - }) - } -} - -func TestMarshalArray(t *testing.T) { - vals := []struct { - Data interface{} - Serialization []byte - }{ - {[4]int8{1, 2, 3, 4}, []byte{1, 2, 3, 4}}, - {[3]uint16{5, 6, 7}, []byte{0, 5, 0, 6, 0, 7}}, - } - for _, val := range vals { - v, want := val.Data, val.Serialization - t.Run(fmt.Sprintf("%v-%v", reflect.TypeOf(v), v), func(t *testing.T) { - marshalUnmarshal(t, v, want) - }) - } -} - -func TestMarshalSlice(t *testing.T) { - // Slices in reflect/gotpm must be tagged marshalled/unmarshalled as - // part of a struct with the 'list' tag - type sliceWrapper struct { - Elems []uint32 `gotpm:"list"` - } - vals := []struct { - Name string - Data sliceWrapper - Serialization []byte - }{ - {"3", sliceWrapper{[]uint32{1, 2, 3}}, []byte{0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3}}, - {"1", sliceWrapper{[]uint32{4}}, []byte{0, 0, 0, 1, 0, 0, 0, 4}}, - {"empty", sliceWrapper{[]uint32{}}, []byte{0, 0, 0, 0}}, - } - for _, val := range vals { - v, want := val.Data, val.Serialization - t.Run(val.Name, func(t *testing.T) { - marshalUnmarshal(t, v, want) - }) - } -} - -func unmarshalReserved(t *testing.T, data []byte, want interface{}) { - t.Helper() - - // Attempt to unmarshal data to the type of want, and compare - // Want is assumed to be a bitfield that may have reserved bits. - // Reserved bits are not going to be present in the input structure, - // or the accessible fields of what we marshalled. - got := reflect.New(reflect.TypeOf(want)) - err := Unmarshal(data, got.Interface()) - if err != nil { - t.Fatalf("want nil, got %v", err) - } - var opts []cmp.Option - if reflect.TypeOf(want).Kind() == reflect.Struct { - opts = append(opts, cmpopts.IgnoreUnexported(want)) - } - if !cmp.Equal(want, got.Elem().Interface(), opts...) { - t.Errorf("want %#v, got %#v\n%v", want, got.Elem().Interface(), cmp.Diff(want, got.Elem().Interface(), opts...)) - } - - // Re-marshal what we unmarshalled and ensure that it contains the - // original serialization (i.e., any reserved bits are still there). - result, err := Marshal(got.Interface()) - if err != nil { - t.Fatalf("error marshalling %v: %v", got, err) - } - if !bytes.Equal(result, data) { - t.Errorf("want %x got %x", data, result) - } -} - -func TestMarshalBitfield(t *testing.T) { - t.Run("8bit", func(t *testing.T) { - v := TPMASession{ - ContinueSession: true, - AuditExclusive: true, - AuditReset: false, - Decrypt: true, - Encrypt: true, - Audit: false, - } - want := []byte{0x63} - marshalUnmarshal(t, v, want) - unmarshalReserved(t, []byte{0x7b}, v) - }) - t.Run("full8bit", func(t *testing.T) { - v := TPMALocality{ - TPMLocZero: true, - TPMLocOne: true, - TPMLocTwo: false, - TPMLocThree: true, - TPMLocFour: false, - Extended: 1, - } - want := []byte{0x2b} - marshalUnmarshal(t, v, want) - }) - t.Run("32bit", func(t *testing.T) { - v := TPMACC{ - CommandIndex: 6, - NV: true, - } - want := []byte{0x00, 0x40, 0x00, 0x06} - marshalUnmarshal(t, v, want) - unmarshalReserved(t, []byte{0x80, 0x41, 0x00, 0x06}, v) - }) - t.Run("TPMAObject", func(t *testing.T) { - v := TPMAObject{ - FixedTPM: true, - STClear: true, - FixedParent: true, - } - want := []byte{0x00, 0x00, 0x00, 0x16} - marshalUnmarshal(t, v, want) - unmarshalReserved(t, []byte{0xff, 0x00, 0x00, 0x16}, v) - }) -} - -func TestMarshalUnion(t *testing.T) { - type valStruct struct { - First bool - Second int32 - } - type unionValue struct { - Val8 *uint8 `gotpm:"selector=8"` - Val64 *uint64 `gotpm:"selector=0x00000040"` - ValStruct *valStruct `gotpm:"selector=5"` // 5 for '5truct' - } - type unionEnvelope struct { - Type uint8 - OtherThing uint32 - Value unionValue `gotpm:"tag=Type"` - } - eight := uint8(8) - sixtyFour := uint64(64) - cases := []struct { - Name string - Data unionEnvelope - Serialization []byte - }{ - { - Name: "8", - Data: unionEnvelope{ - Type: 8, - OtherThing: 0xabcd1234, - Value: unionValue{ - Val8: &eight, - }, - }, - Serialization: []byte{ - 0x08, 0xab, 0xcd, 0x12, 0x34, 0x08, - }, - }, - { - Name: "64", - Data: unionEnvelope{ - Type: 64, - OtherThing: 0xffffffff, - Value: unionValue{ - Val64: &sixtyFour, - }, - }, - Serialization: []byte{ - 0x40, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, - }, - }, - { - Name: "Struct", - Data: unionEnvelope{ - Type: 5, - OtherThing: 0x11111111, - Value: unionValue{ - ValStruct: &valStruct{ - First: true, - Second: 65537, - }, - }, - }, - Serialization: []byte{ - 0x05, 0x11, 0x11, 0x11, 0x11, 0x01, 0x00, 0x01, 0x00, 0x01, - }, - }, - } - - for _, c := range cases { - v, want := c.Data, c.Serialization - t.Run(c.Name, func(t *testing.T) { - marshalUnmarshal(t, v, want) - }) - } -} diff --git a/tpm2/sessions.go b/tpm2/sessions.go index 227aa315..f0fbda68 100644 --- a/tpm2/sessions.go +++ b/tpm2/sessions.go @@ -55,7 +55,7 @@ type Session interface { // CPHash calculates the TPM command parameter hash for a given Command. // N.B. Authorization sessions on handles are ignored, but names aren't. -func CPHash(alg TPMIAlgHash, cmd Command) (*TPM2BDigest, error) { +func CPHash[R any](alg TPMIAlgHash, cmd Command[R, *R]) (*TPM2BDigest, error) { cc := cmd.Command() names, err := cmdNames(cmd) if err != nil { @@ -272,12 +272,14 @@ func AESEncryption(keySize TPMKeyBits, dir parameterEncryptiontpm2ion) AuthOptio o.attrs.Encrypt = (dir == EncryptOut || dir == EncryptInOut) o.symmetric = TPMTSymDef{ Algorithm: TPMAlgAES, - KeyBits: TPMUSymKeyBits{ - AES: NewKeyBits(keySize), - }, - Mode: TPMUSymMode{ - AES: NewAlgID(TPMAlgCFB), - }, + KeyBits: NewTPMUSymKeyBits( + TPMAlgAES, + TPMKeyBits(keySize), + ), + Mode: NewTPMUSymMode( + TPMAlgAES, + TPMAlgCFB, + ), } } } @@ -361,7 +363,8 @@ func HMACSession(t transport.TPM, hash TPMIAlgHash, nonceSize int, opts ...AuthO } closer := func() error { - return (&FlushContext{FlushHandle: sess.handle}).Execute(t) + _, err := (&FlushContext{FlushHandle: sess.handle}).Execute(t) + return err } return &sess, closer, nil @@ -378,13 +381,25 @@ func getEncryptedSaltRSA(nameAlg TPMIAlgHash, parms *TPMSRSAParms, pub *TPM2BPub var hAlg TPMIAlgHash switch parms.Scheme.Scheme { case TPMAlgRSASSA: - hAlg = parms.Scheme.Details.RSASSA.HashAlg + rsassa, err := parms.Scheme.Details.RSASSA() + if err != nil { + return nil, nil, err + } + hAlg = rsassa.HashAlg case TPMAlgRSAES: hAlg = nameAlg case TPMAlgRSAPSS: - hAlg = parms.Scheme.Details.RSAPSS.HashAlg + rsapss, err := parms.Scheme.Details.RSAPSS() + if err != nil { + return nil, nil, err + } + hAlg = rsapss.HashAlg case TPMAlgOAEP: - hAlg = parms.Scheme.Details.OAEP.HashAlg + oaep, err := parms.Scheme.Details.OAEP() + if err != nil { + return nil, nil, err + } + hAlg = oaep.HashAlg case TPMAlgNull: hAlg = nameAlg default: @@ -449,9 +464,25 @@ func getEncryptedSaltECC(nameAlg TPMIAlgHash, parms *TPMSECCParms, pub *TPMSECCP func getEncryptedSalt(pub TPMTPublic) (*TPM2BEncryptedSecret, []byte, error) { switch pub.Type { case TPMAlgRSA: - return getEncryptedSaltRSA(pub.NameAlg, pub.Parameters.RSADetail, pub.Unique.RSA) + rsaParms, err := pub.Parameters.RSADetail() + if err != nil { + return nil, nil, err + } + rsaPub, err := pub.Unique.RSA() + if err != nil { + return nil, nil, err + } + return getEncryptedSaltRSA(pub.NameAlg, rsaParms, rsaPub) case TPMAlgECC: - return getEncryptedSaltECC(pub.NameAlg, pub.Parameters.ECCDetail, pub.Unique.ECC) + eccParms, err := pub.Parameters.ECCDetail() + if err != nil { + return nil, nil, err + } + eccPub, err := pub.Unique.ECC() + if err != nil { + return nil, nil, err + } + return getEncryptedSaltECC(pub.NameAlg, eccParms, eccPub) default: return nil, nil, fmt.Errorf("salt encryption alg '%v' not supported", pub.Type) } @@ -519,7 +550,7 @@ func (s *hmacSession) CleanupFailure(t transport.TPM) error { return nil } fc := FlushContext{FlushHandle: s.handle} - if err := fc.Execute(t); err != nil { + if _, err := fc.Execute(t); err != nil { return err } s.handle = TPMRHNull @@ -686,7 +717,11 @@ func (s *hmacSession) Encrypt(parameter []byte) error { return nil } // Only AES-CFB is supported. - keyBytes := *s.symmetric.KeyBits.AES / 8 + bits, err := s.symmetric.KeyBits.AES() + if err != nil { + return err + } + keyBytes := *bits / 8 keyIVBytes := int(keyBytes) + 16 var sessionValue []byte sessionValue = append(sessionValue, s.sessionKey...) @@ -712,7 +747,11 @@ func (s *hmacSession) Decrypt(parameter []byte) error { return nil } // Only AES-CFB is supported. - keyBytes := *s.symmetric.KeyBits.AES / 8 + bits, err := s.symmetric.KeyBits.AES() + if err != nil { + return err + } + keyBytes := *bits / 8 keyIVBytes := int(keyBytes) + 16 // Part 1, 21.1 var sessionValue []byte @@ -803,7 +842,8 @@ func PolicySession(t transport.TPM, hash TPMIAlgHash, nonceSize int, opts ...Aut } closer := func() error { - return (&FlushContext{sess.handle}).Execute(t) + _, err := (&FlushContext{sess.handle}).Execute(t) + return err } return &sess, closer, nil @@ -884,7 +924,7 @@ func (s *policySession) CleanupFailure(t transport.TPM) error { return nil } fc := FlushContext{FlushHandle: s.handle} - if err := fc.Execute(t); err != nil { + if _, err := fc.Execute(t); err != nil { return err } s.handle = TPMRHNull @@ -998,7 +1038,11 @@ func (s *policySession) Encrypt(parameter []byte) error { return nil } // Only AES-CFB is supported. - keyBytes := *s.symmetric.KeyBits.AES / 8 + bits, err := s.symmetric.KeyBits.AES() + if err != nil { + return err + } + keyBytes := *bits / 8 keyIVBytes := int(keyBytes) + 16 var sessionValue []byte sessionValue = append(sessionValue, s.sessionKey...) @@ -1024,7 +1068,11 @@ func (s *policySession) Decrypt(parameter []byte) error { return nil } // Only AES-CFB is supported. - keyBytes := *s.symmetric.KeyBits.AES / 8 + bits, err := s.symmetric.KeyBits.AES() + if err != nil { + return err + } + keyBytes := *bits / 8 keyIVBytes := int(keyBytes) + 16 // Part 1, 21.1 var sessionValue []byte diff --git a/tpm2/structures.go b/tpm2/structures.go index 474825fe..2853bc38 100644 --- a/tpm2/structures.go +++ b/tpm2/structures.go @@ -2,9 +2,11 @@ package tpm2 import ( + "bytes" "crypto" "crypto/elliptic" "encoding/binary" + "reflect" // Register the relevant hash implementations. _ "crypto/sha1" @@ -16,6 +18,7 @@ import ( // TPMCmdHeader is the header structure in front of any TPM command. // It is described in Part 1, Architecture. type TPMCmdHeader struct { + marshalByReflection Tag TPMISTCommandTag Length uint32 CommandCode TPMCC @@ -24,6 +27,7 @@ type TPMCmdHeader struct { // TPMRspHeader is the header structure in front of any TPM response. // It is described in Part 1, Architecture. type TPMRspHeader struct { + marshalByReflection Tag TPMISTCommandTag Length uint32 ResponseCode TPMRC @@ -165,6 +169,7 @@ func (h TPMHandle) KnownName() *TPM2BName { // See definition in Part 2: Structures, section 8.2. type TPMAAlgorithm struct { bitfield32 + marshalByReflection // SET (1): an asymmetric algorithm with public and private portions // CLEAR (0): not an asymmetric algorithm Asymmetric bool `gotpm:"bit=0"` @@ -195,6 +200,7 @@ type TPMAAlgorithm struct { // See definition in Part 2: Structures, section 8.3.2. type TPMAObject struct { bitfield32 + marshalByReflection // SET (1): The hierarchy of the object, as indicated by its // Qualified Name, may not change. // CLEAR (0): The hierarchy of the object may change as a result @@ -266,6 +272,7 @@ type TPMAObject struct { // See definition in Part 2: Structures, section 8.4. type TPMASession struct { bitfield8 + marshalByReflection // SET (1): In a command, this setting indicates that the session // is to remain active after successful completion of the command. // In a response, it indicates that the session is still active. @@ -327,6 +334,7 @@ type TPMASession struct { // See definition in Part 2: Structures, section 8.5. type TPMALocality struct { bitfield8 + marshalByReflection TPMLocZero bool `gotpm:"bit=0"` TPMLocOne bool `gotpm:"bit=1"` TPMLocTwo bool `gotpm:"bit=2"` @@ -340,6 +348,7 @@ type TPMALocality struct { // See definition in Part 2: Structures, section 8.9. type TPMACC struct { bitfield32 + marshalByReflection // indicates the command being selected CommandIndex uint16 `gotpm:"bit=15:0"` // SET (1): indicates that the command may write to NV @@ -364,6 +373,7 @@ type TPMACC struct { // See definition in Part 2: Structures, section 8.12. type TPMAACT struct { bitfield32 + marshalByReflection // SET (1): The ACT has signaled // CLEAR (0): The ACT has not signaled Signaled bool `gotpm:"bit=0"` @@ -480,9 +490,6 @@ func (a TPMIAlgHash) Hash() (crypto.Hash, error) { return crypto.SHA256, fmt.Errorf("unsupported hash algorithm: %v", a) } -// TODO: Provide a placeholder interface here so we can explicitly enumerate -// these for compile-time protection. - // TPMIAlgSym represents a TPMI_ALG_SYM. // See definition in Part 2: Structures, section 9.29. type TPMIAlgSym = TPMAlgID @@ -509,11 +516,14 @@ type TPMISTCommandTag = TPMST // TPMSEmpty represents a TPMS_EMPTY. // See definition in Part 2: Structures, section 10.1. -type TPMSEmpty = struct{} +type TPMSEmpty struct { + marshalByReflection +} // TPMTHA represents a TPMT_HA. // See definition in Part 2: Structures, section 10.3.2. type TPMTHA struct { + marshalByReflection // selector of the hash contained in the digest that implies the size of the digest HashAlg TPMIAlgHash `gotpm:"nullable"` // the digest data @@ -528,6 +538,7 @@ type TPM2BDigest TPM2BData // TPM2BData represents a TPM2B_DATA. // See definition in Part 2: Structures, section 10.4.3. type TPM2BData struct { + marshalByReflection // size in octets of the buffer field; may be 0 Buffer []byte `gotpm:"sized"` } @@ -570,6 +581,7 @@ type TPM2BName TPM2BData // TPMSPCRSelection represents a TPMS_PCR_SELECTION. // See definition in Part 2: Structures, section 10.6.2. type TPMSPCRSelection struct { + marshalByReflection Hash TPMIAlgHash PCRSelect []byte `gotpm:"sized8"` } @@ -577,6 +589,7 @@ type TPMSPCRSelection struct { // TPMTTKCreation represents a TPMT_TK_CREATION. // See definition in Part 2: Structures, section 10.7.3. type TPMTTKCreation struct { + marshalByReflection // ticket structure tag Tag TPMST // the hierarchy containing name @@ -588,6 +601,7 @@ type TPMTTKCreation struct { // TPMTTKVerified represents a TPMT_TK_Verified. // See definition in Part 2: Structures, section 10.7.4. type TPMTTKVerified struct { + marshalByReflection // ticket structure tag Tag TPMST // the hierarchy containing keyName @@ -599,6 +613,7 @@ type TPMTTKVerified struct { // TPMTTKAuth represents a TPMT_TK_AUTH. // See definition in Part 2: Structures, section 10.7.5. type TPMTTKAuth struct { + marshalByReflection // ticket structure tag Tag TPMST // the hierarchy of the object used to produce the ticket @@ -610,6 +625,7 @@ type TPMTTKAuth struct { // TPMTTKHashCheck represents a TPMT_TK_HASHCHECK. // See definition in Part 2: Structures, section 10.7.6. type TPMTTKHashCheck struct { + marshalByReflection // ticket structure tag Tag TPMST // the hierarchy @@ -621,6 +637,7 @@ type TPMTTKHashCheck struct { // TPMSAlgProperty represents a TPMS_ALG_PROPERTY. // See definition in Part 2: Structures, section 10.8.1. type TPMSAlgProperty struct { + marshalByReflection // an algorithm identifier Alg TPMAlgID // the attributes of the algorithm @@ -630,6 +647,7 @@ type TPMSAlgProperty struct { // TPMSTaggedProperty represents a TPMS_TAGGED_PROPERTY. // See definition in Part 2: Structures, section 10.8.2. type TPMSTaggedProperty struct { + marshalByReflection // a property identifier Property TPMPT // the value of the property @@ -639,6 +657,7 @@ type TPMSTaggedProperty struct { // TPMSTaggedPCRSelect represents a TPMS_TAGGED_PCR_SELECT. // See definition in Part 2: Structures, section 10.8.3. type TPMSTaggedPCRSelect struct { + marshalByReflection // the property identifier Tag TPMPTPCR // the bit map of PCR with the identified property @@ -648,6 +667,7 @@ type TPMSTaggedPCRSelect struct { // TPMSTaggedPolicy represents a TPMS_TAGGED_POLICY. // See definition in Part 2: Structures, section 10.8.4. type TPMSTaggedPolicy struct { + marshalByReflection // a permanent handle Handle TPMHandle // the policy algorithm and hash @@ -657,6 +677,7 @@ type TPMSTaggedPolicy struct { // TPMSACTData represents a TPMS_ACT_DATA. // See definition in Part 2: Structures, section 10.8.5. type TPMSACTData struct { + marshalByReflection // a permanent handle Handle TPMHandle // the current timeout of the ACT @@ -668,30 +689,35 @@ type TPMSACTData struct { // TPMLCC represents a TPML_CC. // See definition in Part 2: Structures, section 10.9.1. type TPMLCC struct { + marshalByReflection CommandCodes []TPMCC `gotpm:"list"` } // TPMLCCA represents a TPML_CCA. // See definition in Part 2: Structures, section 10.9.2. type TPMLCCA struct { + marshalByReflection CommandAttributes []TPMACC `gotpm:"list"` } -// TPMLAlg represents a TPMLALG. +// TPMLAlg represents a TPML_ALG. // See definition in Part 2: Structures, section 10.9.3. type TPMLAlg struct { + marshalByReflection Algorithms []TPMAlgID `gotpm:"list"` } // TPMLHandle represents a TPML_HANDLE. // See definition in Part 2: Structures, section 10.9.4. type TPMLHandle struct { + marshalByReflection Handle []TPMHandle `gotpm:"list"` } // TPMLDigest represents a TPML_DIGEST. // See definition in Part 2: Structures, section 10.9.5. type TPMLDigest struct { + marshalByReflection // a list of digests Digests []TPM2BDigest `gotpm:"list"` } @@ -699,6 +725,7 @@ type TPMLDigest struct { // TPMLDigestValues represents a TPML_DIGEST_VALUES. // See definition in Part 2: Structures, section 10.9.6. type TPMLDigestValues struct { + marshalByReflection // a list of tagged digests Digests []TPMTHA `gotpm:"list"` } @@ -706,64 +733,293 @@ type TPMLDigestValues struct { // TPMLPCRSelection represents a TPML_PCR_SELECTION. // See definition in Part 2: Structures, section 10.9.7. type TPMLPCRSelection struct { + marshalByReflection PCRSelections []TPMSPCRSelection `gotpm:"list"` } // TPMLAlgProperty represents a TPML_ALG_PROPERTY. // See definition in Part 2: Structures, section 10.9.8. type TPMLAlgProperty struct { + marshalByReflection AlgProperties []TPMSAlgProperty `gotpm:"list"` } // TPMLTaggedTPMProperty represents a TPML_TAGGED_TPM_PROPERTY. // See definition in Part 2: Structures, section 10.9.9. type TPMLTaggedTPMProperty struct { + marshalByReflection TPMProperty []TPMSTaggedProperty `gotpm:"list"` } // TPMLTaggedPCRProperty represents a TPML_TAGGED_PCR_PROPERTY. // See definition in Part 2: Structures, section 10.9.10. type TPMLTaggedPCRProperty struct { + marshalByReflection PCRProperty []TPMSTaggedPCRSelect `gotpm:"list"` } // TPMLECCCurve represents a TPML_ECC_CURVE. // See definition in Part 2: Structures, section 10.9.11. type TPMLECCCurve struct { + marshalByReflection ECCCurves []TPMECCCurve `gotpm:"list"` } // TPMLTaggedPolicy represents a TPML_TAGGED_POLICY. // See definition in Part 2: Structures, section 10.9.12. type TPMLTaggedPolicy struct { + marshalByReflection Policies []TPMSTaggedPolicy `gotpm:"list"` } // TPMLACTData represents a TPML_ACT_DATA. // See definition in Part 2: Structures, section 10.9.13. type TPMLACTData struct { + marshalByReflection ACTData []TPMSACTData `gotpm:"list"` } // TPMUCapabilities represents a TPMU_CAPABILITIES. // See definition in Part 2: Structures, section 10.10.1. type TPMUCapabilities struct { - Algorithms *TPMLAlgProperty `gotpm:"selector=0x00000000"` // TPM_CAP_ALGS - Handles *TPMLHandle `gotpm:"selector=0x00000001"` // TPM_CAP_HANDLES - Command *TPMLCCA `gotpm:"selector=0x00000002"` // TPM_CAP_COMMANDS - PPCommands *TPMLCC `gotpm:"selector=0x00000003"` // TPM_CAP_PP_COMMANDS - AuditCommands *TPMLCC `gotpm:"selector=0x00000004"` // TPM_CAP_AUDIT_COMMANDS - AssignedPCR *TPMLPCRSelection `gotpm:"selector=0x00000005"` // TPM_CAP_PCRS - TPMProperties *TPMLTaggedTPMProperty `gotpm:"selector=0x00000006"` // TPM_CAP_TPM_PROPERTIES - PCRProperties *TPMLTaggedPCRProperty `gotpm:"selector=0x00000007"` // TPM_CAP_PCR_PROPERTIES - ECCCurves *TPMLECCCurve `gotpm:"selector=0x00000008"` // TPM_CAP_ECC_CURVES - AuthPolicies *TPMLTaggedPolicy `gotpm:"selector=0x00000009"` // TPM_CAP_AUTH_POLICIES - ACTData *TPMLACTData `gotpm:"selector=0x0000000A"` // TPM_CAP_ACT + selector TPMCap + contents Marshallable +} + +// CapabilitiesContents is a type constraint representing the possible contents of TPMUCapabilities. +type CapabilitiesContents interface { + Marshallable + *TPMLAlgProperty | *TPMLHandle | *TPMLCCA | *TPMLCC | *TPMLPCRSelection | *TPMLTaggedTPMProperty | + *TPMLTaggedPCRProperty | *TPMLECCCurve | *TPMLTaggedPolicy | *TPMLACTData +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUCapabilities) create(hint int64) (reflect.Value, error) { + switch TPMCap(hint) { + case TPMCapAlgs: + contents := TPMLAlgProperty{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + case TPMCapHandles: + contents := TPMLHandle{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + case TPMCapCommands: + contents := TPMLCCA{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + case TPMCapPPCommands, TPMCapAuditCommands: + contents := TPMLCC{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + case TPMCapPCRs: + contents := TPMLPCRSelection{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + case TPMCapTPMProperties: + contents := TPMLTaggedTPMProperty{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + case TPMCapPCRProperties: + contents := TPMLTaggedPCRProperty{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + case TPMCapECCCurves: + contents := TPMLECCCurve{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + case TPMCapAuthPolicies: + contents := TPMLTaggedPolicy{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + case TPMCapACT: + contents := TPMLACTData{} + u.contents = &contents + u.selector = TPMCap(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUCapabilities) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMCap(hint) { + case TPMCapAlgs: + contents := TPMLAlgProperty{} + if u.contents != nil { + contents = *u.contents.(*TPMLAlgProperty) + } + return reflect.ValueOf(&contents), nil + case TPMCapHandles: + contents := TPMLHandle{} + if u.contents != nil { + contents = *u.contents.(*TPMLHandle) + } + return reflect.ValueOf(&contents), nil + case TPMCapCommands: + contents := TPMLCCA{} + if u.contents != nil { + contents = *u.contents.(*TPMLCCA) + } + return reflect.ValueOf(&contents), nil + case TPMCapPPCommands, TPMCapAuditCommands: + contents := TPMLCC{} + if u.contents != nil { + contents = *u.contents.(*TPMLCC) + } + return reflect.ValueOf(&contents), nil + case TPMCapPCRs: + contents := TPMLPCRSelection{} + if u.contents != nil { + contents = *u.contents.(*TPMLPCRSelection) + } + return reflect.ValueOf(&contents), nil + case TPMCapTPMProperties: + contents := TPMLTaggedTPMProperty{} + if u.contents != nil { + contents = *u.contents.(*TPMLTaggedTPMProperty) + } + return reflect.ValueOf(&contents), nil + case TPMCapPCRProperties: + contents := TPMLTaggedPCRProperty{} + if u.contents != nil { + contents = *u.contents.(*TPMLTaggedPCRProperty) + } + return reflect.ValueOf(&contents), nil + case TPMCapECCCurves: + contents := TPMLECCCurve{} + if u.contents != nil { + contents = *u.contents.(*TPMLECCCurve) + } + return reflect.ValueOf(&contents), nil + case TPMCapAuthPolicies: + contents := TPMLTaggedPolicy{} + if u.contents != nil { + contents = *u.contents.(*TPMLTaggedPolicy) + } + return reflect.ValueOf(&contents), nil + case TPMCapACT: + contents := TPMLACTData{} + if u.contents != nil { + contents = *u.contents.(*TPMLACTData) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUCapabilities instantiates a TPMUCapabilities with the given contents. +func NewTPMUCapabilities[C CapabilitiesContents](selector TPMCap, contents C) TPMUCapabilities { + return TPMUCapabilities{ + selector: selector, + contents: contents, + } +} + +// Algorithms returns the 'algorithms' member of the union. +func (u *TPMUCapabilities) Algorithms() (*TPMLAlgProperty, error) { + if u.selector == TPMCapAlgs { + return u.contents.(*TPMLAlgProperty), nil + } + return nil, fmt.Errorf("did not contain algorithms (selector value was %v)", u.selector) +} + +// Handles returns the 'handles' member of the union. +func (u *TPMUCapabilities) Handles() (*TPMLHandle, error) { + if u.selector == TPMCapHandles { + return u.contents.(*TPMLHandle), nil + } + return nil, fmt.Errorf("did not contain handles (selector value was %v)", u.selector) +} + +// Command returns the 'command' member of the union. +func (u *TPMUCapabilities) Command() (*TPMLCCA, error) { + if u.selector == TPMCapAlgs { + return u.contents.(*TPMLCCA), nil + } + return nil, fmt.Errorf("did not contain command (selector value was %v)", u.selector) +} + +// PPCommands returns the 'ppCommands' member of the union. +func (u *TPMUCapabilities) PPCommands() (*TPMLCC, error) { + if u.selector == TPMCapPPCommands { + return u.contents.(*TPMLCC), nil + } + return nil, fmt.Errorf("did not contain ppCommands (selector value was %v)", u.selector) +} + +// AuditCommands returns the 'auditCommands' member of the union. +func (u *TPMUCapabilities) AuditCommands() (*TPMLCC, error) { + if u.selector == TPMCapAuditCommands { + return u.contents.(*TPMLCC), nil + } + return nil, fmt.Errorf("did not contain auditCommands (selector value was %v)", u.selector) +} + +// AssignedPCR returns the 'assignedPCR' member of the union. +func (u *TPMUCapabilities) AssignedPCR() (*TPMLPCRSelection, error) { + if u.selector == TPMCapPCRs { + return u.contents.(*TPMLPCRSelection), nil + } + return nil, fmt.Errorf("did not contain assignedPCR (selector value was %v)", u.selector) +} + +// TPMProperties returns the 'tpmProperties' member of the union. +func (u *TPMUCapabilities) TPMProperties() (*TPMLTaggedTPMProperty, error) { + if u.selector == TPMCapTPMProperties { + return u.contents.(*TPMLTaggedTPMProperty), nil + } + return nil, fmt.Errorf("did not contain tpmProperties (selector value was %v)", u.selector) +} + +// PCRProperties returns the 'pcrProperties' member of the union. +func (u *TPMUCapabilities) PCRProperties() (*TPMLTaggedPCRProperty, error) { + if u.selector == TPMCapPCRProperties { + return u.contents.(*TPMLTaggedPCRProperty), nil + } + return nil, fmt.Errorf("did not contain pcrProperties (selector value was %v)", u.selector) +} + +// ECCCurves returns the 'eccCurves' member of the union. +func (u *TPMUCapabilities) ECCCurves() (*TPMLECCCurve, error) { + if u.selector == TPMCapECCCurves { + return u.contents.(*TPMLECCCurve), nil + } + return nil, fmt.Errorf("did not contain eccCurves (selector value was %v)", u.selector) +} + +// AuthPolicies returns the 'authPolicies' member of the union. +func (u *TPMUCapabilities) AuthPolicies() (*TPMLTaggedPolicy, error) { + if u.selector == TPMCapAuthPolicies { + return u.contents.(*TPMLTaggedPolicy), nil + } + return nil, fmt.Errorf("did not contain authPolicies (selector value was %v)", u.selector) +} + +// ACTData returns the 'actData' member of the union. +func (u *TPMUCapabilities) ACTData() (*TPMLACTData, error) { + if u.selector == TPMCapAuthPolicies { + return u.contents.(*TPMLACTData), nil + } + return nil, fmt.Errorf("did not contain actData (selector value was %v)", u.selector) } // TPMSCapabilityData represents a TPMS_CAPABILITY_DATA. // See definition in Part 2: Structures, section 10.10.2. type TPMSCapabilityData struct { + marshalByReflection // the capability Capability TPMCap // the capability data @@ -773,6 +1029,7 @@ type TPMSCapabilityData struct { // TPMSClockInfo represents a TPMS_CLOCK_INFO. // See definition in Part 2: Structures, section 10.11.1. type TPMSClockInfo struct { + marshalByReflection // time value in milliseconds that advances while the TPM is powered Clock uint64 // number of occurrences of TPM Reset since the last TPM2_Clear() @@ -788,6 +1045,7 @@ type TPMSClockInfo struct { // TPMSTimeInfo represents a TPMS_TIMEzINFO. // See definition in Part 2: Structures, section 10.11.6. type TPMSTimeInfo struct { + marshalByReflection // time in milliseconds since the TIme circuit was last reset Time uint64 // a structure containing the clock information @@ -797,6 +1055,7 @@ type TPMSTimeInfo struct { // TPMSTimeAttestInfo represents a TPMS_TIME_ATTEST_INFO. // See definition in Part 2: Structures, section 10.12.2. type TPMSTimeAttestInfo struct { + marshalByReflection // the Time, Clock, resetCount, restartCount, and Safe indicator Time TPMSTimeInfo // a TPM vendor-specific value indicating the version number of the firmware @@ -806,6 +1065,7 @@ type TPMSTimeAttestInfo struct { // TPMSCertifyInfo represents a TPMS_CERTIFY_INFO. // See definition in Part 2: Structures, section 10.12.3. type TPMSCertifyInfo struct { + marshalByReflection // Name of the certified object Name TPM2BName // Qualified Name of the certified object @@ -815,6 +1075,7 @@ type TPMSCertifyInfo struct { // TPMSQuoteInfo represents a TPMS_QUOTE_INFO. // See definition in Part 2: Structures, section 10.12.4. type TPMSQuoteInfo struct { + marshalByReflection // information on algID, PCR selected and digest PCRSelect TPMLPCRSelection // digest of the selected PCR using the hash of the signing key @@ -824,6 +1085,7 @@ type TPMSQuoteInfo struct { // TPMSCommandAuditInfo represents a TPMS_COMMAND_AUDIT_INFO. // See definition in Part 2: Structures, section 10.12.5. type TPMSCommandAuditInfo struct { + marshalByReflection // the monotonic audit counter AuditCounter uint64 // hash algorithm used for the command audit @@ -837,6 +1099,7 @@ type TPMSCommandAuditInfo struct { // TPMSSessionAuditInfo represents a TPMS_SESSION_AUDIT_INFO. // See definition in Part 2: Structures, section 10.12.6. type TPMSSessionAuditInfo struct { + marshalByReflection // current exclusive status of the session ExclusiveSession TPMIYesNo // the current value of the session audit digest @@ -846,6 +1109,7 @@ type TPMSSessionAuditInfo struct { // TPMSCreationInfo represents a TPMS_CREATION_INFO. // See definition in Part 2: Structures, section 10.12.7. type TPMSCreationInfo struct { + marshalByReflection // Name of the object ObjectName TPM2BName // creationHash @@ -855,6 +1119,7 @@ type TPMSCreationInfo struct { // TPMSNVCertifyInfo represents a TPMS_NV_CERTIFY_INFO. // See definition in Part 2: Structures, section 10.12.8. type TPMSNVCertifyInfo struct { + marshalByReflection // Name of the NV Index IndexName TPM2BName // the offset parameter of TPM2_NV_Certify() @@ -866,6 +1131,7 @@ type TPMSNVCertifyInfo struct { // TPMSNVDigestCertifyInfo represents a TPMS_NV_DIGEST_CERTIFY_INFO. // See definition in Part 2: Structures, section 10.12.9. type TPMSNVDigestCertifyInfo struct { + marshalByReflection // Name of the NV Index IndexName TPM2BName // hash of the contents of the index @@ -879,19 +1145,198 @@ type TPMISTAttest = TPMST // TPMUAttest represents a TPMU_ATTEST. // See definition in Part 2: Structures, section 10.12.11. type TPMUAttest struct { - NV *TPMSNVCertifyInfo `gotpm:"selector=0x8014"` // TPM_ST_ATTEST_NV - CommandAudit *TPMSCommandAuditInfo `gotpm:"selector=0x8015"` // TPM_ST_ATTEST_COMMAND_AUDIT - SessionAudit *TPMSSessionAuditInfo `gotpm:"selector=0x8016"` // TPM_ST_ATTEST_SESSION_AUDIT - Certify *TPMSCertifyInfo `gotpm:"selector=0x8017"` // TPM_ST_ATTEST_CERTIFY - Quote *TPMSQuoteInfo `gotpm:"selector=0x8018"` // TPM_ST_ATTEST_QUOTE - Time *TPMSTimeAttestInfo `gotpm:"selector=0x8019"` // TPM_ST_ATTEST_TIME - Creation *TPMSCreationInfo `gotpm:"selector=0x801A"` // TPM_ST_ATTEST_CREATION - NVDigest *TPMSNVDigestCertifyInfo `gotpm:"selector=0x801C"` // TPM_ST_ATTEST_NV_DIGEST + selector TPMST + contents Marshallable +} + +// AttestContents is a type constraint representing the possible contents of TPMUAttest. +type AttestContents interface { + Marshallable + *TPMSNVCertifyInfo | *TPMSCommandAuditInfo | *TPMSSessionAuditInfo | *TPMSCertifyInfo | + *TPMSQuoteInfo | *TPMSTimeAttestInfo | *TPMSCreationInfo | *TPMSNVDigestCertifyInfo +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUAttest) create(hint int64) (reflect.Value, error) { + switch TPMST(hint) { + case TPMSTAttestNV: + contents := TPMSNVCertifyInfo{} + u.contents = &contents + u.selector = TPMST(hint) + return reflect.ValueOf(&contents), nil + case TPMSTAttestCommandAudit: + contents := TPMSCommandAuditInfo{} + u.contents = &contents + u.selector = TPMST(hint) + return reflect.ValueOf(&contents), nil + case TPMSTAttestSessionAudit: + contents := TPMSSessionAuditInfo{} + u.contents = &contents + u.selector = TPMST(hint) + return reflect.ValueOf(&contents), nil + case TPMSTAttestCertify: + contents := TPMSCertifyInfo{} + u.contents = &contents + u.selector = TPMST(hint) + return reflect.ValueOf(&contents), nil + case TPMSTAttestQuote: + contents := TPMSQuoteInfo{} + u.contents = &contents + u.selector = TPMST(hint) + return reflect.ValueOf(&contents), nil + case TPMSTAttestTime: + contents := TPMSTimeAttestInfo{} + u.contents = &contents + u.selector = TPMST(hint) + return reflect.ValueOf(&contents), nil + case TPMSTAttestCreation: + contents := TPMSCreationInfo{} + u.contents = &contents + u.selector = TPMST(hint) + return reflect.ValueOf(&contents), nil + case TPMSTAttestNVDigest: + contents := TPMSNVDigestCertifyInfo{} + u.contents = &contents + u.selector = TPMST(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUAttest) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMST(hint) { + case TPMSTAttestNV: + contents := TPMSNVCertifyInfo{} + if u.contents != nil { + contents = *u.contents.(*TPMSNVCertifyInfo) + } + return reflect.ValueOf(&contents), nil + case TPMSTAttestCommandAudit: + contents := TPMSCommandAuditInfo{} + if u.contents != nil { + contents = *u.contents.(*TPMSCommandAuditInfo) + } + return reflect.ValueOf(&contents), nil + case TPMSTAttestSessionAudit: + contents := TPMSSessionAuditInfo{} + if u.contents != nil { + contents = *u.contents.(*TPMSSessionAuditInfo) + } + return reflect.ValueOf(&contents), nil + case TPMSTAttestCertify: + contents := TPMSCertifyInfo{} + if u.contents != nil { + contents = *u.contents.(*TPMSCertifyInfo) + } + return reflect.ValueOf(&contents), nil + case TPMSTAttestQuote: + contents := TPMSQuoteInfo{} + if u.contents != nil { + contents = *u.contents.(*TPMSQuoteInfo) + } + return reflect.ValueOf(&contents), nil + case TPMSTAttestTime: + contents := TPMSTimeAttestInfo{} + if u.contents != nil { + contents = *u.contents.(*TPMSTimeAttestInfo) + } + return reflect.ValueOf(&contents), nil + case TPMSTAttestCreation: + contents := TPMSCreationInfo{} + if u.contents != nil { + contents = *u.contents.(*TPMSCreationInfo) + } + return reflect.ValueOf(&contents), nil + case TPMSTAttestNVDigest: + contents := TPMSNVDigestCertifyInfo{} + if u.contents != nil { + contents = *u.contents.(*TPMSNVDigestCertifyInfo) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUAttest instantiates a TPMUAttest with the given contents. +func NewTPMUAttest[C AttestContents](selector TPMST, contents C) TPMUAttest { + return TPMUAttest{ + selector: selector, + contents: contents, + } +} + +// Certify returns the 'certify' member of the union. +func (u *TPMUAttest) Certify() (*TPMSCertifyInfo, error) { + if u.selector == TPMSTAttestCertify { + return u.contents.(*TPMSCertifyInfo), nil + } + return nil, fmt.Errorf("did not contain certify (selector value was %v)", u.selector) +} + +// Creation returns the 'creation' member of the union. +func (u *TPMUAttest) Creation() (*TPMSCreationInfo, error) { + if u.selector == TPMSTAttestCreation { + return u.contents.(*TPMSCreationInfo), nil + } + return nil, fmt.Errorf("did not contain creation (selector value was %v)", u.selector) +} + +// Quote returns the 'quote' member of the union. +func (u *TPMUAttest) Quote() (*TPMSQuoteInfo, error) { + if u.selector == TPMSTAttestQuote { + return u.contents.(*TPMSQuoteInfo), nil + } + return nil, fmt.Errorf("did not contain quote (selector value was %v)", u.selector) +} + +// CommandAudit returns the 'commandAudit' member of the union. +func (u *TPMUAttest) CommandAudit() (*TPMSCommandAuditInfo, error) { + if u.selector == TPMSTAttestCommandAudit { + return u.contents.(*TPMSCommandAuditInfo), nil + } + return nil, fmt.Errorf("did not contain commandAudit (selector value was %v)", u.selector) +} + +// SessionAudit returns the 'sessionAudit' member of the union. +func (u *TPMUAttest) SessionAudit() (*TPMSSessionAuditInfo, error) { + if u.selector == TPMSTAttestSessionAudit { + return u.contents.(*TPMSSessionAuditInfo), nil + } + return nil, fmt.Errorf("did not contain sessionAudit (selector value was %v)", u.selector) +} + +// Time returns the 'time' member of the union. +func (u *TPMUAttest) Time() (*TPMSTimeAttestInfo, error) { + if u.selector == TPMSTAttestTime { + return u.contents.(*TPMSTimeAttestInfo), nil + } + return nil, fmt.Errorf("did not contain time (selector value was %v)", u.selector) +} + +// NV returns the 'nv' member of the union. +func (u *TPMUAttest) NV() (*TPMSNVCertifyInfo, error) { + if u.selector == TPMSTAttestNV { + return u.contents.(*TPMSNVCertifyInfo), nil + } + return nil, fmt.Errorf("did not contain nv (selector value was %v)", u.selector) +} + +// NVDigest returns the 'nvDigest' member of the union. +func (u *TPMUAttest) NVDigest() (*TPMSNVDigestCertifyInfo, error) { + if u.selector == TPMSTAttestNVDigest { + return u.contents.(*TPMSNVDigestCertifyInfo), nil + } + return nil, fmt.Errorf("did not contain nvDigest (selector value was %v)", u.selector) } // TPMSAttest represents a TPMS_ATTEST. // See definition in Part 2: Structures, section 10.12.12. type TPMSAttest struct { + marshalByReflection // the indication that this structure was created by a TPM (always TPM_GENERATED_VALUE) Magic TPMGenerated `gotpm:"check"` // type of the attestation structure @@ -910,16 +1355,12 @@ type TPMSAttest struct { // TPM2BAttest represents a TPM2B_ATTEST. // See definition in Part 2: Structures, section 10.12.13. -// Note that in the spec, this is just a 2B_DATA with enough room for an S_ATTEST. -// For ergonomics, pretend that TPM2B_Attest wraps a TPMS_Attest just like other 2Bs. -type TPM2BAttest struct { - // the signed structure - AttestationData TPMSAttest `gotpm:"sized"` -} +type TPM2BAttest = TPM2B[TPMSAttest, *TPMSAttest] // TPMSAuthCommand represents a TPMS_AUTH_COMMAND. // See definition in Part 2: Structures, section 10.13.2. type TPMSAuthCommand struct { + marshalByReflection Handle TPMISHAuthSession Nonce TPM2BNonce Attributes TPMASession @@ -929,6 +1370,7 @@ type TPMSAuthCommand struct { // TPMSAuthResponse represents a TPMS_AUTH_RESPONSE. // See definition in Part 2: Structures, section 10.13.3. type TPMSAuthResponse struct { + marshalByReflection Nonce TPM2BNonce Attributes TPMASession Authorization TPM2BData @@ -937,33 +1379,208 @@ type TPMSAuthResponse struct { // TPMUSymKeyBits represents a TPMU_SYM_KEY_BITS. // See definition in Part 2: Structures, section 11.1.3. type TPMUSymKeyBits struct { - // TODO: The rest of the symmetric algorithms get their own entry - // in this union. - AES *TPMKeyBits `gotpm:"selector=0x0006"` // TPM_ALG_AES - XOR *TPMIAlgHash `gotpm:"selector=0x000A"` // TPM_ALG_XOR + selector TPMAlgID + contents Marshallable +} + +// SymKeyBitsContents is a type constraint representing the possible contents of TPMUSymKeyBits. +type SymKeyBitsContents interface { + TPMKeyBits | TPMAlgID +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUSymKeyBits) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgAES: + var contents boxed[TPMKeyBits] + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgXOR: + var contents boxed[TPMAlgID] + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUSymKeyBits) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgAES: + var contents boxed[TPMKeyBits] + if u.contents != nil { + contents = *u.contents.(*boxed[TPMKeyBits]) + } + return reflect.ValueOf(&contents), nil + case TPMAlgXOR: + var contents boxed[TPMAlgID] + if u.contents != nil { + contents = *u.contents.(*boxed[TPMAlgID]) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUSymKeyBits instantiates a TPMUSymKeyBits with the given contents. +func NewTPMUSymKeyBits[C SymKeyBitsContents](selector TPMAlgID, contents C) TPMUSymKeyBits { + boxed := box(&contents) + return TPMUSymKeyBits{ + selector: selector, + contents: &boxed, + } +} + +// AES returns the 'aes' member of the union. +func (u *TPMUSymKeyBits) AES() (*TPMKeyBits, error) { + if u.selector == TPMAlgAES { + value := u.contents.(*boxed[TPMKeyBits]).unbox() + return value, nil + } + return nil, fmt.Errorf("did not contain aes (selector value was %v)", u.selector) +} + +// XOR returns the 'xor' member of the union. +func (u *TPMUSymKeyBits) XOR() (*TPMAlgID, error) { + if u.selector == TPMAlgXOR { + value := u.contents.(*boxed[TPMAlgID]).unbox() + return value, nil + } + return nil, fmt.Errorf("did not contain xor (selector value was %v)", u.selector) } // TPMUSymMode represents a TPMU_SYM_MODE. // See definition in Part 2: Structures, section 11.1.4. type TPMUSymMode struct { - // TODO: The rest of the symmetric algorithms get their own entry - // in this union. - AES *TPMIAlgSymMode `gotpm:"selector=0x0006"` // TPM_ALG_AES - XOR *struct{} `gotpm:"selector=0x000A"` // TPM_ALG_XOR + selector TPMAlgID + contents Marshallable +} + +// SymModeContents is a type constraint representing the possible contents of TPMUSymMode. +type SymModeContents interface { + TPMIAlgSymMode | TPMSEmpty +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUSymMode) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgAES: + var contents boxed[TPMAlgID] + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgXOR: + var contents boxed[TPMSEmpty] + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUSymMode) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgAES: + var contents boxed[TPMAlgID] + if u.contents != nil { + contents = *u.contents.(*boxed[TPMAlgID]) + } + return reflect.ValueOf(&contents), nil + case TPMAlgXOR: + var contents boxed[TPMSEmpty] + if u.contents != nil { + contents = *u.contents.(*boxed[TPMSEmpty]) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUSymMode instantiates a TPMUSymMode with the given contents. +func NewTPMUSymMode[C SymModeContents](selector TPMAlgID, contents C) TPMUSymMode { + boxed := box(&contents) + return TPMUSymMode{ + selector: selector, + contents: &boxed, + } +} + +// AES returns the 'aes' member of the union. +func (u *TPMUSymMode) AES() (*TPMIAlgSymMode, error) { + if u.selector == TPMAlgAES { + value := u.contents.(*boxed[TPMIAlgSymMode]).unbox() + return value, nil + } + return nil, fmt.Errorf("did not contain aes (selector value was %v)", u.selector) } // TPMUSymDetails represents a TPMU_SYM_DETAILS. // See definition in Part 2: Structures, section 11.1.5. type TPMUSymDetails struct { - // TODO: The rest of the symmetric algorithms get their own entry - // in this union. - AES *struct{} `gotpm:"selector=0x0006"` // TPM_ALG_AES - XOR *struct{} `gotpm:"selector=0x000A"` // TPM_ALG_XOR + selector TPMAlgID + contents Marshallable +} + +// SymDetailsContents is a type constraint representing the possible contents of TPMUSymDetails. +type SymDetailsContents interface { + TPMSEmpty +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUSymDetails) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgAES: + var contents boxed[TPMSEmpty] + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgXOR: + var contents boxed[TPMSEmpty] + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUSymDetails) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgAES, TPMAlgXOR: + var contents boxed[TPMSEmpty] + if u.contents != nil { + contents = *u.contents.(*boxed[TPMSEmpty]) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUSymDetails instantiates a TPMUSymDetails with the given contents. +func NewTPMUSymDetails[C SymDetailsContents](selector TPMAlgID, contents C) TPMUSymMode { + boxed := box(&contents) + return TPMUSymMode{ + selector: selector, + contents: &boxed, + } } // TPMTSymDef represents a TPMT_SYM_DEF. // See definition in Part 2: Structures, section 11.1.6. type TPMTSymDef struct { + marshalByReflection // indicates a symmetric algorithm Algorithm TPMIAlgSym `gotpm:"nullable"` // the key size @@ -977,6 +1594,7 @@ type TPMTSymDef struct { // TPMTSymDefObject represents a TPMT_SYM_DEF_OBJECT. // See definition in Part 2: Structures, section 11.1.7. type TPMTSymDefObject struct { + marshalByReflection // selects a symmetric block cipher // When used in the parameter area of a parent object, this shall // be a supported block cipher and not TPM_ALG_NULL @@ -998,6 +1616,7 @@ type TPM2BSymKey TPM2BData // TPMSSymCipherParms represents a TPMS_SYMCIPHER_PARMS. // See definition in Part 2: Structures, section 11.1.9. type TPMSSymCipherParms struct { + marshalByReflection // a symmetric block cipher Sym TPMTSymDefObject } @@ -1009,25 +1628,42 @@ type TPM2BLabel TPM2BData // TPMSDerive represents a TPMS_DERIVE. // See definition in Part 2: Structures, section 11.1.11. type TPMSDerive struct { + marshalByReflection Label TPM2BLabel Context TPM2BLabel } // TPM2BDerive represents a TPM2B_DERIVE. // See definition in Part 2: Structures, section 11.1.12. -type TPM2BDerive struct { - Buffer TPMSDerive `gotpm:"sized"` -} +type TPM2BDerive = TPM2B[TPMSDerive, *TPMSDerive] // TPMUSensitiveCreate represents a TPMU_SENSITIVE_CREATE. // See definition in Part 2: Structures, section 11.1.13. -// Since the TPM cannot return this type, it can be an interface. -type TPMUSensitiveCreate interface { - tpmusensitivecreate() +type TPMUSensitiveCreate struct { + contents Marshallable } -func (TPM2BSensitiveData) tpmusensitivecreate() {} -func (TPM2BDerive) tpmusensitivecreate() {} +// SensitiveCreateContents is a type constraint representing the possible contents of TPMUSensitiveCreate. +type SensitiveCreateContents interface { + Marshallable + *TPM2BDerive | *TPM2BSensitiveData +} + +// marshal implements the Marshallable interface. +func (u TPMUSensitiveCreate) marshal(buf *bytes.Buffer) { + if u.contents != nil { + buf.Write(Marshal(u.contents)) + } else { + // If this is a zero-valued structure, marshal a default TPM2BSensitiveData. + var defaultValue TPM2BSensitiveData + buf.Write(Marshal(&defaultValue)) + } +} + +// NewTPMUSensitiveCreate instantiates a TPMUSensitiveCreate with the given contents. +func NewTPMUSensitiveCreate[C SensitiveCreateContents](contents C) TPMUSensitiveCreate { + return TPMUSensitiveCreate{contents: contents} +} // TPM2BSensitiveData represents a TPM2B_SENSITIVE_DATA. // See definition in Part 2: Structures, section 11.1.14. @@ -1036,6 +1672,7 @@ type TPM2BSensitiveData TPM2BData // TPMSSensitiveCreate represents a TPMS_SENSITIVE_CREATE. // See definition in Part 2: Structures, section 11.1.15. type TPMSSensitiveCreate struct { + marshalByReflection // the USER auth secret value. UserAuth TPM2BAuth // data to be sealed, a key, or derivation values. @@ -1044,14 +1681,33 @@ type TPMSSensitiveCreate struct { // TPM2BSensitiveCreate represents a TPM2B_SENSITIVE_CREATE. // See definition in Part 2: Structures, section 11.1.16. +// This is a structure instead of an alias to TPM2B[TPMSSensitiveCreate], +// because it has custom marshalling logic for zero-valued parameters. type TPM2BSensitiveCreate struct { - // data to be sealed or a symmetric key value. - Sensitive TPMSSensitiveCreate `gotpm:"sized"` + Sensitive *TPMSSensitiveCreate +} + +// Quirk: When this structure is zero-valued, we need to marshal +// a 2B-wrapped zero-valued TPMS_SENSITIVE_CREATE instead of +// [0x00, 0x00] (a zero-valued 2B). +func (c TPM2BSensitiveCreate) marshal(buf *bytes.Buffer) { + var marshalled TPM2B[TPMSSensitiveCreate, *TPMSSensitiveCreate] + if c.Sensitive != nil { + marshalled = New2B(*c.Sensitive) + } else { + // If no value was provided (i.e., this is a zero-valued structure), + // provide an 2B containing a zero-valued TPMS_SensitiveCreate. + marshalled = New2B(TPMSSensitiveCreate{ + Data: NewTPMUSensitiveCreate(&TPM2BSensitiveData{}), + }) + } + marshalled.marshal(buf) } // TPMSSchemeHash represents a TPMS_SCHEME_HASH. // See definition in Part 2: Structures, section 11.1.17. type TPMSSchemeHash struct { + marshalByReflection // the hash algorithm used to digest the message HashAlg TPMIAlgHash } @@ -1059,6 +1715,7 @@ type TPMSSchemeHash struct { // TPMSSchemeECDAA represents a TPMS_SCHEME_ECDAA. // See definition in Part 2: Structures, section 11.1.18. type TPMSSchemeECDAA struct { + marshalByReflection // the hash algorithm used to digest the message HashAlg TPMIAlgHash // the counter value that is used between TPM2_Commit() @@ -1077,6 +1734,7 @@ type TPMSSchemeHMAC TPMSSchemeHash // TPMSSchemeXOR represents a TPMS_SCHEME_XOR. // See definition in Part 2: Structures, section 11.1.21. type TPMSSchemeXOR struct { + marshalByReflection // the hash algorithm used to digest the message HashAlg TPMIAlgHash // the key derivation function @@ -1086,13 +1744,85 @@ type TPMSSchemeXOR struct { // TPMUSchemeKeyedHash represents a TPMU_SCHEME_KEYEDHASH. // See definition in Part 2: Structures, section 11.1.22. type TPMUSchemeKeyedHash struct { - HMAC *TPMSSchemeHMAC `gotpm:"selector=0x0005"` // TPM_ALG_HMAC - XOR *TPMSSchemeXOR `gotpm:"selector=0x000A"` // TPM_ALG_XOR + selector TPMAlgID + contents Marshallable +} + +// SchemeKeyedHashContents is a type constraint representing the possible contents of TPMUSchemeKeyedHash. +type SchemeKeyedHashContents interface { + Marshallable + *TPMSSchemeHMAC | *TPMSSchemeXOR +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUSchemeKeyedHash) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgHMAC: + var contents TPMSSchemeHMAC + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgXOR: + var contents TPMSSchemeXOR + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUSchemeKeyedHash) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgHMAC: + var contents TPMSSchemeHMAC + if u.contents != nil { + contents = *u.contents.(*TPMSSchemeHMAC) + } + return reflect.ValueOf(&contents), nil + case TPMAlgXOR: + var contents TPMSSchemeXOR + if u.contents != nil { + contents = *u.contents.(*TPMSSchemeXOR) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUSchemeKeyedHash instantiates a TPMUSchemeKeyedHash with the given contents. +func NewTPMUSchemeKeyedHash[C SchemeKeyedHashContents](selector TPMAlgID, contents C) TPMUSchemeKeyedHash { + return TPMUSchemeKeyedHash{ + selector: selector, + contents: contents, + } +} + +// HMAC returns the 'hmac' member of the union. +func (u *TPMUSchemeKeyedHash) HMAC() (*TPMSSchemeHMAC, error) { + if u.selector == TPMAlgHMAC { + value := u.contents.(*TPMSSchemeHMAC) + return value, nil + } + return nil, fmt.Errorf("did not contain hmac (selector value was %v)", u.selector) +} + +// XOR returns the 'xor' member of the union. +func (u *TPMUSchemeKeyedHash) XOR() (*TPMSSchemeXOR, error) { + if u.selector == TPMAlgXOR { + value := u.contents.(*TPMSSchemeXOR) + return value, nil + } + return nil, fmt.Errorf("did not contain xor (selector value was %v)", u.selector) } // TPMTKeyedHashScheme represents a TPMT_KEYEDHASH_SCHEME. // See definition in Part 2: Structures, section 11.1.23. type TPMTKeyedHashScheme struct { + marshalByReflection Scheme TPMIAlgKeyedHashScheme `gotpm:"nullable"` Details TPMUSchemeKeyedHash `gotpm:"tag=Scheme"` } @@ -1109,23 +1839,121 @@ type TPMSSigSchemeRSAPSS TPMSSchemeHash // See definition in Part 2: Structures, section 11.2.1.3. type TPMSSigSchemeECDSA TPMSSchemeHash -// TPMSSigSchemeECDAA represents a TPMS_SIG_SCHEME_ECDAA. -// See definition in Part 2: Structures, section 11.2.1.3. -type TPMSSigSchemeECDAA TPMSSchemeECDAA - // TPMUSigScheme represents a TPMU_SIG_SCHEME. // See definition in Part 2: Structures, section 11.2.1.4. type TPMUSigScheme struct { - HMAC *TPMSSchemeHMAC `gotpm:"selector=0x0005"` // TPM_ALG_HMAC - RSASSA *TPMSSchemeHash `gotpm:"selector=0x0014"` // TPM_ALG_RSASSA - RSAPSS *TPMSSchemeHash `gotpm:"selector=0x0016"` // TPM_ALG_RSAPSS - ECDSA *TPMSSchemeHash `gotpm:"selector=0x0018"` // TPM_ALG_ECDSA - ECDAA *TPMSSchemeECDAA `gotpm:"selector=0x001a"` // TPM_ALG_ECDAA + selector TPMAlgID + contents Marshallable +} + +// SigSchemeContents is a type constraint representing the possible contents of TPMUSigScheme. +type SigSchemeContents interface { + Marshallable + *TPMSSchemeHMAC | *TPMSSchemeHash | *TPMSSchemeECDAA +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUSigScheme) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgHMAC: + var contents TPMSSchemeHMAC + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgRSASSA, TPMAlgRSAPSS, TPMAlgECDSA: + var contents TPMSSchemeHash + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgECDAA: + var contents TPMSSchemeECDAA + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUSigScheme) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgHMAC: + var contents TPMSSchemeHMAC + if u.contents != nil { + contents = *u.contents.(*TPMSSchemeHMAC) + } + return reflect.ValueOf(&contents), nil + case TPMAlgRSASSA, TPMAlgRSAPSS, TPMAlgECDSA: + var contents TPMSSchemeHash + if u.contents != nil { + contents = *u.contents.(*TPMSSchemeHash) + } + return reflect.ValueOf(&contents), nil + case TPMAlgECDAA: + var contents TPMSSchemeECDAA + if u.contents != nil { + contents = *u.contents.(*TPMSSchemeECDAA) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUSigScheme instantiates a TPMUSigScheme with the given contents. +func NewTPMUSigScheme[C SigSchemeContents](selector TPMAlgID, contents C) TPMUSigScheme { + return TPMUSigScheme{ + selector: selector, + contents: contents, + } +} + +// HMAC returns the 'hmac' member of the union. +func (u *TPMUSigScheme) HMAC() (*TPMSSchemeHMAC, error) { + if u.selector == TPMAlgHMAC { + return u.contents.(*TPMSSchemeHMAC), nil + } + return nil, fmt.Errorf("did not contain hmac (selector value was %v)", u.selector) +} + +// RSASSA returns the 'rsassa' member of the union. +func (u *TPMUSigScheme) RSASSA() (*TPMSSchemeHash, error) { + if u.selector == TPMAlgRSASSA { + return u.contents.(*TPMSSchemeHash), nil + } + return nil, fmt.Errorf("did not contain rsassa (selector value was %v)", u.selector) +} + +// RSAPSS returns the 'rsapss' member of the union. +func (u *TPMUSigScheme) RSAPSS() (*TPMSSchemeHash, error) { + if u.selector == TPMAlgRSAPSS { + return u.contents.(*TPMSSchemeHash), nil + } + return nil, fmt.Errorf("did not contain rsapss (selector value was %v)", u.selector) +} + +// ECDSA returns the 'ecdsa' member of the union. +func (u *TPMUSigScheme) ECDSA() (*TPMSSchemeHash, error) { + if u.selector == TPMAlgECDSA { + return u.contents.(*TPMSSchemeHash), nil + } + return nil, fmt.Errorf("did not contain ecdsa (selector value was %v)", u.selector) +} + +// ECDAA returns the 'ecdaa' member of the union. +func (u *TPMUSigScheme) ECDAA() (*TPMSSchemeECDAA, error) { + if u.selector == TPMAlgECDAA { + return u.contents.(*TPMSSchemeECDAA), nil + } + return nil, fmt.Errorf("did not contain ecdaa (selector value was %v)", u.selector) } // TPMTSigScheme represents a TPMT_SIG_SCHEME. // See definition in Part 2: Structures, section 11.2.1.5. type TPMTSigScheme struct { + marshalByReflection Scheme TPMIAlgSigScheme `gotpm:"nullable"` Details TPMUSigScheme `gotpm:"tag=Scheme"` } @@ -1165,16 +1993,142 @@ type TPMSKDFSchemeKDF1SP800108 TPMSSchemeHash // TPMUKDFScheme represents a TPMU_KDF_SCHEME. // See definition in Part 2: Structures, section 11.2.3.2. type TPMUKDFScheme struct { - MGF1 *TPMSKDFSchemeMGF1 `gotpm:"selector=0x0007"` // TPM_ALG_MGF1 - ECDH *TPMSKDFSchemeECDH `gotpm:"selector=0x0019"` // TPM_ALG_ECDH - KDF1SP80056A *TPMSKDFSchemeKDF1SP80056A `gotpm:"selector=0x0020"` // TPM_ALG_KDF1_SP800_56A - KDF2 *TPMSKDFSchemeKDF2 `gotpm:"selector=0x0021"` // TPM_ALG_KDF2 - KDF1SP800108 *TPMSKDFSchemeKDF1SP800108 `gotpm:"selector=0x0022"` // TPM_ALG_KDF1_SP800_108 + selector TPMAlgID + contents Marshallable +} + +// KDFSchemeContents is a type constraint representing the possible contents of TPMUKDFScheme. +type KDFSchemeContents interface { + Marshallable + *TPMSKDFSchemeMGF1 | *TPMSKDFSchemeECDH | *TPMSKDFSchemeKDF1SP80056A | + *TPMSKDFSchemeKDF2 | *TPMSKDFSchemeKDF1SP800108 +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUKDFScheme) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgMGF1: + var contents TPMSKDFSchemeMGF1 + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgECDH: + var contents TPMSKDFSchemeECDH + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgKDF1SP80056A: + var contents TPMSKDFSchemeKDF1SP80056A + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgKDF2: + var contents TPMSKDFSchemeKDF2 + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgKDF1SP800108: + var contents TPMSKDFSchemeKDF1SP800108 + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUKDFScheme) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgMGF1: + var contents TPMSKDFSchemeMGF1 + if u.contents != nil { + contents = *u.contents.(*TPMSKDFSchemeMGF1) + } + return reflect.ValueOf(&contents), nil + case TPMAlgECDH: + var contents TPMSKDFSchemeECDH + if u.contents != nil { + contents = *u.contents.(*TPMSKDFSchemeECDH) + } + return reflect.ValueOf(&contents), nil + case TPMAlgKDF1SP80056A: + var contents TPMSKDFSchemeKDF1SP80056A + if u.contents != nil { + contents = *u.contents.(*TPMSKDFSchemeKDF1SP80056A) + } + return reflect.ValueOf(&contents), nil + case TPMAlgKDF2: + var contents TPMSKDFSchemeKDF2 + if u.contents != nil { + contents = *u.contents.(*TPMSKDFSchemeKDF2) + } + return reflect.ValueOf(&contents), nil + + case TPMAlgKDF1SP800108: + var contents TPMSKDFSchemeKDF1SP800108 + if u.contents != nil { + contents = *u.contents.(*TPMSKDFSchemeKDF1SP800108) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUKDFScheme instantiates a TPMUKDFScheme with the given contents. +func NewTPMUKDFScheme[C KDFSchemeContents](selector TPMAlgID, contents C) TPMUKDFScheme { + return TPMUKDFScheme{ + selector: selector, + contents: contents, + } +} + +// MGF1 returns the 'mgf1' member of the union. +func (u *TPMUKDFScheme) MGF1() (*TPMSKDFSchemeMGF1, error) { + if u.selector == TPMAlgMGF1 { + return u.contents.(*TPMSKDFSchemeMGF1), nil + } + return nil, fmt.Errorf("did not contain mgf1 (selector value was %v)", u.selector) +} + +// ECDH returns the 'ecdh' member of the union. +func (u *TPMUKDFScheme) ECDH() (*TPMSKDFSchemeECDH, error) { + if u.selector == TPMAlgECDH { + return u.contents.(*TPMSKDFSchemeECDH), nil + } + return nil, fmt.Errorf("did not contain ecdh (selector value was %v)", u.selector) +} + +// KDF1SP80056A returns the 'kdf1sp80056a' member of the union. +func (u *TPMUKDFScheme) KDF1SP80056A() (*TPMSKDFSchemeKDF1SP80056A, error) { + if u.selector == TPMAlgMGF1 { + return u.contents.(*TPMSKDFSchemeKDF1SP80056A), nil + } + return nil, fmt.Errorf("did not contain kdf1sp80056a (selector value was %v)", u.selector) +} + +// KDF2 returns the 'kdf2' member of the union. +func (u *TPMUKDFScheme) KDF2() (*TPMSKDFSchemeKDF2, error) { + if u.selector == TPMAlgMGF1 { + return u.contents.(*TPMSKDFSchemeKDF2), nil + } + return nil, fmt.Errorf("did not contain mgf1 (selector value was %v)", u.selector) +} + +// KDF1SP800108 returns the 'kdf1sp800108' member of the union. +func (u *TPMUKDFScheme) KDF1SP800108() (*TPMSKDFSchemeKDF1SP800108, error) { + if u.selector == TPMAlgMGF1 { + return u.contents.(*TPMSKDFSchemeKDF1SP800108), nil + } + return nil, fmt.Errorf("did not contain kdf1sp800108 (selector value was %v)", u.selector) } // TPMTKDFScheme represents a TPMT_KDF_SCHEME. // See definition in Part 2: Structures, section 11.2.3.3. type TPMTKDFScheme struct { + marshalByReflection // scheme selector Scheme TPMIAlgKDF `gotpm:"nullable"` // scheme parameters @@ -1184,14 +2138,173 @@ type TPMTKDFScheme struct { // TPMUAsymScheme represents a TPMU_ASYM_SCHEME. // See definition in Part 2: Structures, section 11.2.3.5. type TPMUAsymScheme struct { - // TODO every asym scheme gets an entry in this union. - RSASSA *TPMSSigSchemeRSASSA `gotpm:"selector=0x0014"` // TPM_ALG_RSASSA - RSAES *TPMSEncSchemeRSAES `gotpm:"selector=0x0015"` // TPM_ALG_RSAES - RSAPSS *TPMSSigSchemeRSAPSS `gotpm:"selector=0x0016"` // TPM_ALG_RSAPSS - OAEP *TPMSEncSchemeOAEP `gotpm:"selector=0x0017"` // TPM_ALG_OAEP - ECDSA *TPMSSigSchemeECDSA `gotpm:"selector=0x0018"` // TPM_ALG_ECDSA - ECDH *TPMSKeySchemeECDH `gotpm:"selector=0x0019"` // TPM_ALG_ECDH - ECDAA *TPMSSigSchemeECDAA `gotpm:"selector=0x001a"` // TPM_ALG_ECDAA + selector TPMAlgID + contents Marshallable +} + +// AsymSchemeContents is a type constraint representing the possible contents of TPMUAsymScheme. +type AsymSchemeContents interface { + Marshallable + *TPMSSigSchemeRSASSA | *TPMSEncSchemeRSAES | *TPMSSigSchemeRSAPSS | *TPMSEncSchemeOAEP | + *TPMSSigSchemeECDSA | *TPMSKeySchemeECDH | *TPMSSchemeECDAA +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUAsymScheme) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgRSASSA: + var contents TPMSSigSchemeRSASSA + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgRSAES: + var contents TPMSEncSchemeRSAES + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgRSAPSS: + var contents TPMSSigSchemeRSAPSS + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgOAEP: + var contents TPMSEncSchemeOAEP + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgECDSA: + var contents TPMSSigSchemeECDSA + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgECDH: + var contents TPMSKeySchemeECDH + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgECDAA: + var contents TPMSSchemeECDAA + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUAsymScheme) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgRSASSA: + var contents TPMSSigSchemeRSASSA + if u.contents != nil { + contents = *u.contents.(*TPMSSigSchemeRSASSA) + } + return reflect.ValueOf(&contents), nil + case TPMAlgRSAES: + var contents TPMSEncSchemeRSAES + if u.contents != nil { + contents = *u.contents.(*TPMSEncSchemeRSAES) + } + return reflect.ValueOf(&contents), nil + case TPMAlgRSAPSS: + var contents TPMSSigSchemeRSAPSS + if u.contents != nil { + contents = *u.contents.(*TPMSSigSchemeRSAPSS) + } + return reflect.ValueOf(&contents), nil + case TPMAlgOAEP: + var contents TPMSEncSchemeOAEP + if u.contents != nil { + contents = *u.contents.(*TPMSEncSchemeOAEP) + } + return reflect.ValueOf(&contents), nil + case TPMAlgECDSA: + var contents TPMSSigSchemeECDSA + if u.contents != nil { + contents = *u.contents.(*TPMSSigSchemeECDSA) + } + return reflect.ValueOf(&contents), nil + case TPMAlgECDH: + var contents TPMSKeySchemeECDH + if u.contents != nil { + contents = *u.contents.(*TPMSKeySchemeECDH) + } + return reflect.ValueOf(&contents), nil + case TPMAlgECDAA: + var contents TPMSSchemeECDAA + if u.contents != nil { + contents = *u.contents.(*TPMSSchemeECDAA) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUAsymScheme instantiates a TPMUAsymScheme with the given contents. +func NewTPMUAsymScheme[C AsymSchemeContents](selector TPMAlgID, contents C) TPMUAsymScheme { + return TPMUAsymScheme{ + selector: selector, + contents: contents, + } +} + +// RSASSA returns the 'rsassa' member of the union. +func (u *TPMUAsymScheme) RSASSA() (*TPMSSigSchemeRSASSA, error) { + if u.selector == TPMAlgRSASSA { + return u.contents.(*TPMSSigSchemeRSASSA), nil + } + return nil, fmt.Errorf("did not contain rsassa (selector value was %v)", u.selector) +} + +// RSAES returns the 'rsaes' member of the union. +func (u *TPMUAsymScheme) RSAES() (*TPMSEncSchemeRSAES, error) { + if u.selector == TPMAlgRSAES { + return u.contents.(*TPMSEncSchemeRSAES), nil + } + return nil, fmt.Errorf("did not contain rsaes (selector value was %v)", u.selector) +} + +// RSAPSS returns the 'rsapss' member of the union. +func (u *TPMUAsymScheme) RSAPSS() (*TPMSSigSchemeRSAPSS, error) { + if u.selector == TPMAlgRSAPSS { + return u.contents.(*TPMSSigSchemeRSAPSS), nil + } + return nil, fmt.Errorf("did not contain rsapss (selector value was %v)", u.selector) +} + +// OAEP returns the 'oaep' member of the union. +func (u *TPMUAsymScheme) OAEP() (*TPMSEncSchemeOAEP, error) { + if u.selector == TPMAlgOAEP { + return u.contents.(*TPMSEncSchemeOAEP), nil + } + return nil, fmt.Errorf("did not contain oaep (selector value was %v)", u.selector) +} + +// ECDSA returns the 'ecdsa' member of the union. +func (u *TPMUAsymScheme) ECDSA() (*TPMSSigSchemeECDSA, error) { + if u.selector == TPMAlgECDSA { + return u.contents.(*TPMSSigSchemeECDSA), nil + } + return nil, fmt.Errorf("did not contain rsassa (selector value was %v)", u.selector) +} + +// ECDH returns the 'ecdh' member of the union. +func (u *TPMUAsymScheme) ECDH() (*TPMSKeySchemeECDH, error) { + if u.selector == TPMAlgRSASSA { + return u.contents.(*TPMSKeySchemeECDH), nil + } + return nil, fmt.Errorf("did not contain ecdh (selector value was %v)", u.selector) +} + +// ECDAA returns the 'ecdaa' member of the union. +func (u *TPMUAsymScheme) ECDAA() (*TPMSSchemeECDAA, error) { + if u.selector == TPMAlgECDAA { + return u.contents.(*TPMSSchemeECDAA), nil + } + return nil, fmt.Errorf("did not contain rsassa (selector value was %v)", u.selector) } // TPMIAlgRSAScheme represents a TPMI_ALG_RSA_SCHEME. @@ -1201,6 +2314,7 @@ type TPMIAlgRSAScheme = TPMAlgID // TPMTRSAScheme represents a TPMT_RSA_SCHEME. // See definition in Part 2: Structures, section 11.2.4.2. type TPMTRSAScheme struct { + marshalByReflection // scheme selector Scheme TPMIAlgRSAScheme `gotpm:"nullable"` // scheme parameters @@ -1226,6 +2340,7 @@ type TPM2BECCParameter TPM2BData // TPMSECCPoint represents a TPMS_ECC_POINT. // See definition in Part 2: Structures, section 11.2.5.2. type TPMSECCPoint struct { + marshalByReflection // X coordinate X TPM2BECCParameter // Y coordinate @@ -1234,9 +2349,7 @@ type TPMSECCPoint struct { // TPM2BECCPoint represents a TPM2B_ECC_POINT. // See definition in Part 2: Structures, section 11.2.5.3. -type TPM2BECCPoint struct { - Point TPMSECCPoint `gotpm:"sized"` -} +type TPM2BECCPoint = TPM2B[TPMSECCPoint, *TPMSECCPoint] // TPMIAlgECCScheme represents a TPMI_ALG_ECC_SCHEME. // See definition in Part 2: Structures, section 11.2.5.4. @@ -1249,6 +2362,7 @@ type TPMIECCCurve = TPMECCCurve // TPMTECCScheme represents a TPMT_ECC_SCHEME. // See definition in Part 2: Structures, section 11.2.5.6. type TPMTECCScheme struct { + marshalByReflection // scheme selector Scheme TPMIAlgECCScheme `gotpm:"nullable"` // scheme parameters @@ -1258,6 +2372,7 @@ type TPMTECCScheme struct { // TPMSSignatureRSA represents a TPMS_SIGNATURE_RSA. // See definition in Part 2: Structures, section 11.3.1. type TPMSSignatureRSA struct { + marshalByReflection // the hash algorithm used to digest the message Hash TPMIAlgHash // The signature is the size of a public key. @@ -1267,6 +2382,7 @@ type TPMSSignatureRSA struct { // TPMSSignatureECC represents a TPMS_SIGNATURE_ECC. // See definition in Part 2: Structures, section 11.3.2. type TPMSSignatureECC struct { + marshalByReflection // the hash algorithm used in the signature process Hash TPMIAlgHash SignatureR TPM2BECCParameter @@ -1276,16 +2392,118 @@ type TPMSSignatureECC struct { // TPMUSignature represents a TPMU_SIGNATURE. // See definition in Part 2: Structures, section 11.3.3. type TPMUSignature struct { - HMAC *TPMTHA `gotpm:"selector=0x0005"` // TPM_ALG_HMAC - RSASSA *TPMSSignatureRSA `gotpm:"selector=0x0014"` // TPM_ALG_RSASSA - RSAPSS *TPMSSignatureRSA `gotpm:"selector=0x0016"` // TPM_ALG_RSAPSS - ECDSA *TPMSSignatureECC `gotpm:"selector=0x0018"` // TPM_ALG_ECDSA - ECDAA *TPMSSignatureECC `gotpm:"selector=0x001a"` // TPM_ALG_ECDAA + selector TPMAlgID + contents Marshallable +} + +// SignatureContents is a type constraint representing the possible contents of TPMUSignature. +type SignatureContents interface { + Marshallable + *TPMTHA | *TPMSSignatureRSA | *TPMSSignatureECC +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUSignature) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgHMAC: + var contents TPMTHA + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgRSASSA, TPMAlgRSAPSS: + var contents TPMSSignatureRSA + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgECDSA, TPMAlgECDAA: + var contents TPMSSignatureECC + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUSignature) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgHMAC: + var contents TPMTHA + if u.contents != nil { + contents = *u.contents.(*TPMTHA) + } + return reflect.ValueOf(&contents), nil + case TPMAlgRSASSA, TPMAlgRSAPSS: + var contents TPMSSignatureRSA + if u.contents != nil { + contents = *u.contents.(*TPMSSignatureRSA) + } + return reflect.ValueOf(&contents), nil + case TPMAlgECDSA, TPMAlgECDAA: + var contents TPMSSignatureECC + if u.contents != nil { + contents = *u.contents.(*TPMSSignatureECC) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUSignature instantiates a TPMUSignature with the given contents. +func NewTPMUSignature[C SignatureContents](selector TPMAlgID, contents C) TPMUSignature { + return TPMUSignature{ + selector: selector, + contents: contents, + } +} + +// HMAC returns the 'hmac' member of the union. +func (u *TPMUSignature) HMAC() (*TPMTHA, error) { + if u.selector == TPMAlgHMAC { + return u.contents.(*TPMTHA), nil + } + return nil, fmt.Errorf("did not contain hmac (selector value was %v)", u.selector) +} + +// RSASSA returns the 'rsassa' member of the union. +func (u *TPMUSignature) RSASSA() (*TPMSSignatureRSA, error) { + if u.selector == TPMAlgRSASSA { + return u.contents.(*TPMSSignatureRSA), nil + } + return nil, fmt.Errorf("did not contain rsassa (selector value was %v)", u.selector) +} + +// RSAPSS returns the 'rsapss' member of the union. +func (u *TPMUSignature) RSAPSS() (*TPMSSignatureRSA, error) { + if u.selector == TPMAlgRSAPSS { + return u.contents.(*TPMSSignatureRSA), nil + } + return nil, fmt.Errorf("did not contain rsapss (selector value was %v)", u.selector) +} + +// ECDSA returns the 'ecdsa' member of the union. +func (u *TPMUSignature) ECDSA() (*TPMSSignatureECC, error) { + if u.selector == TPMAlgECDSA { + return u.contents.(*TPMSSignatureECC), nil + } + return nil, fmt.Errorf("did not contain ecdsa (selector value was %v)", u.selector) +} + +// ECDAA returns the 'ecdaa' member of the union. +func (u *TPMUSignature) ECDAA() (*TPMSSignatureECC, error) { + if u.selector == TPMAlgRSASSA { + return u.contents.(*TPMSSignatureECC), nil + } + return nil, fmt.Errorf("did not contain ecdaa (selector value was %v)", u.selector) } // TPMTSignature represents a TPMT_SIGNATURE. // See definition in Part 2: Structures, section 11.3.4. type TPMTSignature struct { + marshalByReflection // selector of the algorithm used to construct the signature SigAlg TPMIAlgSigScheme `gotpm:"nullable"` // This shall be the actual signature information. @@ -1303,15 +2521,121 @@ type TPMIAlgPublic = TPMAlgID // TPMUPublicID represents a TPMU_PUBLIC_ID. // See definition in Part 2: Structures, section 12.2.3.2. type TPMUPublicID struct { - KeyedHash *TPM2BDigest `gotpm:"selector=0x0008"` // TPM_ALG_KEYEDHASH - Sym *TPM2BDigest `gotpm:"selector=0x0025"` // TPM_ALG_SYMCIPHER - RSA *TPM2BPublicKeyRSA `gotpm:"selector=0x0001"` // TPM_ALG_RSA - ECC *TPMSECCPoint `gotpm:"selector=0x0023"` // TPM_ALG_ECC + selector TPMAlgID + contents Marshallable +} + +// PublicIDContents is a type constraint representing the possible contents of TPMUPublicID. +type PublicIDContents interface { + Marshallable + *TPM2BDigest | *TPM2BPublicKeyRSA | *TPMSECCPoint +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUPublicID) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgKeyedHash: + var contents TPM2BDigest + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgSymCipher: + var contents TPM2BDigest + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgRSA: + var contents TPM2BPublicKeyRSA + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgECC: + var contents TPMSECCPoint + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUPublicID) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgKeyedHash: + var contents TPM2BDigest + if u.contents != nil { + contents = *u.contents.(*TPM2BDigest) + } + return reflect.ValueOf(&contents), nil + case TPMAlgSymCipher: + var contents TPM2BDigest + if u.contents != nil { + contents = *u.contents.(*TPM2BDigest) + } + return reflect.ValueOf(&contents), nil + case TPMAlgRSA: + var contents TPM2BPublicKeyRSA + if u.contents != nil { + contents = *u.contents.(*TPM2BPublicKeyRSA) + } + return reflect.ValueOf(&contents), nil + case TPMAlgECC: + var contents TPMSECCPoint + if u.contents != nil { + contents = *u.contents.(*TPMSECCPoint) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUPublicID instantiates a TPMUPublicID with the given contents. +func NewTPMUPublicID[C PublicIDContents](selector TPMAlgID, contents C) TPMUPublicID { + return TPMUPublicID{ + selector: selector, + contents: contents, + } +} + +// KeyedHash returns the 'keyedHash' member of the union. +func (u *TPMUPublicID) KeyedHash() (*TPM2BDigest, error) { + if u.selector == TPMAlgKeyedHash { + return u.contents.(*TPM2BDigest), nil + } + return nil, fmt.Errorf("did not contain keyedHash (selector value was %v)", u.selector) +} + +// SymCipher returns the 'symCipher' member of the union. +func (u *TPMUPublicID) SymCipher() (*TPM2BDigest, error) { + if u.selector == TPMAlgSymCipher { + return u.contents.(*TPM2BDigest), nil + } + return nil, fmt.Errorf("did not contain symCipher (selector value was %v)", u.selector) +} + +// RSA returns the 'rsa' member of the union. +func (u *TPMUPublicID) RSA() (*TPM2BPublicKeyRSA, error) { + if u.selector == TPMAlgRSA { + return u.contents.(*TPM2BPublicKeyRSA), nil + } + return nil, fmt.Errorf("did not contain rsa (selector value was %v)", u.selector) +} + +// ECC returns the 'ecc' member of the union. +func (u *TPMUPublicID) ECC() (*TPMSECCPoint, error) { + if u.selector == TPMAlgECC { + return u.contents.(*TPMSECCPoint), nil + } + return nil, fmt.Errorf("did not contain ecc (selector value was %v)", u.selector) } // TPMSKeyedHashParms represents a TPMS_KEYEDHASH_PARMS. // See definition in Part 2: Structures, section 12.2.3.3. type TPMSKeyedHashParms struct { + marshalByReflection // Indicates the signing method used for a keyedHash signing // object. This field also determines the size of the data field // for a data object created with TPM2_Create() or @@ -1322,6 +2646,7 @@ type TPMSKeyedHashParms struct { // TPMSRSAParms represents a TPMS_RSA_PARMS. // See definition in Part 2: Structures, section 12.2.3.5. type TPMSRSAParms struct { + marshalByReflection // for a restricted decryption key, shall be set to a supported // symmetric algorithm, key size, and mode. // if the key is not a restricted decryption key, this field shall @@ -1346,6 +2671,7 @@ type TPMSRSAParms struct { // TPMSECCParms represents a TPMS_ECC_PARMS. // See definition in Part 2: Structures, section 12.2.3.6. type TPMSECCParms struct { + marshalByReflection // for a restricted decryption key, shall be set to a supported // symmetric algorithm, key size. and mode. // if the key is not a restricted decryption key, this field shall @@ -1366,19 +2692,122 @@ type TPMSECCParms struct { // TPMUPublicParms represents a TPMU_PUBLIC_PARMS. // See definition in Part 2: Structures, section 12.2.3.7. type TPMUPublicParms struct { - // sign | decrypt | neither - KeyedHashDetail *TPMSKeyedHashParms `gotpm:"selector=0x0008"` // TPM_ALG_KEYEDHASH - // sign | decrypt | neither - SymCipherDetail *TPMSSymCipherParms `gotpm:"selector=0x0025"` // TPM_ALG_SYMCIPHER - // decrypt + sign - RSADetail *TPMSRSAParms `gotpm:"selector=0x0001"` // TPM_ALG_RSA - // decrypt + sign - ECCDetail *TPMSECCParms `gotpm:"selector=0x0023"` // TPM_ALG_ECC + selector TPMAlgID + contents Marshallable +} + +// PublicParmsContents is a type constraint representing the possible contents of TPMUPublicParms. +type PublicParmsContents interface { + Marshallable + *TPMSKeyedHashParms | *TPMSSymCipherParms | *TPMSRSAParms | + *TPMSECCParms +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUPublicParms) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgKeyedHash: + var contents TPMSKeyedHashParms + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgSymCipher: + var contents TPMSSymCipherParms + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgRSA: + var contents TPMSRSAParms + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgECC: + var contents TPMSECCParms + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUPublicParms) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgKeyedHash: + var contents TPMSKeyedHashParms + if u.contents != nil { + contents = *u.contents.(*TPMSKeyedHashParms) + } + return reflect.ValueOf(&contents), nil + case TPMAlgSymCipher: + var contents TPMSSymCipherParms + if u.contents != nil { + contents = *u.contents.(*TPMSSymCipherParms) + } + return reflect.ValueOf(&contents), nil + case TPMAlgRSA: + var contents TPMSRSAParms + if u.contents != nil { + contents = *u.contents.(*TPMSRSAParms) + } + return reflect.ValueOf(&contents), nil + case TPMAlgECC: + var contents TPMSECCParms + if u.contents != nil { + contents = *u.contents.(*TPMSECCParms) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// NewTPMUPublicParms instantiates a TPMUPublicParms with the given contents. +func NewTPMUPublicParms[C PublicParmsContents](selector TPMAlgID, contents C) TPMUPublicParms { + return TPMUPublicParms{ + selector: selector, + contents: contents, + } +} + +// KeyedHashDetail returns the 'keyedHashDetail' member of the union. +func (u *TPMUPublicParms) KeyedHashDetail() (*TPMSKeyedHashParms, error) { + if u.selector == TPMAlgKeyedHash { + return u.contents.(*TPMSKeyedHashParms), nil + } + return nil, fmt.Errorf("did not contain keyedHashDetail (selector value was %v)", u.selector) +} + +// SymDetail returns the 'symDetail' member of the union. +func (u *TPMUPublicParms) SymDetail() (*TPMSSymCipherParms, error) { + if u.selector == TPMAlgSymCipher { + return u.contents.(*TPMSSymCipherParms), nil + } + return nil, fmt.Errorf("did not contain symDetail (selector value was %v)", u.selector) +} + +// RSADetail returns the 'rsaDetail' member of the union. +func (u *TPMUPublicParms) RSADetail() (*TPMSRSAParms, error) { + if u.selector == TPMAlgRSA { + return u.contents.(*TPMSRSAParms), nil + } + return nil, fmt.Errorf("did not contain rsaDetail (selector value was %v)", u.selector) +} + +// ECCDetail returns the 'eccDetail' member of the union. +func (u *TPMUPublicParms) ECCDetail() (*TPMSECCParms, error) { + if u.selector == TPMAlgECC { + return u.contents.(*TPMSECCParms), nil + } + return nil, fmt.Errorf("did not contain eccDetail (selector value was %v)", u.selector) } // TPMTPublic represents a TPMT_PUBLIC. // See definition in Part 2: Structures, section 12.2.4. type TPMTPublic struct { + marshalByReflection // “algorithm” associated with this object Type TPMIAlgPublic // algorithm used for computing the Name of the object @@ -1396,10 +2825,25 @@ type TPMTPublic struct { Unique TPMUPublicID `gotpm:"tag=Type"` } +// TPM2BPublic represents a TPM2B_PUBLIC. +// See definition in Part 2: Structures, section 12.2.5. +type TPM2BPublic = TPM2B[TPMTPublic, *TPMTPublic] + +// TPM2BTemplate represents a TPM2B_TEMPLATE. +// See definition in Part 2: Structures, section 12.2.6. +type TPM2BTemplate TPM2BData + +// TemplateContents is a type constraint representing the possible contents of TPMUTemplate. +type TemplateContents interface { + Marshallable + *TPMTPublic | *TPMTTemplate +} + // TPMTTemplate represents a TPMT_TEMPLATE. It is not defined in the spec. // It represents the alternate form of TPMT_PUBLIC for TPM2B_TEMPLATE as // described in Part 2: Structures, 12.2.6. type TPMTTemplate struct { + marshalByReflection // “algorithm” associated with this object Type TPMIAlgPublic // algorithm used for computing the Name of the object @@ -1416,49 +2860,131 @@ type TPMTTemplate struct { Unique TPMSDerive } -// TPM2BPublic represents a TPM2B_PUBLIC. -// See definition in Part 2: Structures, section 12.2.5. -type TPM2BPublic struct { - // the public area - PublicArea TPMTPublic `gotpm:"sized"` +// New2BTemplate creates a TPM2BTemplate with the given data. +func New2BTemplate[C TemplateContents](data C) TPM2BTemplate { + return TPM2BTemplate{ + Buffer: Marshal(data), + } } -// TPM2BTemplate represents a TPM2B_TEMPLATE. -// See definition in Part 2: Structures, section 12.2.6. -type TPM2BTemplate struct { - Template TPMUTemplate `gotpm:"sized"` +// TPMUSensitiveComposite represents a TPMU_SENSITIVE_COMPOSITE. +// See definition in Part 2: Structures, section 12.3.2.3. +type TPMUSensitiveComposite struct { + selector TPMAlgID + contents Marshallable +} + +// SensitiveCompositeContents is a type constraint representing the possible contents of TPMUSensitiveComposite. +type SensitiveCompositeContents interface { + Marshallable + *TPM2BPrivateKeyRSA | *TPM2BECCParameter | *TPM2BSensitiveData | *TPM2BSymKey +} + +// create implements the unmarshallableWithHint interface. +func (u *TPMUSensitiveComposite) create(hint int64) (reflect.Value, error) { + switch TPMAlgID(hint) { + case TPMAlgRSA: + var contents TPM2BPrivateKeyRSA + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgECC: + var contents TPM2BECCParameter + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgKeyedHash: + var contents TPM2BSensitiveData + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + case TPMAlgSymCipher: + var contents TPM2BSymKey + u.contents = &contents + u.selector = TPMAlgID(hint) + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) +} + +// get implements the marshallableWithHint interface. +func (u TPMUSensitiveComposite) get(hint int64) (reflect.Value, error) { + if u.selector != 0 && hint != int64(u.selector) { + return reflect.ValueOf(nil), fmt.Errorf("incorrect union tag %v, is %v", hint, u.selector) + } + switch TPMAlgID(hint) { + case TPMAlgRSA: + var contents TPM2BPrivateKeyRSA + if u.contents != nil { + contents = *u.contents.(*TPM2BPrivateKeyRSA) + } + return reflect.ValueOf(&contents), nil + case TPMAlgECC: + var contents TPM2BECCParameter + if u.contents != nil { + contents = *u.contents.(*TPM2BECCParameter) + } + return reflect.ValueOf(&contents), nil + case TPMAlgKeyedHash: + var contents TPM2BSensitiveData + if u.contents != nil { + contents = *u.contents.(*TPM2BSensitiveData) + } + return reflect.ValueOf(&contents), nil + case TPMAlgSymCipher: + var contents TPM2BSymKey + if u.contents != nil { + contents = *u.contents.(*TPM2BSymKey) + } + return reflect.ValueOf(&contents), nil + } + return reflect.ValueOf(nil), fmt.Errorf("no union member for tag %v", hint) } -// TPMUTemplate represents the possible contents of a TPM2B_Template. It is not -// defined or named in the spec, which instead describes how its contents may -// differ in the case of CreateLoaded with a derivation parent. -// Since the TPM cannot return this type, it can be an interface. -type TPMUTemplate interface { - tpmutemplate() - defaultMarshalling() []byte +// NewTPMUSensitiveComposite instantiates a TPMUSensitiveComposite with the given contents. +func NewTPMUSensitiveComposite[C SensitiveCompositeContents](selector TPMAlgID, contents C) TPMUSensitiveComposite { + return TPMUSensitiveComposite{ + selector: selector, + contents: contents, + } } -func (TPMTPublic) tpmutemplate() {} -func (TPMTPublic) defaultMarshalling() []byte { return nil } -func (TPMTTemplate) tpmutemplate() {} -func (TPMTTemplate) defaultMarshalling() []byte { return nil } +// RSA returns the 'rsa' member of the union. +func (u *TPMUKDFScheme) RSA() (*TPM2BPrivateKeyRSA, error) { + if u.selector == TPMAlgRSA { + return u.contents.(*TPM2BPrivateKeyRSA), nil + } + return nil, fmt.Errorf("did not contain rsa (selector value was %v)", u.selector) +} -// TPMUSensitiveComposite represents a TPMU_SENSITIVE_COMPOSITE. -// See definition in Part 2: Structures, section 12.3.2.3. -type TPMUSensitiveComposite struct { - // a prime factor of the public key - RSA *TPM2BPrivateKeyRSA `gotpm:"selector=0x0001"` // TPM_ALG_RSA - // the integer private key - ECC *TPM2BECCParameter `gotpm:"selector=0x0023"` // TPM_ALG_ECC - // the private data - Bits *TPM2BSensitiveData `gotpm:"selector=0x0008"` // TPM_ALG_KEYEDHASH - // the symmetric key - Sym *TPM2BSymKey `gotpm:"selector=0x0025"` // TPM_ALG_SYMCIPHER +// ECC returns the 'ecc' member of the union. +func (u *TPMUKDFScheme) ECC() (*TPM2BECCParameter, error) { + if u.selector == TPMAlgECC { + return u.contents.(*TPM2BECCParameter), nil + } + return nil, fmt.Errorf("did not contain ecc (selector value was %v)", u.selector) +} + +// Bits returns the 'bits' member of the union. +func (u *TPMUKDFScheme) Bits() (*TPM2BSensitiveData, error) { + if u.selector == TPMAlgKeyedHash { + return u.contents.(*TPM2BSensitiveData), nil + } + return nil, fmt.Errorf("did not contain bits (selector value was %v)", u.selector) +} + +// Sym returns the 'sym' member of the union. +func (u *TPMUKDFScheme) Sym() (*TPM2BSymKey, error) { + if u.selector == TPMAlgSymCipher { + return u.contents.(*TPM2BSymKey), nil + } + return nil, fmt.Errorf("did not contain sym (selector value was %v)", u.selector) } // TPMTSensitive represents a TPMT_SENSITIVE. // See definition in Part 2: Structures, section 12.3.2.4. type TPMTSensitive struct { + marshalByReflection // identifier for the sensitive area SensitiveType TPMIAlgPublic // user authorization data @@ -1472,10 +2998,7 @@ type TPMTSensitive struct { // TPM2BSensitive represents a TPM2B_SENSITIVE. // See definition in Part 2: Structures, section 12.3.3. -type TPM2BSensitive struct { - // an unencrypted sensitive area - SensitiveArea TPMTSensitive `gotpm:"sized"` -} +type TPM2BSensitive = TPM2B[TPMTSensitive, *TPMTSensitive] // TPM2BPrivate represents a TPM2B_PRIVATE. // See definition in Part 2: Structures, section 12.3.7. @@ -1484,6 +3007,7 @@ type TPM2BPrivate TPM2BData // TPMSCreationData represents a TPMS_CREATION_DATA. // See definition in Part 2: Structures, section 15.1. type TPMSCreationData struct { + marshalByReflection // list indicating the PCR included in pcrDigest PCRSelect TPMLPCRSelection // digest of the selected PCR using nameAlg of the object for which @@ -1513,6 +3037,7 @@ type TPMNT uint8 // See definition in Part 2: Structures, section 13.4. type TPMANV struct { bitfield32 + marshalByReflection // SET (1): The Index data can be written if Platform Authorization is // provided. // CLEAR (0): Writing of the Index data cannot be authorized with @@ -1621,6 +3146,7 @@ type TPMANV struct { // TPMSNVPublic represents a TPMS_NV_PUBLIC. // See definition in Part 2: Structures, section 13.5. type TPMSNVPublic struct { + marshalByReflection // the handle of the data area NVIndex TPMIRHNVIndex // hash algorithm used to compute the name of the Index and used for @@ -1637,21 +3163,16 @@ type TPMSNVPublic struct { // TPM2BNVPublic represents a TPM2B_NV_PUBLIC. // See definition in Part 2: Structures, section 13.6. -type TPM2BNVPublic struct { - NVPublic TPMSNVPublic `gotpm:"sized"` -} +type TPM2BNVPublic = TPM2B[TPMSNVPublic, *TPMSNVPublic] // TPM2BContextSensitive represents a TPM2B_CONTEXT_SENSITIVE // See definition in Part 2: Structures, section 14.2. -type TPM2BContextSensitive struct { - Size uint16 - // the sensitive data - Buffer []byte -} +type TPM2BContextSensitive TPM2BData // TPMSContextData represents a TPMS_CONTEXT_DATA // See definition in Part 2: Structures, section 14.3. type TPMSContextData struct { + marshalByReflection // the integrity value Integrity TPM2BDigest // the sensitive area @@ -1660,13 +3181,14 @@ type TPMSContextData struct { // TPM2BContextData represents a TPM2B_CONTEXT_DATA // See definition in Part 2: Structures, section 14.4. -type TPM2BContextData struct { - Buffer TPMSContextData `gotpm:"sized"` -} +// Represented here as a flat buffer because how a TPM chooses +// to represent its context data is implementation-dependent. +type TPM2BContextData TPM2BData // TPMSContext represents a TPMS_CONTEXT // See definition in Part 2: Structures, section 14.5. type TPMSContext struct { + marshalByReflection // the sequence number of the context Sequence uint64 // a handle indicating if the context is a session, object, or sequence object @@ -1677,8 +3199,4 @@ type TPMSContext struct { ContextBlob TPM2BContextData } -// TPM2BCreationData represents a TPM2B_CREATION_DATA. -// See definition in Part 2: Structures, section 15.2. -type TPM2BCreationData struct { - CreationData TPMSCreationData `gotpm:"sized"` -} +type tpm2bCreationData = TPM2B[TPMSCreationData, *TPMSCreationData] diff --git a/tpm2/templates.go b/tpm2/templates.go index 5d6e3c66..a0cfa811 100644 --- a/tpm2/templates.go +++ b/tpm2/templates.go @@ -19,25 +19,29 @@ var ( Decrypt: true, SignEncrypt: false, }, - Parameters: TPMUPublicParms{ - RSADetail: &TPMSRSAParms{ + Parameters: NewTPMUPublicParms( + TPMAlgRSA, + &TPMSRSAParms{ Symmetric: TPMTSymDefObject{ Algorithm: TPMAlgAES, - KeyBits: TPMUSymKeyBits{ - AES: NewKeyBits(128), - }, - Mode: TPMUSymMode{ - AES: NewAlgID(TPMAlgCFB), - }, + KeyBits: NewTPMUSymKeyBits( + TPMAlgAES, + TPMKeyBits(128), + ), + Mode: NewTPMUSymMode( + TPMAlgAES, + TPMAlgCFB, + ), }, KeyBits: 2048, }, - }, - Unique: TPMUPublicID{ - RSA: &TPM2BPublicKeyRSA{ + ), + Unique: NewTPMUPublicID( + TPMAlgRSA, + &TPM2BPublicKeyRSA{ Buffer: make([]byte, 256), }, - }, + ), } // RSAEKTemplate contains the TCG reference RSA-2048 EK template. RSAEKTemplate = TPMTPublic{ @@ -65,25 +69,29 @@ var ( 0xF2, 0xA1, 0xDA, 0x1B, 0x33, 0x14, 0x69, 0xAA, }, }, - Parameters: TPMUPublicParms{ - RSADetail: &TPMSRSAParms{ + Parameters: NewTPMUPublicParms( + TPMAlgRSA, + &TPMSRSAParms{ Symmetric: TPMTSymDefObject{ Algorithm: TPMAlgAES, - KeyBits: TPMUSymKeyBits{ - AES: NewKeyBits(128), - }, - Mode: TPMUSymMode{ - AES: NewAlgID(TPMAlgCFB), - }, + KeyBits: NewTPMUSymKeyBits( + TPMAlgAES, + TPMKeyBits(128), + ), + Mode: NewTPMUSymMode( + TPMAlgAES, + TPMAlgCFB, + ), }, KeyBits: 2048, }, - }, - Unique: TPMUPublicID{ - RSA: &TPM2BPublicKeyRSA{ + ), + Unique: NewTPMUPublicID( + TPMAlgRSA, + &TPM2BPublicKeyRSA{ Buffer: make([]byte, 256), }, - }, + ), } // ECCSRKTemplate contains the TCG reference ECC-P256 SRK template. @@ -104,22 +112,26 @@ var ( Decrypt: true, SignEncrypt: false, }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ Symmetric: TPMTSymDefObject{ Algorithm: TPMAlgAES, - KeyBits: TPMUSymKeyBits{ - AES: NewKeyBits(128), - }, - Mode: TPMUSymMode{ - AES: NewAlgID(TPMAlgCFB), - }, + KeyBits: NewTPMUSymKeyBits( + TPMAlgAES, + TPMKeyBits(128), + ), + Mode: NewTPMUSymMode( + TPMAlgAES, + TPMAlgCFB, + ), }, CurveID: TPMECCNistP256, }, - }, - Unique: TPMUPublicID{ - ECC: &TPMSECCPoint{ + ), + Unique: NewTPMUPublicID( + TPMAlgECC, + &TPMSECCPoint{ X: TPM2BECCParameter{ Buffer: make([]byte, 32), }, @@ -127,7 +139,7 @@ var ( Buffer: make([]byte, 32), }, }, - }, + ), } // ECCEKTemplate contains the TCG reference ECC-P256 EK template. @@ -156,22 +168,26 @@ var ( 0xF2, 0xA1, 0xDA, 0x1B, 0x33, 0x14, 0x69, 0xAA, }, }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ Symmetric: TPMTSymDefObject{ Algorithm: TPMAlgAES, - KeyBits: TPMUSymKeyBits{ - AES: NewKeyBits(128), - }, - Mode: TPMUSymMode{ - AES: NewAlgID(TPMAlgCFB), - }, + KeyBits: NewTPMUSymKeyBits( + TPMAlgAES, + TPMKeyBits(128), + ), + Mode: NewTPMUSymMode( + TPMAlgAES, + TPMAlgCFB, + ), }, CurveID: TPMECCNistP256, }, - }, - Unique: TPMUPublicID{ - ECC: &TPMSECCPoint{ + ), + Unique: NewTPMUPublicID( + TPMAlgECC, + &TPMSECCPoint{ X: TPM2BECCParameter{ Buffer: make([]byte, 32), }, @@ -179,6 +195,6 @@ var ( Buffer: make([]byte, 32), }, }, - }, + ), } ) diff --git a/tpm2/test/activate_credential_test.go b/tpm2/test/activate_credential_test.go index 01af1f80..b81b5b42 100644 --- a/tpm2/test/activate_credential_test.go +++ b/tpm2/test/activate_credential_test.go @@ -17,9 +17,7 @@ func TestActivateCredential(t *testing.T) { ekCreate := CreatePrimary{ PrimaryHandle: TPMRHEndorsement, - InPublic: TPM2BPublic{ - PublicArea: ECCEKTemplate, - }, + InPublic: New2B(ECCEKTemplate), } ekCreateRsp, err := ekCreate.Execute(thetpm) @@ -30,7 +28,7 @@ func TestActivateCredential(t *testing.T) { flush := FlushContext{ FlushHandle: ekCreateRsp.ObjectHandle, } - err := flush.Execute(thetpm) + _, err := flush.Execute(thetpm) if err != nil { t.Fatalf("could not flush EK: %v", err) } @@ -38,9 +36,7 @@ func TestActivateCredential(t *testing.T) { srkCreate := CreatePrimary{ PrimaryHandle: TPMRHOwner, - InPublic: TPM2BPublic{ - PublicArea: ECCSRKTemplate, - }, + InPublic: New2B(ECCSRKTemplate), } srkCreateRsp, err := srkCreate.Execute(thetpm) @@ -51,7 +47,7 @@ func TestActivateCredential(t *testing.T) { flush := FlushContext{ FlushHandle: srkCreateRsp.ObjectHandle, } - err := flush.Execute(thetpm) + _, err := flush.Execute(thetpm) if err != nil { t.Fatalf("could not flush SRK: %v", err) } diff --git a/tpm2/test/audit_test.go b/tpm2/test/audit_test.go index 8b1923de..31fd6d10 100644 --- a/tpm2/test/audit_test.go +++ b/tpm2/test/audit_test.go @@ -25,38 +25,39 @@ func TestAuditSession(t *testing.T) { // Create the AK for audit createAKCmd := CreatePrimary{ PrimaryHandle: TPMRHOwner, - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgECC, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - FixedTPM: true, - STClear: false, - FixedParent: true, - SensitiveDataOrigin: true, - UserWithAuth: true, - AdminWithPolicy: false, - NoDA: true, - EncryptedDuplication: false, - Restricted: true, - Decrypt: false, - SignEncrypt: true, - }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ - Scheme: TPMTECCScheme{ - Scheme: TPMAlgECDSA, - Details: TPMUAsymScheme{ - ECDSA: &TPMSSigSchemeECDSA{ - HashAlg: TPMAlgSHA256, - }, + InPublic: New2B(TPMTPublic{ + Type: TPMAlgECC, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + STClear: false, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + AdminWithPolicy: false, + NoDA: true, + EncryptedDuplication: false, + Restricted: true, + Decrypt: false, + SignEncrypt: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ + Scheme: TPMTECCScheme{ + Scheme: TPMAlgECDSA, + Details: NewTPMUAsymScheme( + TPMAlgECDSA, + &TPMSSigSchemeECDSA{ + HashAlg: TPMAlgSHA256, }, - }, - CurveID: TPMECCNistP256, + ), }, + CurveID: TPMECCNistP256, }, - }, + ), }, + ), } createAKRsp, err := createAKCmd.Execute(thetpm) if err != nil { @@ -65,7 +66,7 @@ func TestAuditSession(t *testing.T) { defer func() { // Flush the AK flush := FlushContext{FlushHandle: createAKRsp.ObjectHandle} - if err := flush.Execute(thetpm); err != nil { + if _, err := flush.Execute(thetpm); err != nil { t.Errorf("%v", err) } }() @@ -94,7 +95,7 @@ func TestAuditSession(t *testing.T) { if err != nil { t.Fatalf("%v", err) } - if err := audit.Extend(&getCmd, getRsp); err != nil { + if err := AuditCommand(audit, getCmd, getRsp); err != nil { t.Fatalf("%v", err) } // Get the audit digest signed by the AK @@ -112,9 +113,13 @@ func TestAuditSession(t *testing.T) { t.Fatalf("%v", err) } // TODO check the signature with the AK pub - aud := getAuditRsp.AuditInfo.AttestationData.Attested.SessionAudit - if aud == nil { - t.Fatalf("got nil session audit attestation") + attest, err := getAuditRsp.AuditInfo.Contents() + if err != nil { + t.Fatalf("%v", err) + } + aud, err := attest.Attested.SessionAudit() + if err != nil { + t.Fatalf("%v", err) } want := audit.Digest() got := aud.SessionDigest.Buffer diff --git a/tpm2/test/certify_test.go b/tpm2/test/certify_test.go index 5260547b..d70a4bcf 100644 --- a/tpm2/test/certify_test.go +++ b/tpm2/test/certify_test.go @@ -25,33 +25,34 @@ func TestCertify(t *testing.T) { if err != nil { t.Fatalf("Failed to create PCRSelection") } - public := TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgRSA, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - SignEncrypt: true, - Restricted: true, - FixedTPM: true, - FixedParent: true, - SensitiveDataOrigin: true, - UserWithAuth: true, - }, - Parameters: TPMUPublicParms{ - RSADetail: &TPMSRSAParms{ - Scheme: TPMTRSAScheme{ - Scheme: TPMAlgRSASSA, - Details: TPMUAsymScheme{ - RSASSA: &TPMSSigSchemeRSASSA{ - HashAlg: TPMAlgSHA256, - }, + public := New2B(TPMTPublic{ + Type: TPMAlgRSA, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + SignEncrypt: true, + Restricted: true, + FixedTPM: true, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgRSA, + &TPMSRSAParms{ + Scheme: TPMTRSAScheme{ + Scheme: TPMAlgRSASSA, + Details: NewTPMUAsymScheme( + TPMAlgRSASSA, + &TPMSSigSchemeRSASSA{ + HashAlg: TPMAlgSHA256, }, - }, - KeyBits: 2048, + ), }, + KeyBits: 2048, }, - }, - } + ), + }, + ) pcrSelection := TPMLPCRSelection{ PCRSelections: []TPMSPCRSelection{ @@ -65,7 +66,7 @@ func TestCertify(t *testing.T) { createPrimarySigner := CreatePrimary{ PrimaryHandle: TPMRHOwner, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ + Sensitive: &TPMSSensitiveCreate{ UserAuth: TPM2BAuth{ Buffer: Auth, }, @@ -84,7 +85,7 @@ func TestCertify(t *testing.T) { createPrimarySubject := CreatePrimary{ PrimaryHandle: TPMRHOwner, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ + Sensitive: &TPMSSensitiveCreate{ UserAuth: TPM2BAuth{ Buffer: Auth, }, @@ -93,12 +94,17 @@ func TestCertify(t *testing.T) { InPublic: public, CreationPCR: pcrSelection, } - unique := TPMUPublicID{ - RSA: &TPM2BPublicKeyRSA{ + unique := NewTPMUPublicID( + TPMAlgRSA, + &TPM2BPublicKeyRSA{ Buffer: []byte("subject key"), }, + ) + inPub, err := createPrimarySubject.InPublic.Contents() + if err != nil { + t.Fatalf("%v", err) } - createPrimarySubject.InPublic.PublicArea.Unique = unique + inPub.Unique = unique rspSubject, err := createPrimarySubject.Execute(thetpm) if err != nil { @@ -133,23 +139,38 @@ func TestCertify(t *testing.T) { t.Fatalf("Failed to certify: %v", err) } - info, err := Marshal(rspCert.CertifyInfo.AttestationData) + certifyInfo, err := rspCert.CertifyInfo.Contents() if err != nil { - t.Fatalf("Failed to marshal: %v", err) + t.Fatalf("%v", err) } + info := Marshal(certifyInfo) attestHash := sha256.Sum256(info) - pub := rspSigner.OutPublic.PublicArea - rsaPub, err := RSAPub(pub.Parameters.RSADetail, pub.Unique.RSA) + pub, err := rspSigner.OutPublic.Contents() if err != nil { - t.Fatalf("Failed to retrieve Public Key: %v", err) + t.Fatalf("%v", err) + } + rsaDetail, err := pub.Parameters.RSADetail() + if err != nil { + t.Fatalf("%v", err) + } + rsaUnique, err := pub.Unique.RSA() + if err != nil { + t.Fatalf("%v", err) + } + rsaPub, err := RSAPub(rsaDetail, rsaUnique) + if err != nil { + t.Fatalf("%v", err) } - if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, attestHash[:], rspCert.Signature.Signature.RSASSA.Sig.Buffer); err != nil { + rsassa, err := rspCert.Signature.Signature.RSASSA() + if err != nil { + t.Fatalf("%v", err) + } + if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, attestHash[:], rsassa.Sig.Buffer); err != nil { t.Errorf("Signature verification failed: %v", err) } - - if !cmp.Equal(originalBuffer, rspCert.CertifyInfo.AttestationData.ExtraData.Buffer) { + if !cmp.Equal(originalBuffer, certifyInfo.ExtraData.Buffer) { t.Errorf("Attested buffer is different from original buffer") } } @@ -161,34 +182,34 @@ func TestCreateAndCertifyCreation(t *testing.T) { } defer thetpm.Close() - public := TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgRSA, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - SignEncrypt: true, - Restricted: true, - FixedTPM: true, - FixedParent: true, - SensitiveDataOrigin: true, - UserWithAuth: true, - NoDA: true, - }, - Parameters: TPMUPublicParms{ - RSADetail: &TPMSRSAParms{ - Scheme: TPMTRSAScheme{ - Scheme: TPMAlgRSASSA, - Details: TPMUAsymScheme{ - RSASSA: &TPMSSigSchemeRSASSA{ - HashAlg: TPMAlgSHA256, - }, + public := New2B(TPMTPublic{ + Type: TPMAlgRSA, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + SignEncrypt: true, + Restricted: true, + FixedTPM: true, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + NoDA: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgRSA, + &TPMSRSAParms{ + Scheme: TPMTRSAScheme{ + Scheme: TPMAlgRSASSA, + Details: NewTPMUAsymScheme( + TPMAlgRSASSA, + &TPMSSigSchemeRSASSA{ + HashAlg: TPMAlgSHA256, }, - }, - KeyBits: 2048, + ), }, + KeyBits: 2048, }, - }, - } + ), + }) PCR7, err := CreatePCRSelection([]int{7}) if err != nil { @@ -217,11 +238,12 @@ func TestCreateAndCertifyCreation(t *testing.T) { inScheme := TPMTSigScheme{ Scheme: TPMAlgRSASSA, - Details: TPMUSigScheme{ - RSASSA: &TPMSSchemeHash{ + Details: NewTPMUSigScheme( + TPMAlgRSASSA, + &TPMSSchemeHash{ HashAlg: TPMAlgSHA256, }, - }, + ), } certifyCreation := CertifyCreation{ @@ -244,26 +266,50 @@ func TestCreateAndCertifyCreation(t *testing.T) { t.Fatalf("Failed to certify creation: %v", err) } - attName := rspCC.CertifyInfo.AttestationData.Attested.Creation.ObjectName.Buffer + certifyInfo, err := rspCC.CertifyInfo.Contents() + if err != nil { + t.Fatalf("%v", err) + } + creationInfo, err := certifyInfo.Attested.Creation() + if err != nil { + t.Fatalf("%v", err) + } + attName := creationInfo.ObjectName.Buffer pubName := rspCP.Name.Buffer if !bytes.Equal(attName, pubName) { t.Fatalf("Attested name: %v does not match returned public key: %v.", attName, pubName) } - info, err := Marshal(rspCC.CertifyInfo.AttestationData) + info := Marshal(certifyInfo) if err != nil { t.Fatalf("Failed to marshal: %v", err) } attestHash := sha256.Sum256(info) - pub := rspCP.OutPublic.PublicArea - rsaPub, err := RSAPub(pub.Parameters.RSADetail, pub.Unique.RSA) + pub, err := rspCP.OutPublic.Contents() + if err != nil { + t.Fatalf("%v", err) + } + rsaDetail, err := pub.Parameters.RSADetail() + if err != nil { + t.Fatalf("%v", err) + } + rsaUnique, err := pub.Unique.RSA() + if err != nil { + t.Fatalf("%v", err) + } + + rsaPub, err := RSAPub(rsaDetail, rsaUnique) if err != nil { t.Fatalf("Failed to retrieve Public Key: %v", err) } - if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, attestHash[:], rspCC.Signature.Signature.RSASSA.Sig.Buffer); err != nil { + rsassa, err := rspCC.Signature.Signature.RSASSA() + if err != nil { + t.Fatalf("%v", err) + } + if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, attestHash[:], rsassa.Sig.Buffer); err != nil { t.Errorf("Signature verification failed: %v", err) } } @@ -277,38 +323,38 @@ func TestNVCertify(t *testing.T) { Auth := []byte("password") - public := TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgRSA, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - SignEncrypt: true, - Restricted: true, - FixedTPM: true, - FixedParent: true, - SensitiveDataOrigin: true, - UserWithAuth: true, - }, - Parameters: TPMUPublicParms{ - RSADetail: &TPMSRSAParms{ - Scheme: TPMTRSAScheme{ - Scheme: TPMAlgRSASSA, - Details: TPMUAsymScheme{ - RSASSA: &TPMSSigSchemeRSASSA{ - HashAlg: TPMAlgSHA256, - }, + public := New2B(TPMTPublic{ + Type: TPMAlgRSA, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + SignEncrypt: true, + Restricted: true, + FixedTPM: true, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgRSA, + &TPMSRSAParms{ + Scheme: TPMTRSAScheme{ + Scheme: TPMAlgRSASSA, + Details: NewTPMUAsymScheme( + TPMAlgRSASSA, + &TPMSSigSchemeRSASSA{ + HashAlg: TPMAlgSHA256, }, - }, - KeyBits: 2048, + ), }, + KeyBits: 2048, }, - }, - } + ), + }) createPrimarySigner := CreatePrimary{ PrimaryHandle: TPMRHOwner, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ + Sensitive: &TPMSSensitiveCreate{ UserAuth: TPM2BAuth{ Buffer: Auth, }, @@ -325,8 +371,8 @@ func TestNVCertify(t *testing.T) { def := NVDefineSpace{ AuthHandle: TPMRHOwner, - PublicInfo: TPM2BNVPublic{ - NVPublic: TPMSNVPublic{ + PublicInfo: New2B( + TPMSNVPublic{ NVIndex: TPMHandle(0x0180000F), NameAlg: TPMAlgSHA256, Attributes: TPMANV{ @@ -338,10 +384,9 @@ func TestNVCertify(t *testing.T) { NoDA: true, }, DataSize: 4, - }, - }, + }), } - if err := def.Execute(thetpm); err != nil { + if _, err := def.Execute(thetpm); err != nil { t.Fatalf("Calling TPM2_NV_DefineSpace: %v", err) } @@ -352,15 +397,19 @@ func TestNVCertify(t *testing.T) { if err != nil { t.Fatalf("Calling TPM2_NV_ReadPublic: %v", err) } + nvPublic, err := def.PublicInfo.Contents() + if err != nil { + t.Fatalf("%v", err) + } prewrite := NVWrite{ AuthHandle: AuthHandle{ - Handle: def.PublicInfo.NVPublic.NVIndex, + Handle: nvPublic.NVIndex, Name: nvPub.NVName, Auth: PasswordAuth(nil), }, NVIndex: NamedHandle{ - Handle: def.PublicInfo.NVPublic.NVIndex, + Handle: nvPublic.NVIndex, Name: nvPub.NVName, }, Data: TPM2BMaxNVBuffer{ @@ -368,7 +417,7 @@ func TestNVCertify(t *testing.T) { }, Offset: 0, } - if err := prewrite.Execute(thetpm); err != nil { + if _, err := prewrite.Execute(thetpm); err != nil { t.Errorf("Calling TPM2_NV_Write: %v", err) } @@ -400,24 +449,41 @@ func TestNVCertify(t *testing.T) { if err != nil { t.Fatalf("Failed to certify: %v", err) } - - info, err := Marshal(rspCert.CertifyInfo.AttestationData) + certInfo, err := rspCert.CertifyInfo.Contents() if err != nil { - t.Fatalf("Failed to marshal: %v", err) + t.Fatalf("%v", err) } + info := Marshal(certInfo) + attestHash := sha256.Sum256(info) - pub := rspSigner.OutPublic.PublicArea - rsaPub, err := RSAPub(pub.Parameters.RSADetail, pub.Unique.RSA) + pub, err := rspSigner.OutPublic.Contents() + if err != nil { + t.Fatalf("%v", err) + } + rsaDetail, err := pub.Parameters.RSADetail() + if err != nil { + t.Fatalf("%v", err) + } + rsaUnique, err := pub.Unique.RSA() + if err != nil { + t.Fatalf("%v", err) + } + + rsaPub, err := RSAPub(rsaDetail, rsaUnique) if err != nil { t.Fatalf("Failed to retrieve Public Key: %v", err) } - if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, attestHash[:], rspCert.Signature.Signature.RSASSA.Sig.Buffer); err != nil { + rsassa, err := rspCert.Signature.Signature.RSASSA() + if err != nil { + t.Fatalf("%v", err) + } + if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, attestHash[:], rsassa.Sig.Buffer); err != nil { t.Errorf("Signature verification failed: %v", err) } - if !cmp.Equal([]byte("nonce"), rspCert.CertifyInfo.AttestationData.ExtraData.Buffer) { + if !cmp.Equal([]byte("nonce"), certInfo.ExtraData.Buffer) { t.Errorf("Attested buffer is different from original buffer") } } diff --git a/tpm2/test/clear_test.go b/tpm2/test/clear_test.go index 87625560..bdac11f5 100644 --- a/tpm2/test/clear_test.go +++ b/tpm2/test/clear_test.go @@ -17,9 +17,7 @@ func TestClear(t *testing.T) { srkCreate := CreatePrimary{ PrimaryHandle: TPMRHOwner, - InPublic: TPM2BPublic{ - PublicArea: ECCSRKTemplate, - }, + InPublic: New2B(ECCSRKTemplate), } srkCreateRsp, err := srkCreate.Execute(thetpm) @@ -35,7 +33,7 @@ func TestClear(t *testing.T) { Auth: PasswordAuth(nil), }, } - err = clear.Execute(thetpm) + _, err = clear.Execute(thetpm) if err != nil { t.Fatalf("could not clear TPM: %v", err) } @@ -48,7 +46,7 @@ func TestClear(t *testing.T) { flush := FlushContext{ FlushHandle: srkCreateRsp.ObjectHandle, } - err := flush.Execute(thetpm) + _, err := flush.Execute(thetpm) if err != nil { t.Fatalf("could not flush SRK: %v", err) } diff --git a/tpm2/test/combined_context_test.go b/tpm2/test/combined_context_test.go index 3cb724fc..3640431d 100644 --- a/tpm2/test/combined_context_test.go +++ b/tpm2/test/combined_context_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" "github.com/google/go-tpm/tpm2/transport/simulator" @@ -37,32 +38,31 @@ func TestCombinedContext(t *testing.T) { createPrimary := CreatePrimary{ PrimaryHandle: TPMRHOwner, - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgRSA, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - SignEncrypt: true, - FixedTPM: true, - FixedParent: true, - SensitiveDataOrigin: true, - UserWithAuth: true, - }, - Parameters: TPMUPublicParms{ - RSADetail: &TPMSRSAParms{ - Scheme: TPMTRSAScheme{ - Scheme: TPMAlgRSASSA, - Details: TPMUAsymScheme{ - RSASSA: &TPMSSigSchemeRSASSA{ - HashAlg: TPMAlgSHA256, - }, + InPublic: New2B(TPMTPublic{ + Type: TPMAlgRSA, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + SignEncrypt: true, + FixedTPM: true, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgRSA, + &TPMSRSAParms{ + Scheme: TPMTRSAScheme{ + Scheme: TPMAlgRSASSA, + Details: NewTPMUAsymScheme( + TPMAlgRSASSA, &TPMSSigSchemeRSASSA{ + HashAlg: TPMAlgSHA256, }, - }, - KeyBits: 2048, + ), }, + KeyBits: 2048, }, - }, - }, + ), + }), CreationPCR: TPMLPCRSelection{ PCRSelections: []TPMSPCRSelection{ { @@ -105,7 +105,7 @@ func TestCombinedContext(t *testing.T) { rspCLName := ReadPublicName(t, rspCL.LoadedHandle, thetpm) rspCPName := ReadPublicName(t, rspCP.ObjectHandle, thetpm) - if !cmp.Equal(rspCLName, rspCPName) { + if !cmp.Equal(rspCLName, rspCPName, cmpopts.IgnoreUnexported(rspCLName)) { t.Error("Mismatch between public returned from ContextLoad & CreateLoaded") } } diff --git a/tpm2/test/commit_test.go b/tpm2/test/commit_test.go index 96b512fa..226df310 100644 --- a/tpm2/test/commit_test.go +++ b/tpm2/test/commit_test.go @@ -20,14 +20,14 @@ func TestCommit(t *testing.T) { create := CreateLoaded{ ParentHandle: TPMRHOwner, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ + Sensitive: &TPMSSensitiveCreate{ UserAuth: TPM2BAuth{ Buffer: password, }, }, }, - InPublic: TPM2BTemplate{ - Template: TPMTPublic{ + InPublic: New2BTemplate( + &TPMTPublic{ Type: TPMAlgECC, NameAlg: TPMAlgSHA256, ObjectAttributes: TPMAObject{ @@ -37,27 +37,28 @@ func TestCommit(t *testing.T) { SensitiveDataOrigin: true, SignEncrypt: true, }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ Symmetric: TPMTSymDefObject{ Algorithm: TPMAlgNull, }, Scheme: TPMTECCScheme{ Scheme: TPMAlgECDAA, - Details: TPMUAsymScheme{ - ECDAA: &TPMSSigSchemeECDAA{ + Details: NewTPMUAsymScheme( + TPMAlgECDAA, + &TPMSSchemeECDAA{ HashAlg: TPMAlgSHA256, }, - }, + ), }, CurveID: TPMECCBNP256, KDF: TPMTKDFScheme{ Scheme: TPMAlgNull, }, }, - }, - }, - }, + ), + }), } rspCP, err := create.Execute(thetpm) @@ -74,16 +75,15 @@ func TestCommit(t *testing.T) { Name: rspCP.Name, Auth: PasswordAuth(password), }, - P1: TPM2BECCPoint{ - Point: TPMSECCPoint{ + P1: New2B( + TPMSECCPoint{ X: TPM2BECCParameter{ Buffer: []byte{1}, }, Y: TPM2BECCParameter{ Buffer: []byte{2}, }, - }, - }, + }), S2: TPM2BSensitiveData{ Buffer: []byte{}, }, diff --git a/tpm2/test/create_loaded_test.go b/tpm2/test/create_loaded_test.go index 06f8582d..f55b5e4b 100644 --- a/tpm2/test/create_loaded_test.go +++ b/tpm2/test/create_loaded_test.go @@ -13,8 +13,8 @@ func getDeriver(t *testing.T, thetpm transport.TPM) NamedHandle { cl := CreateLoaded{ ParentHandle: TPMRHOwner, - InPublic: TPM2BTemplate{ - Template: TPMTPublic{ + InPublic: New2BTemplate( + &TPMTPublic{ Type: TPMAlgKeyedHash, NameAlg: TPMAlgSHA256, ObjectAttributes: TPMAObject{ @@ -23,21 +23,22 @@ func getDeriver(t *testing.T, thetpm transport.TPM) NamedHandle { Decrypt: true, Restricted: true, }, - Parameters: TPMUPublicParms{ - KeyedHashDetail: &TPMSKeyedHashParms{ + Parameters: NewTPMUPublicParms( + TPMAlgKeyedHash, + &TPMSKeyedHashParms{ Scheme: TPMTKeyedHashScheme{ Scheme: TPMAlgXOR, - Details: TPMUSchemeKeyedHash{ - XOR: &TPMSSchemeXOR{ + Details: NewTPMUSchemeKeyedHash( + TPMAlgXOR, + &TPMSSchemeXOR{ HashAlg: TPMAlgSHA256, KDF: TPMAlgKDF1SP800108, }, - }, + ), }, }, - }, - }, - }, + ), + }), } rsp, err := cl.Execute(thetpm) if err != nil { @@ -58,24 +59,32 @@ func TestCreateLoaded(t *testing.T) { deriver := getDeriver(t, thetpm) + derive := New2B( + TPMSDerive{ + Label: TPM2BLabel{ + Buffer: []byte("label"), + }, + Context: TPM2BLabel{ + Buffer: []byte("context"), + }, + }) + createLoadeds := map[string]*CreateLoaded{ "PrimaryKey": { ParentHandle: TPMRHEndorsement, - InPublic: TPM2BTemplate{ - Template: ECCEKTemplate, - }, + InPublic: New2BTemplate(&ECCEKTemplate), }, "OrdinaryKey": { ParentHandle: TPMRHOwner, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ + Sensitive: &TPMSSensitiveCreate{ UserAuth: TPM2BAuth{ Buffer: []byte("p@ssw0rd"), }, }, }, - InPublic: TPM2BTemplate{ - Template: TPMTPublic{ + InPublic: New2BTemplate( + &TPMTPublic{ Type: TPMAlgECC, NameAlg: TPMAlgSHA256, ObjectAttributes: TPMAObject{ @@ -83,57 +92,47 @@ func TestCreateLoaded(t *testing.T) { UserWithAuth: true, SignEncrypt: true, }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ CurveID: TPMECCNistP256, }, - }, - }, - }, + ), + }), }, "DataBlob": { ParentHandle: TPMRHOwner, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ + Sensitive: &TPMSSensitiveCreate{ UserAuth: TPM2BAuth{ Buffer: []byte("p@ssw0rd"), }, - Data: TPM2BSensitiveData{ + Data: NewTPMUSensitiveCreate(&TPM2BSensitiveData{ Buffer: []byte("secrets"), - }, + }), }, }, - InPublic: TPM2BTemplate{ - Template: TPMTPublic{ + InPublic: New2BTemplate( + &TPMTPublic{ Type: TPMAlgKeyedHash, NameAlg: TPMAlgSHA256, ObjectAttributes: TPMAObject{ UserWithAuth: true, }, - }, - }, + }), }, "Derived": { ParentHandle: deriver, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ + Sensitive: &TPMSSensitiveCreate{ UserAuth: TPM2BAuth{ Buffer: []byte("p@ssw0rd"), }, - Data: TPM2BDerive{ - Buffer: TPMSDerive{ - Label: TPM2BLabel{ - Buffer: []byte("label"), - }, - Context: TPM2BLabel{ - Buffer: []byte("context"), - }, - }, - }, + Data: NewTPMUSensitiveCreate(&derive), }, }, - InPublic: TPM2BTemplate{ - Template: TPMTPublic{ + InPublic: New2BTemplate( + &TPMTPublic{ Type: TPMAlgECC, NameAlg: TPMAlgSHA256, ObjectAttributes: TPMAObject{ @@ -141,13 +140,13 @@ func TestCreateLoaded(t *testing.T) { UserWithAuth: true, SignEncrypt: true, }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ CurveID: TPMECCNistP256, }, - }, - }, - }, + ), + }), }, } @@ -157,7 +156,7 @@ func TestCreateLoaded(t *testing.T) { if err != nil { t.Fatalf("error from CreateLoaded: %v", err) } - if err = (&FlushContext{FlushHandle: rsp.ObjectHandle}).Execute(thetpm); err != nil { + if _, err = (FlushContext{FlushHandle: rsp.ObjectHandle}).Execute(thetpm); err != nil { t.Errorf("error from FlushContext: %v", err) } }) diff --git a/tpm2/test/ecdh_test.go b/tpm2/test/ecdh_test.go index 9c76d4c5..fbf333f2 100644 --- a/tpm2/test/ecdh_test.go +++ b/tpm2/test/ecdh_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport/simulator" ) @@ -21,46 +22,53 @@ func TestECDH(t *testing.T) { // Create a TPM ECDH key tpmCreate := CreatePrimary{ PrimaryHandle: TPMRHOwner, - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgECC, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - FixedTPM: true, - STClear: false, - FixedParent: true, - SensitiveDataOrigin: true, - UserWithAuth: true, - AdminWithPolicy: false, - NoDA: true, - EncryptedDuplication: false, - Restricted: false, - Decrypt: true, - SignEncrypt: false, - X509Sign: false, - }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ - CurveID: TPMECCNistP256, - Scheme: TPMTECCScheme{ - Scheme: TPMAlgECDH, - Details: TPMUAsymScheme{ - ECDH: &TPMSKeySchemeECDH{ - HashAlg: TPMAlgSHA256, - }, + InPublic: New2B(TPMTPublic{ + Type: TPMAlgECC, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + STClear: false, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + AdminWithPolicy: false, + NoDA: true, + EncryptedDuplication: false, + Restricted: false, + Decrypt: true, + SignEncrypt: false, + X509Sign: false, + }, + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ + CurveID: TPMECCNistP256, + Scheme: TPMTECCScheme{ + Scheme: TPMAlgECDH, + Details: NewTPMUAsymScheme( + TPMAlgECDH, + &TPMSKeySchemeECDH{ + HashAlg: TPMAlgSHA256, }, - }, + ), }, }, - }, - }, + ), + }), } tpmCreateRsp, err := tpmCreate.Execute(thetpm) if err != nil { t.Fatalf("could not create the TPM key: %v", err) } - tpmPub := tpmCreateRsp.OutPublic.PublicArea.Unique.ECC + outPub, err := tpmCreateRsp.OutPublic.Contents() + if err != nil { + t.Fatalf("%v", err) + } + tpmPub, err := outPub.Unique.ECC() + if err != nil { + t.Fatalf("%v", err) + } tpmX := big.NewInt(0).SetBytes(tpmPub.X.Buffer) tpmY := big.NewInt(0).SetBytes(tpmPub.Y.Buffer) @@ -88,16 +96,18 @@ func TestECDH(t *testing.T) { Name: tpmCreateRsp.Name, Auth: PasswordAuth(nil), }, - InPoint: TPM2BECCPoint{ - Point: swPub, - }, + InPoint: New2B(swPub), } ecdhRsp, err := ecdh.Execute(thetpm) if err != nil { t.Fatalf("ECDH_ZGen failed: %v", err) } - if !cmp.Equal(z, ecdhRsp.OutPoint.Point) { - t.Errorf("want %x got %x", z, ecdhRsp.OutPoint.Point) + outPoint, err := ecdhRsp.OutPoint.Contents() + if err != nil { + t.Fatalf("%v", err) + } + if !cmp.Equal(z.X, outPoint.X, cmpopts.IgnoreUnexported(z.X)) { + t.Errorf("want %x got %x", z, outPoint) } } diff --git a/tpm2/test/ek_test.go b/tpm2/test/ek_test.go index 3dc84cb5..1c91b0f4 100644 --- a/tpm2/test/ek_test.go +++ b/tpm2/test/ek_test.go @@ -102,22 +102,35 @@ func ekTest(t *testing.T, ekTemplate TPMTPublic) { // Create the EK createEKCmd := CreatePrimary{ PrimaryHandle: TPMRHEndorsement, - InPublic: TPM2BPublic{ - PublicArea: ekTemplate, - }, + InPublic: New2B(ekTemplate), } createEKRsp, err := createEKCmd.Execute(thetpm) if err != nil { t.Fatalf("%v", err) } - if createEKRsp.OutPublic.PublicArea.Unique.ECC != nil { - t.Logf("EK pub:\n%x\n%x\n", createEKRsp.OutPublic.PublicArea.Unique.ECC.X, createEKRsp.OutPublic.PublicArea.Unique.ECC.Y) - t.Logf("EK name: %x", createEKRsp.Name) + outPub, err := createEKRsp.OutPublic.Contents() + if err != nil { + t.Fatalf("%v", err) } + switch outPub.Type { + case TPMAlgRSA: + rsa, err := outPub.Unique.RSA() + if err != nil { + t.Fatalf("%v", err) + } + t.Logf("EK pub:\n%x\n", rsa.Buffer) + case TPMAlgECC: + ecc, err := outPub.Unique.ECC() + if err != nil { + t.Fatalf("%v", err) + } + t.Logf("EK pub:\n%x\n%x\n", ecc.X, ecc.Y) + } + t.Logf("EK name: %x", createEKRsp.Name) defer func() { // Flush the EK flush := FlushContext{FlushHandle: createEKRsp.ObjectHandle} - if err := flush.Execute(thetpm); err != nil { + if _, err := flush.Execute(thetpm); err != nil { t.Errorf("%v", err) } }() @@ -131,24 +144,22 @@ func ekTest(t *testing.T, ekTemplate TPMTPublic) { Name: createEKRsp.Name, }, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ - Data: TPM2BSensitiveData{ + Sensitive: &TPMSSensitiveCreate{ + Data: NewTPMUSensitiveCreate(&TPM2BSensitiveData{ Buffer: data, - }, + }), }, }, - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgKeyedHash, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - FixedTPM: true, - FixedParent: true, - UserWithAuth: true, - NoDA: true, - }, + InPublic: New2B(TPMTPublic{ + Type: TPMAlgKeyedHash, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + FixedParent: true, + UserWithAuth: true, + NoDA: true, }, - }, + }), } var sessions []Session @@ -164,7 +175,7 @@ func ekTest(t *testing.T, ekTemplate TPMTPublic) { options = append(options, Bound(createEKRsp.ObjectHandle, createEKRsp.Name, nil)) } if c.salted { - options = append(options, Salted(createEKRsp.ObjectHandle, createEKRsp.OutPublic.PublicArea)) + options = append(options, Salted(createEKRsp.ObjectHandle, *outPub)) } var s Session diff --git a/tpm2/test/load_external_test.go b/tpm2/test/load_external_test.go index 09d557cf..238d4e51 100644 --- a/tpm2/test/load_external_test.go +++ b/tpm2/test/load_external_test.go @@ -20,54 +20,54 @@ func decodeHex(t *testing.T, h string) []byte { func TestLoadExternal(t *testing.T) { loads := map[string]*LoadExternal{ "ECCNoSensitive": { - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgECC, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - SignEncrypt: true, - }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ - CurveID: TPMECCNistP256, - }, + InPublic: New2B(TPMTPublic{ + Type: TPMAlgECC, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + SignEncrypt: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ + CurveID: TPMECCNistP256, }, - Unique: TPMUPublicID{ - // This happens to be a P256 EKpub from the simulator - ECC: &TPMSECCPoint{ - X: TPM2BECCParameter{Buffer: decodeHex(t, "9855efa3514873b88067ab127b2d4692864a395db3d9e4ccad0592478a245c16")}, - Y: TPM2BECCParameter{Buffer: decodeHex(t, "e802a26649839a2d7b13c812a5dc0b61c110cbe62db784d96e60a823448c8993")}, - }, + ), + Unique: NewTPMUPublicID( + // This happens to be a P256 EKpub from the simulator + TPMAlgECC, + &TPMSECCPoint{ + X: TPM2BECCParameter{Buffer: decodeHex(t, "9855efa3514873b88067ab127b2d4692864a395db3d9e4ccad0592478a245c16")}, + Y: TPM2BECCParameter{Buffer: decodeHex(t, "e802a26649839a2d7b13c812a5dc0b61c110cbe62db784d96e60a823448c8993")}, }, - }, - }, + ), + }), }, "KeyedHashSensitive": { - InPrivate: &TPM2BSensitive{ - SensitiveArea: TPMTSensitive{ + InPrivate: New2B( + TPMTSensitive{ SensitiveType: TPMAlgKeyedHash, SeedValue: TPM2BDigest{ Buffer: []byte("obfuscation is my middle name!!!"), }, - Sensitive: TPMUSensitiveComposite{ - Bits: &TPM2BSensitiveData{ + Sensitive: NewTPMUSensitiveComposite( + TPMAlgKeyedHash, + &TPM2BSensitiveData{ Buffer: []byte("secrets"), }, - }, - }, - }, - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ + ), + }), + InPublic: New2B( + TPMTPublic{ Type: TPMAlgKeyedHash, NameAlg: TPMAlgSHA256, - Unique: TPMUPublicID{ - KeyedHash: &TPM2BDigest{ + Unique: NewTPMUPublicID( + TPMAlgKeyedHash, + &TPM2BDigest{ // SHA256("obfuscation is my middle name!!!secrets") Buffer: decodeHex(t, "ed4fe8e2bff97665e7bfbe27c2365d07a9be91dd92d997cd91cc706b6074eb08"), }, - }, - }, - }, + ), + }), }, } @@ -83,7 +83,7 @@ func TestLoadExternal(t *testing.T) { if err != nil { t.Fatalf("error from LoadExternal: %v", err) } - if err = (&FlushContext{FlushHandle: rsp.ObjectHandle}).Execute(thetpm); err != nil { + if _, err = (FlushContext{FlushHandle: rsp.ObjectHandle}).Execute(thetpm); err != nil { t.Errorf("error from FlushContext: %v", err) } }) diff --git a/tpm2/test/names_test.go b/tpm2/test/names_test.go index eca93f6c..270396a6 100644 --- a/tpm2/test/names_test.go +++ b/tpm2/test/names_test.go @@ -25,9 +25,7 @@ func TestObjectName(t *testing.T) { createPrimary := CreatePrimary{ PrimaryHandle: TPMRHEndorsement, - InPublic: TPM2BPublic{ - PublicArea: ECCEKTemplate, - }, + InPublic: New2B(ECCEKTemplate), } rsp, err := createPrimary.Execute(thetpm) if err != nil { @@ -38,7 +36,11 @@ func TestObjectName(t *testing.T) { public := rsp.OutPublic want := rsp.Name - name, err := ObjectName(&public.PublicArea) + pub, err := public.Contents() + if err != nil { + t.Fatalf("%v", err) + } + name, err := ObjectName(pub) if err != nil { t.Fatalf("error from ObjectName: %v", err) } @@ -54,8 +56,8 @@ func TestNVName(t *testing.T) { } defer thetpm.Close() - public := TPM2BNVPublic{ - NVPublic: TPMSNVPublic{ + public := New2B( + TPMSNVPublic{ NVIndex: TPMHandle(0x0180000F), NameAlg: TPMAlgSHA256, Attributes: TPMANV{ @@ -64,19 +66,22 @@ func TestNVName(t *testing.T) { NT: TPMNTOrdinary, }, DataSize: 4, - }, - } + }) defineSpace := NVDefineSpace{ AuthHandle: TPMRHOwner, PublicInfo: public, } - if err := defineSpace.Execute(thetpm); err != nil { + if _, err := defineSpace.Execute(thetpm); err != nil { t.Fatalf("could not call TPM2_DefineSpace: %v", err) } + pub, err := public.Contents() + if err != nil { + t.Fatalf("%v", err) + } readPublic := NVReadPublic{ - NVIndex: public.NVPublic.NVIndex, + NVIndex: pub.NVIndex, } rsp, err := readPublic.Execute(thetpm) if err != nil { @@ -84,7 +89,7 @@ func TestNVName(t *testing.T) { } want := rsp.NVName - name, err := NVName(&public.NVPublic) + name, err := NVName(pub) if err != nil { t.Fatalf("error from NVIndexName: %v", err) } diff --git a/tpm2/test/nv_test.go b/tpm2/test/nv_test.go index 44633147..f37b526f 100644 --- a/tpm2/test/nv_test.go +++ b/tpm2/test/nv_test.go @@ -21,8 +21,8 @@ func TestNVAuthWrite(t *testing.T) { Auth: TPM2BAuth{ Buffer: []byte("p@ssw0rd"), }, - PublicInfo: TPM2BNVPublic{ - NVPublic: TPMSNVPublic{ + PublicInfo: New2B( + TPMSNVPublic{ NVIndex: TPMHandle(0x0180000F), NameAlg: TPMAlgSHA256, Attributes: TPMANV{ @@ -34,26 +34,29 @@ func TestNVAuthWrite(t *testing.T) { NoDA: true, }, DataSize: 4, - }, - }, + }), } - if err := def.Execute(thetpm); err != nil { + if _, err := def.Execute(thetpm); err != nil { t.Fatalf("Calling TPM2_NV_DefineSpace: %v", err) } - nvName, err := NVName(&def.PublicInfo.NVPublic) + pub, err := def.PublicInfo.Contents() + if err != nil { + t.Fatalf("%v", err) + } + nvName, err := NVName(pub) if err != nil { t.Fatalf("Calculating name of NV index: %v", err) } prewrite := NVWrite{ AuthHandle: AuthHandle{ - Handle: def.PublicInfo.NVPublic.NVIndex, + Handle: pub.NVIndex, Name: *nvName, Auth: PasswordAuth([]byte("p@ssw0rd")), }, NVIndex: NamedHandle{ - Handle: def.PublicInfo.NVPublic.NVIndex, + Handle: pub.NVIndex, Name: *nvName, }, Data: TPM2BMaxNVBuffer{ @@ -61,12 +64,12 @@ func TestNVAuthWrite(t *testing.T) { }, Offset: 0, } - if err := prewrite.Execute(thetpm); err != nil { + if _, err := prewrite.Execute(thetpm); err != nil { t.Errorf("Calling TPM2_NV_Write: %v", err) } read := NVReadPublic{ - NVIndex: def.PublicInfo.NVPublic.NVIndex, + NVIndex: pub.NVIndex, } readRsp, err := read.Execute(thetpm) if err != nil { @@ -80,7 +83,7 @@ func TestNVAuthWrite(t *testing.T) { Auth: HMAC(TPMAlgSHA256, 16, Auth([]byte{})), }, NVIndex: NamedHandle{ - Handle: def.PublicInfo.NVPublic.NVIndex, + Handle: pub.NVIndex, Name: readRsp.NVName, }, Data: TPM2BMaxNVBuffer{ @@ -88,7 +91,7 @@ func TestNVAuthWrite(t *testing.T) { }, Offset: 0, } - if err := write.Execute(thetpm); err != nil { + if _, err := write.Execute(thetpm); err != nil { t.Errorf("Calling TPM2_NV_Write: %v", err) } } @@ -106,8 +109,8 @@ func TestNVAuthIncrement(t *testing.T) { Auth: TPM2BAuth{ Buffer: []byte("p@ssw0rd"), }, - PublicInfo: TPM2BNVPublic{ - NVPublic: TPMSNVPublic{ + PublicInfo: New2B( + TPMSNVPublic{ NVIndex: TPMHandle(0x0180000F), NameAlg: TPMAlgSHA256, Attributes: TPMANV{ @@ -119,16 +122,19 @@ func TestNVAuthIncrement(t *testing.T) { NoDA: true, }, DataSize: 8, - }, - }, + }), } - if err := def.Execute(thetpm); err != nil { + if _, err := def.Execute(thetpm); err != nil { t.Fatalf("Calling TPM2_NV_DefineSpace: %v", err) } + pub, err := def.PublicInfo.Contents() + if err != nil { + t.Fatalf("%v", err) + } // Calculate the Name of the index as of its creation // (i.e., without NV_WRITTEN set). - nvName, err := NVName(&def.PublicInfo.NVPublic) + nvName, err := NVName(pub) if err != nil { t.Fatalf("Calculating name of NV index: %v", err) } @@ -139,24 +145,24 @@ func TestNVAuthIncrement(t *testing.T) { Auth: HMAC(TPMAlgSHA256, 16, Auth([]byte{})), }, NVIndex: NamedHandle{ - Handle: def.PublicInfo.NVPublic.NVIndex, + Handle: pub.NVIndex, Name: *nvName, }, } - if err := incr.Execute(thetpm); err != nil { + if _, err := incr.Execute(thetpm); err != nil { t.Errorf("Calling TPM2_NV_Increment: %v", err) } // The NV index's Name has changed. Ask the TPM for it. readPub := NVReadPublic{ - NVIndex: def.PublicInfo.NVPublic.NVIndex, + NVIndex: pub.NVIndex, } readPubRsp, err := readPub.Execute(thetpm) if err != nil { t.Fatalf("Calling TPM2_NV_ReadPublic: %v", err) } incr.NVIndex = NamedHandle{ - Handle: def.PublicInfo.NVPublic.NVIndex, + Handle: pub.NVIndex, Name: readPubRsp.NVName, } @@ -166,7 +172,7 @@ func TestNVAuthIncrement(t *testing.T) { Auth: HMAC(TPMAlgSHA256, 16, Auth([]byte{})), }, NVIndex: NamedHandle{ - Handle: def.PublicInfo.NVPublic.NVIndex, + Handle: pub.NVIndex, Name: readPubRsp.NVName, }, Size: 8, @@ -176,7 +182,7 @@ func TestNVAuthIncrement(t *testing.T) { t.Fatalf("Calling TPM2_NV_Read: %v", err) } - if err := incr.Execute(thetpm); err != nil { + if _, err := incr.Execute(thetpm); err != nil { t.Errorf("Calling TPM2_NV_Increment: %v", err) } diff --git a/tpm2/test/pcr_test.go b/tpm2/test/pcr_test.go index 48ad88ff..bb13ba48 100644 --- a/tpm2/test/pcr_test.go +++ b/tpm2/test/pcr_test.go @@ -91,7 +91,7 @@ func TestPCRReset(t *testing.T) { }, }, } - if err := pcrExtend.Execute(thetpm); err != nil { + if _, err := pcrExtend.Execute(thetpm); err != nil { t.Fatalf("failed to extend pcr for test %v", err) } } diff --git a/tpm2/test/policy_test.go b/tpm2/test/policy_test.go index ad0f7eea..82d06fc3 100644 --- a/tpm2/test/policy_test.go +++ b/tpm2/test/policy_test.go @@ -14,32 +14,32 @@ func signingKey(t *testing.T, thetpm transport.TPM) (NamedHandle, func()) { t.Helper() createPrimary := CreatePrimary{ PrimaryHandle: TPMRHOwner, - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgECC, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - FixedTPM: true, - FixedParent: true, - SensitiveDataOrigin: true, - UserWithAuth: true, - SignEncrypt: true, - }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ - Scheme: TPMTECCScheme{ - Scheme: TPMAlgECDSA, - Details: TPMUAsymScheme{ - ECDSA: &TPMSSigSchemeECDSA{ - HashAlg: TPMAlgSHA256, - }, + InPublic: New2B(TPMTPublic{ + Type: TPMAlgECC, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + SignEncrypt: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ + Scheme: TPMTECCScheme{ + Scheme: TPMAlgECDSA, + Details: NewTPMUAsymScheme( + TPMAlgECDSA, + &TPMSSigSchemeECDSA{ + HashAlg: TPMAlgSHA256, }, - }, - CurveID: TPMECCNistP256, + ), }, + CurveID: TPMECCNistP256, }, - }, - }, + ), + }), } rsp, err := createPrimary.Execute(thetpm) if err != nil { @@ -50,7 +50,7 @@ func signingKey(t *testing.T, thetpm transport.TPM) (NamedHandle, func()) { flush := FlushContext{ FlushHandle: rsp.ObjectHandle, } - if err := flush.Execute(thetpm); err != nil { + if _, err := flush.Execute(thetpm); err != nil { t.Errorf("could not flush signing key: %v", err) } } @@ -64,8 +64,8 @@ func nvIndex(t *testing.T, thetpm transport.TPM) (NamedHandle, func()) { t.Helper() defSpace := NVDefineSpace{ AuthHandle: TPMRHOwner, - PublicInfo: TPM2BNVPublic{ - NVPublic: TPMSNVPublic{ + PublicInfo: New2B( + TPMSNVPublic{ NVIndex: 0x01800001, NameAlg: TPMAlgSHA256, Attributes: TPMANV{ @@ -73,14 +73,17 @@ func nvIndex(t *testing.T, thetpm transport.TPM) (NamedHandle, func()) { AuthRead: true, NT: TPMNTOrdinary, }, - }, - }, + }), } - if err := defSpace.Execute(thetpm); err != nil { + if _, err := defSpace.Execute(thetpm); err != nil { t.Fatalf("could not create NV index: %v", err) } + pub, err := defSpace.PublicInfo.Contents() + if err != nil { + t.Fatalf("%v", err) + } readPub := NVReadPublic{ - NVIndex: defSpace.PublicInfo.NVPublic.NVIndex, + NVIndex: pub.NVIndex, } readRsp, err := readPub.Execute(thetpm) if err != nil { @@ -91,16 +94,16 @@ func nvIndex(t *testing.T, thetpm transport.TPM) (NamedHandle, func()) { undefine := NVUndefineSpace{ AuthHandle: TPMRHOwner, NVIndex: NamedHandle{ - Handle: defSpace.PublicInfo.NVPublic.NVIndex, + Handle: pub.NVIndex, Name: readRsp.NVName, }, } - if err := undefine.Execute(thetpm); err != nil { + if _, err := undefine.Execute(thetpm); err != nil { t.Errorf("could not undefine NV index: %v", err) } } return NamedHandle{ - Handle: defSpace.PublicInfo.NVPublic.NVIndex, + Handle: pub.NVIndex, Name: readRsp.NVName, }, cleanup } @@ -133,11 +136,12 @@ func TestPolicySignedUpdate(t *testing.T) { PolicyRef: TPM2BNonce{Buffer: []byte{5, 6, 7, 8}}, Auth: TPMTSignature{ SigAlg: TPMAlgECDSA, - Signature: TPMUSignature{ - ECDSA: &TPMSSignatureECC{ + Signature: NewTPMUSignature( + TPMAlgECDSA, + &TPMSSignatureECC{ Hash: TPMAlgSHA256, }, - }, + ), }, } @@ -251,7 +255,7 @@ func TestPolicyOrUpdate(t *testing.T) { }, } - if err := policyOr.Execute(thetpm); err != nil { + if _, err := policyOr.Execute(thetpm); err != nil { t.Fatalf("executing PolicyOr: %v", err) } @@ -357,7 +361,7 @@ func TestPolicyPCR(t *testing.T) { Pcrs: selection, } - err = policyPCR.Execute(thetpm) + _, err = policyPCR.Execute(thetpm) if tt.callShouldSucceed { if err != nil { t.Fatalf("executing PolicyPCR: %v", err) @@ -432,7 +436,7 @@ func TestPolicyCpHashUpdate(t *testing.T) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}}, } - if err := policyCpHash.Execute(thetpm); err != nil { + if _, err := policyCpHash.Execute(thetpm); err != nil { t.Fatalf("executing PolicyCpHash: %v", err) } @@ -489,7 +493,7 @@ func TestPolicyAuthorizeUpdate(t *testing.T) { }, } - if err := policyAuthorize.Execute(thetpm); err != nil { + if _, err := policyAuthorize.Execute(thetpm); err != nil { t.Fatalf("executing PolicyAuthorize: %v", err) } @@ -594,7 +598,7 @@ func TestPolicyNVUpdate(t *testing.T) { Operation: TPMEOSignedLE, } - if err := policyNV.Execute(thetpm); err != nil { + if _, err := policyNV.Execute(thetpm); err != nil { t.Fatalf("executing PolicyAuthorizeNV: %v", err) } @@ -647,7 +651,7 @@ func TestPolicyAuthorizeNVUpdate(t *testing.T) { NVIndex: nv, } - if err := policyAuthorizeNV.Execute(thetpm); err != nil { + if _, err := policyAuthorizeNV.Execute(thetpm); err != nil { t.Fatalf("executing PolicyAuthorizeNV: %v", err) } @@ -695,7 +699,7 @@ func TestPolicyCommandCodeUpdate(t *testing.T) { PolicySession: sess.Handle(), Code: TPMCCCreate, } - if err := pcc.Execute(thetpm); err != nil { + if _, err := pcc.Execute(thetpm); err != nil { t.Fatalf("executing PolicyCommandCode: %v", err) } diff --git a/tpm2/test/read_public_test.go b/tpm2/test/read_public_test.go index 7e9f77cf..e3e99cc3 100644 --- a/tpm2/test/read_public_test.go +++ b/tpm2/test/read_public_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport/simulator" ) @@ -25,32 +26,32 @@ func TestReadPublicKey(t *testing.T) { // See tpm2/templates/go for more TPMTPublic examples. createPrimary := CreatePrimary{ PrimaryHandle: TPMRHOwner, - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgECC, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - FixedTPM: true, - FixedParent: true, - SensitiveDataOrigin: true, - UserWithAuth: true, - SignEncrypt: true, - }, - Parameters: TPMUPublicParms{ - ECCDetail: &TPMSECCParms{ - Scheme: TPMTECCScheme{ - Scheme: TPMAlgECDSA, - Details: TPMUAsymScheme{ - ECDSA: &TPMSSigSchemeECDSA{ - HashAlg: TPMAlgSHA256, - }, + InPublic: New2B(TPMTPublic{ + Type: TPMAlgECC, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + SignEncrypt: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ + Scheme: TPMTECCScheme{ + Scheme: TPMAlgECDSA, + Details: NewTPMUAsymScheme( + TPMAlgECDSA, + &TPMSSigSchemeECDSA{ + HashAlg: TPMAlgSHA256, }, - }, - CurveID: TPMECCNistP256, + ), }, + CurveID: TPMECCNistP256, }, - }, - }, + ), + }), } // Executing the command uses reflection to pack the bytes into a @@ -89,9 +90,23 @@ func TestReadPublicKey(t *testing.T) { // PublicArea.Unique represents the unique identifier of the TPMTPublic. // Notice how this test uses verification of another TPM command that is // able to produce similar results to validate the response. - rspCPUnique := rspCP.OutPublic.PublicArea.Unique - rspRPUnique := rspRP.OutPublic.PublicArea.Unique - if !cmp.Equal(rspCPUnique, rspRPUnique) { + pubCreate, err := rspCP.OutPublic.Contents() + if err != nil { + t.Fatalf("%v", err) + } + pubRead, err := rspRP.OutPublic.Contents() + if err != nil { + t.Fatalf("%v", err) + } + eccCreate, err := pubCreate.Unique.ECC() + if err != nil { + t.Fatalf("%v", err) + } + eccRead, err := pubRead.Unique.ECC() + if err != nil { + t.Fatalf("%v", err) + } + if !cmp.Equal(eccCreate.X, eccRead.X, cmpopts.IgnoreUnexported(eccCreate.X)) { t.Error("Mismatch between public returned from CreatePrimary & ReadPublic") } } diff --git a/tpm2/test/sealing_test.go b/tpm2/test/sealing_test.go index a4904960..8f498f61 100644 --- a/tpm2/test/sealing_test.go +++ b/tpm2/test/sealing_test.go @@ -37,15 +37,13 @@ func unsealingTest(t *testing.T, srkTemplate TPMTPublic) { createSRKCmd := CreatePrimary{ PrimaryHandle: TPMRHOwner, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ + Sensitive: &TPMSSensitiveCreate{ UserAuth: TPM2BAuth{ Buffer: srkAuth, }, }, }, - InPublic: TPM2BPublic{ - PublicArea: srkTemplate, - }, + InPublic: New2B(srkTemplate), } createSRKRsp, err := createSRKCmd.Execute(thetpm) if err != nil { @@ -55,7 +53,7 @@ func unsealingTest(t *testing.T, srkTemplate TPMTPublic) { defer func() { // Flush the SRK flushSRKCmd := FlushContext{FlushHandle: createSRKRsp.ObjectHandle} - if err := flushSRKCmd.Execute(thetpm); err != nil { + if _, err := flushSRKCmd.Execute(thetpm); err != nil { t.Errorf("%v", err) } }() @@ -72,28 +70,27 @@ func unsealingTest(t *testing.T, srkTemplate TPMTPublic) { Auth: PasswordAuth(srkAuth), }, InSensitive: TPM2BSensitiveCreate{ - Sensitive: TPMSSensitiveCreate{ + Sensitive: &TPMSSensitiveCreate{ UserAuth: TPM2BAuth{ Buffer: auth, }, - Data: TPM2BSensitiveData{ + Data: NewTPMUSensitiveCreate(&TPM2BSensitiveData{ Buffer: data, - }, + }), }, }, - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgKeyedHash, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - FixedTPM: true, - FixedParent: true, - UserWithAuth: true, - NoDA: true, - }, + InPublic: New2B(TPMTPublic{ + Type: TPMAlgKeyedHash, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + FixedParent: true, + UserWithAuth: true, + NoDA: true, }, - }, + }), } + var createBlobRsp *CreateResponse // Create the blob with password auth, without any session encryption @@ -174,12 +171,16 @@ func unsealingTest(t *testing.T, srkTemplate TPMTPublic) { // Create the blob with decrypt and encrypt session bound to SRK t.Run("CreateDecryptEncryptSalted", func(t *testing.T) { + outPub, err := createSRKRsp.OutPublic.Contents() + if err != nil { + t.Fatalf("%v", err) + } createBlobCmd.ParentHandle = AuthHandle{ Handle: createSRKRsp.ObjectHandle, Name: createSRKRsp.Name, Auth: HMAC(TPMAlgSHA256, 16, Auth(srkAuth), AESEncryption(128, EncryptInOut), - Salted(createSRKRsp.ObjectHandle, createSRKRsp.OutPublic.PublicArea)), + Salted(createSRKRsp.ObjectHandle, *outPub)), } createBlobRsp, err = createBlobCmd.Execute(thetpm) if err != nil { @@ -263,7 +264,7 @@ func unsealingTest(t *testing.T, srkTemplate TPMTPublic) { defer func() { // Flush the blob flushBlobCmd := FlushContext{FlushHandle: loadBlobRsp.ObjectHandle} - if err := flushBlobCmd.Execute(thetpm); err != nil { + if _, err := flushBlobCmd.Execute(thetpm); err != nil { t.Errorf("%v", err) } }() diff --git a/tpm2/test/sign_test.go b/tpm2/test/sign_test.go index 806a38fe..75c85573 100644 --- a/tpm2/test/sign_test.go +++ b/tpm2/test/sign_test.go @@ -69,32 +69,32 @@ func TestSign(t *testing.T) { createPrimary := CreatePrimary{ PrimaryHandle: TPMRHOwner, - InPublic: TPM2BPublic{ - PublicArea: TPMTPublic{ - Type: TPMAlgRSA, - NameAlg: TPMAlgSHA256, - ObjectAttributes: TPMAObject{ - SignEncrypt: true, - FixedTPM: true, - FixedParent: true, - SensitiveDataOrigin: true, - UserWithAuth: true, - }, - Parameters: TPMUPublicParms{ - RSADetail: &TPMSRSAParms{ - Scheme: TPMTRSAScheme{ - Scheme: TPMAlgRSASSA, - Details: TPMUAsymScheme{ - RSASSA: &TPMSSigSchemeRSASSA{ - HashAlg: TPMAlgSHA256, - }, + InPublic: New2B(TPMTPublic{ + Type: TPMAlgRSA, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + SignEncrypt: true, + FixedTPM: true, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + }, + Parameters: NewTPMUPublicParms( + TPMAlgRSA, + &TPMSRSAParms{ + Scheme: TPMTRSAScheme{ + Scheme: TPMAlgRSASSA, + Details: NewTPMUAsymScheme( + TPMAlgRSASSA, + &TPMSSigSchemeRSASSA{ + HashAlg: TPMAlgSHA256, }, - }, - KeyBits: 2048, + ), }, + KeyBits: 2048, }, - }, - }, + ), + }), CreationPCR: TPMLPCRSelection{ PCRSelections: []TPMSPCRSelection{ { @@ -125,11 +125,12 @@ func TestSign(t *testing.T) { }, InScheme: TPMTSigScheme{ Scheme: TPMAlgRSASSA, - Details: TPMUSigScheme{ - RSASSA: &TPMSSchemeHash{ + Details: NewTPMUSigScheme( + TPMAlgRSASSA, + &TPMSSchemeHash{ HashAlg: TPMAlgSHA256, }, - }, + ), }, Validation: TPMTTKHashCheck{ Tag: TPMSTHashCheck, @@ -141,13 +142,29 @@ func TestSign(t *testing.T) { t.Fatalf("Failed to Sign Digest: %v", err) } - pub := rspCP.OutPublic.PublicArea - rsaPub, err := RSAPub(pub.Parameters.RSADetail, pub.Unique.RSA) + pub, err := rspCP.OutPublic.Contents() + if err != nil { + t.Fatalf("%v", err) + } + rsaDetail, err := pub.Parameters.RSADetail() + if err != nil { + t.Fatalf("%v", err) + } + rsaUnique, err := pub.Unique.RSA() + if err != nil { + t.Fatalf("%v", err) + } + + rsaPub, err := RSAPub(rsaDetail, rsaUnique) if err != nil { t.Fatalf("Failed to retrieve Public Key: %v", err) } - if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, digest[:], rspSign.Signature.Signature.RSASSA.Sig.Buffer); err != nil { + rsassa, err := rspSign.Signature.Signature.RSASSA() + if err != nil { + t.Fatalf("%v", err) + } + if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, digest[:], rsassa.Sig.Buffer); err != nil { t.Errorf("Signature verification failed: %v", err) } diff --git a/tpm2/tpm2.go b/tpm2/tpm2.go index 3407e19d..f0fce24b 100644 --- a/tpm2/tpm2.go +++ b/tpm2/tpm2.go @@ -59,23 +59,13 @@ func (h AuthHandle) KnownName() *TPM2BName { return h.Handle.KnownName() } -// Command is a placeholder interface for TPM command structures so that they -// can be easily distinguished from other types of structures. -// TODO: once go-tpm requires Go 1.18, parameterize this type for compile-time -// command/response matching. -type Command interface { +// Command is an interface for any TPM command, parameterized by its response +// type. +type Command[R any, PR *R] interface { // The TPM command code associated with this command. Command() TPMCC -} - -// Response is a placeholder interface for TPM response structures so that they -// can be easily distinguished from other types of structures. -// All implementations of this interface are pointers to structures, for -// settability. -// See https://go.dev/blog/laws-of-reflection -type Response interface { - // The TPM command code associated with this response. - Response() TPMCC + // Executes the command and returns the response. + Execute(t transport.TPM, s ...Session) (PR, error) } // PolicyCommand is a TPM command that can be part of a TPM policy. @@ -94,20 +84,21 @@ type Shutdown_ struct { } // Command implements the Command interface. -func (*Shutdown_) Command() TPMCC { return TPMCCShutdown } +func (Shutdown_) Command() TPMCC { return TPMCCShutdown } // Execute executes the command and returns the response. -func (cmd *Shutdown_) Execute(t transport.TPM, s ...Session) error { +func (cmd Shutdown_) Execute(t transport.TPM, s ...Session) (*ShutdownResponse, error) { var rsp ShutdownResponse - return execute(t, cmd, &rsp, s...) + err := execute[ShutdownResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // ShutdownResponse is the response from TPM2_Shutdown. type ShutdownResponse struct{} -// Response implements the Response interface. -func (*ShutdownResponse) Response() TPMCC { return TPMCCShutdown } - // Startup_ is the input to TPM2_Startup. // See definition in Part 3, Commands, section 9.3. // TODO: Rename this to Startup after adapter.go is deleted. @@ -117,20 +108,21 @@ type Startup_ struct { } // Command implements the Command interface. -func (*Startup_) Command() TPMCC { return TPMCCStartup } +func (Startup_) Command() TPMCC { return TPMCCStartup } // Execute executes the command and returns the response. -func (cmd *Startup_) Execute(t transport.TPM, s ...Session) error { +func (cmd Startup_) Execute(t transport.TPM, s ...Session) (*StartupResponse, error) { var rsp StartupResponse - return execute(t, cmd, &rsp, s...) + err := execute[StartupResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // StartupResponse is the response from TPM2_Startup. type StartupResponse struct{} -// Response implements the Response interface. -func (*StartupResponse) Response() TPMCC { return TPMCCStartup } - // StartAuthSession is the input to TPM2_StartAuthSession. // See definition in Part 3, Commands, section 11.1 type StartAuthSession struct { @@ -158,12 +150,12 @@ type StartAuthSession struct { } // Command implements the Command interface. -func (*StartAuthSession) Command() TPMCC { return TPMCCStartAuthSession } +func (StartAuthSession) Command() TPMCC { return TPMCCStartAuthSession } // Execute executes the command and returns the response. -func (cmd *StartAuthSession) Execute(t transport.TPM, s ...Session) (*StartAuthSessionResponse, error) { +func (cmd StartAuthSession) Execute(t transport.TPM, s ...Session) (*StartAuthSessionResponse, error) { var rsp StartAuthSessionResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[StartAuthSessionResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -177,9 +169,6 @@ type StartAuthSessionResponse struct { NonceTPM TPM2BNonce } -// Response implements the Response interface. -func (*StartAuthSessionResponse) Response() TPMCC { return TPMCCStartAuthSession } - // Create is the input to TPM2_Create. // See definition in Part 3, Commands, section 12.1 type Create struct { @@ -198,12 +187,12 @@ type Create struct { } // Command implements the Command interface. -func (*Create) Command() TPMCC { return TPMCCCreate } +func (Create) Command() TPMCC { return TPMCCCreate } // Execute executes the command and returns the response. -func (cmd *Create) Execute(t transport.TPM, s ...Session) (*CreateResponse, error) { +func (cmd Create) Execute(t transport.TPM, s ...Session) (*CreateResponse, error) { var rsp CreateResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[CreateResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -216,7 +205,7 @@ type CreateResponse struct { // the public portion of the created object OutPublic TPM2BPublic // contains a TPMS_CREATION_DATA - CreationData TPM2BCreationData + CreationData tpm2bCreationData // digest of creationData using nameAlg of outPublic CreationHash TPM2BDigest // ticket used by TPM2_CertifyCreation() to validate that the @@ -224,9 +213,6 @@ type CreateResponse struct { CreationTicket TPMTTKCreation } -// Response implements the Response interface. -func (*CreateResponse) Response() TPMCC { return TPMCCCreate } - // Load is the input to TPM2_Load. // See definition in Part 3, Commands, section 12.2 type Load struct { @@ -239,12 +225,12 @@ type Load struct { } // Command implements the Command interface. -func (*Load) Command() TPMCC { return TPMCCLoad } +func (Load) Command() TPMCC { return TPMCCLoad } // Execute executes the command and returns the response. -func (cmd *Load) Execute(t transport.TPM, s ...Session) (*LoadResponse, error) { +func (cmd Load) Execute(t transport.TPM, s ...Session) (*LoadResponse, error) { var rsp LoadResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[LoadResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -258,14 +244,11 @@ type LoadResponse struct { Name TPM2BName } -// Response implements the Response interface. -func (*LoadResponse) Response() TPMCC { return TPMCCLoad } - // LoadExternal is the input to TPM2_LoadExternal. // See definition in Part 3, Commands, section 12.3 type LoadExternal struct { // the sensitive portion of the object (optional) - InPrivate *TPM2BSensitive `gotpm:"optional"` + InPrivate TPM2BSensitive `gotpm:"optional"` // the public portion of the object InPublic TPM2BPublic // hierarchy with which the object area is associated @@ -273,12 +256,12 @@ type LoadExternal struct { } // Command implements the Command interface. -func (*LoadExternal) Command() TPMCC { return TPMCCLoadExternal } +func (LoadExternal) Command() TPMCC { return TPMCCLoadExternal } // Execute executes the command and returns the response. -func (cmd *LoadExternal) Execute(t transport.TPM, s ...Session) (*LoadExternalResponse, error) { +func (cmd LoadExternal) Execute(t transport.TPM, s ...Session) (*LoadExternalResponse, error) { var rsp LoadExternalResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[LoadExternalResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -292,9 +275,6 @@ type LoadExternalResponse struct { Name TPM2BName } -// Response implements the Response interface. -func (*LoadExternalResponse) Response() TPMCC { return TPMCCLoadExternal } - // ReadPublic is the input to TPM2_ReadPublic. // See definition in Part 3, Commands, section 12.4 type ReadPublic struct { @@ -303,12 +283,12 @@ type ReadPublic struct { } // Command implements the Command interface. -func (*ReadPublic) Command() TPMCC { return TPMCCReadPublic } +func (ReadPublic) Command() TPMCC { return TPMCCReadPublic } // Execute executes the command and returns the response. -func (cmd *ReadPublic) Execute(t transport.TPM, s ...Session) (*ReadPublicResponse, error) { +func (cmd ReadPublic) Execute(t transport.TPM, s ...Session) (*ReadPublicResponse, error) { var rsp ReadPublicResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[ReadPublicResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -324,9 +304,6 @@ type ReadPublicResponse struct { QualifiedName TPM2BName } -// Response implements the Response interface. -func (*ReadPublicResponse) Response() TPMCC { return TPMCCReadPublic } - // ActivateCredential is the input to TPM2_ActivateCredential. // See definition in Part 3, Commands, section 12.5. type ActivateCredential struct { @@ -341,12 +318,12 @@ type ActivateCredential struct { } // Command implements the Command interface. -func (*ActivateCredential) Command() TPMCC { return TPMCCActivateCredential } +func (ActivateCredential) Command() TPMCC { return TPMCCActivateCredential } // Execute executes the command and returns the response. -func (cmd *ActivateCredential) Execute(t transport.TPM, s ...Session) (*ActivateCredentialResponse, error) { +func (cmd ActivateCredential) Execute(t transport.TPM, s ...Session) (*ActivateCredentialResponse, error) { var rsp ActivateCredentialResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[ActivateCredentialResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -358,9 +335,6 @@ type ActivateCredentialResponse struct { CertInfo TPM2BDigest } -// Response implements the Response interface. -func (*ActivateCredentialResponse) Response() TPMCC { return TPMCCActivateCredential } - // MakeCredential is the input to TPM2_MakeCredential. // See definition in Part 3, Commands, section 12.6. type MakeCredential struct { @@ -373,12 +347,12 @@ type MakeCredential struct { } // Command implements the Command interface. -func (*MakeCredential) Command() TPMCC { return TPMCCMakeCredential } +func (MakeCredential) Command() TPMCC { return TPMCCMakeCredential } // Execute executes the command and returns the response. -func (cmd *MakeCredential) Execute(t transport.TPM, s ...Session) (*MakeCredentialResponse, error) { +func (cmd MakeCredential) Execute(t transport.TPM, s ...Session) (*MakeCredentialResponse, error) { var rsp MakeCredentialResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[MakeCredentialResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -392,9 +366,6 @@ type MakeCredentialResponse struct { Secret TPM2BEncryptedSecret } -// Response implements the Response interface. -func (*MakeCredentialResponse) Response() TPMCC { return TPMCCMakeCredential } - // Unseal is the input to TPM2_Unseal. // See definition in Part 3, Commands, section 12.7 type Unseal struct { @@ -402,12 +373,12 @@ type Unseal struct { } // Command implements the Command interface. -func (*Unseal) Command() TPMCC { return TPMCCUnseal } +func (Unseal) Command() TPMCC { return TPMCCUnseal } // Execute executes the command and returns the response. -func (cmd *Unseal) Execute(t transport.TPM, s ...Session) (*UnsealResponse, error) { +func (cmd Unseal) Execute(t transport.TPM, s ...Session) (*UnsealResponse, error) { var rsp UnsealResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[UnsealResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -418,9 +389,6 @@ type UnsealResponse struct { OutData TPM2BSensitiveData } -// Response implements the Response interface. -func (*UnsealResponse) Response() TPMCC { return TPMCCUnseal } - // CreateLoaded is the input to TPM2_CreateLoaded. // See definition in Part 3, Commands, section 12.9 type CreateLoaded struct { @@ -434,12 +402,12 @@ type CreateLoaded struct { } // Command implements the Command interface. -func (*CreateLoaded) Command() TPMCC { return TPMCCCreateLoaded } +func (CreateLoaded) Command() TPMCC { return TPMCCCreateLoaded } // Execute executes the command and returns the response. -func (cmd *CreateLoaded) Execute(t transport.TPM, s ...Session) (*CreateLoadedResponse, error) { +func (cmd CreateLoaded) Execute(t transport.TPM, s ...Session) (*CreateLoadedResponse, error) { var rsp CreateLoadedResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[CreateLoadedResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -457,9 +425,6 @@ type CreateLoadedResponse struct { Name TPM2BName } -// Response implements the Response interface. -func (*CreateLoadedResponse) Response() TPMCC { return TPMCCCreateLoaded } - // ECDHZGen is the input to TPM2_ECDHZGen. // See definition in Part 3, Commands, section 14.5 type ECDHZGen struct { @@ -470,12 +435,12 @@ type ECDHZGen struct { } // Command implements the Command interface. -func (*ECDHZGen) Command() TPMCC { return TPMCCECDHZGen } +func (ECDHZGen) Command() TPMCC { return TPMCCECDHZGen } // Execute executes the command and returns the response. -func (cmd *ECDHZGen) Execute(t transport.TPM, s ...Session) (*ECDHZGenResponse, error) { +func (cmd ECDHZGen) Execute(t transport.TPM, s ...Session) (*ECDHZGenResponse, error) { var rsp ECDHZGenResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[ECDHZGenResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -487,9 +452,6 @@ type ECDHZGenResponse struct { OutPoint TPM2BECCPoint } -// Response implements the Response interface. -func (*ECDHZGenResponse) Response() TPMCC { return TPMCCECDHZGen } - // Hash is the input to TPM2_Hash. // See definition in Part 3, Commands, section 15.4 type Hash struct { @@ -502,12 +464,12 @@ type Hash struct { } // Command implements the Command interface. -func (*Hash) Command() TPMCC { return TPMCCHash } +func (Hash) Command() TPMCC { return TPMCCHash } // Execute executes the command and returns the response. -func (cmd *Hash) Execute(t transport.TPM, s ...Session) (*HashResponse, error) { +func (cmd Hash) Execute(t transport.TPM, s ...Session) (*HashResponse, error) { var rsp HashResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[HashResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -522,9 +484,6 @@ type HashResponse struct { Validation TPMTTKHashCheck } -// Response implements the Response interface. -func (*HashResponse) Response() TPMCC { return TPMCCHash } - // GetRandom is the input to TPM2_GetRandom. // See definition in Part 3, Commands, section 16.1 type GetRandom struct { @@ -533,12 +492,12 @@ type GetRandom struct { } // Command implements the Command interface. -func (*GetRandom) Command() TPMCC { return TPMCCGetRandom } +func (GetRandom) Command() TPMCC { return TPMCCGetRandom } // Execute executes the command and returns the response. -func (cmd *GetRandom) Execute(t transport.TPM, s ...Session) (*GetRandomResponse, error) { +func (cmd GetRandom) Execute(t transport.TPM, s ...Session) (*GetRandomResponse, error) { var rsp GetRandomResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[GetRandomResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -550,9 +509,6 @@ type GetRandomResponse struct { RandomBytes TPM2BDigest } -// Response implements the Response interface. -func (*GetRandomResponse) Response() TPMCC { return TPMCCGetRandom } - // HashSequenceStart is the input to TPM2_HashSequenceStart. // See definition in Part 3, Commands, section 17.3 type HashSequenceStart struct { @@ -564,12 +520,12 @@ type HashSequenceStart struct { } // Command implements the Command interface. -func (*HashSequenceStart) Command() TPMCC { return TPMCCHashSequenceStart } +func (HashSequenceStart) Command() TPMCC { return TPMCCHashSequenceStart } // Execute executes the command and returns the response. -func (cmd *HashSequenceStart) Execute(t transport.TPM, s ...Session) (*HashSequenceStartResponse, error) { +func (cmd HashSequenceStart) Execute(t transport.TPM, s ...Session) (*HashSequenceStartResponse, error) { var rsp HashSequenceStartResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[HashSequenceStartResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -581,9 +537,6 @@ type HashSequenceStartResponse struct { SequenceHandle TPMIDHObject } -// Response implements the Response interface. -func (*HashSequenceStartResponse) Response() TPMCC { return TPMCCHashSequenceStart } - // SequenceUpdate is the input to TPM2_SequenceUpdate. // See definition in Part 3, Commands, section 17.4 type SequenceUpdate struct { @@ -594,12 +547,12 @@ type SequenceUpdate struct { } // Command implements the Command interface. -func (*SequenceUpdate) Command() TPMCC { return TPMCCSequenceUpdate } +func (SequenceUpdate) Command() TPMCC { return TPMCCSequenceUpdate } // Execute executes the command and returns the response. -func (cmd *SequenceUpdate) Execute(t transport.TPM, s ...Session) (*SequenceUpdateResponse, error) { +func (cmd SequenceUpdate) Execute(t transport.TPM, s ...Session) (*SequenceUpdateResponse, error) { var rsp SequenceUpdateResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[SequenceUpdateResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -608,9 +561,6 @@ func (cmd *SequenceUpdate) Execute(t transport.TPM, s ...Session) (*SequenceUpda // SequenceUpdateResponse is the response from TPM2_SequenceUpdate. type SequenceUpdateResponse struct{} -// Response implements the Response interface. -func (*SequenceUpdateResponse) Response() TPMCC { return TPMCCSequenceUpdate } - // SequenceComplete is the input to TPM2_SequenceComplete. // See definition in Part 3, Commands, section 17.5 type SequenceComplete struct { @@ -623,12 +573,12 @@ type SequenceComplete struct { } // Command implements the Command interface. -func (*SequenceComplete) Command() TPMCC { return TPMCCSequenceComplete } +func (SequenceComplete) Command() TPMCC { return TPMCCSequenceComplete } // Execute executes the command and returns the response. -func (cmd *SequenceComplete) Execute(t transport.TPM, s ...Session) (*SequenceCompleteResponse, error) { +func (cmd SequenceComplete) Execute(t transport.TPM, s ...Session) (*SequenceCompleteResponse, error) { var rsp SequenceCompleteResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[SequenceCompleteResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -643,9 +593,6 @@ type SequenceCompleteResponse struct { Validation TPMTTKHashCheck } -// Response implements the Response interface. -func (*SequenceCompleteResponse) Response() TPMCC { return TPMCCSequenceComplete } - // Certify is the input to TPM2_Certify. // See definition in Part 3, Commands, section 18.2. type Certify struct { @@ -660,12 +607,12 @@ type Certify struct { } // Command implements the Command interface. -func (*Certify) Command() TPMCC { return TPMCCCertify } +func (Certify) Command() TPMCC { return TPMCCCertify } // Execute executes the command and returns the response. -func (cmd *Certify) Execute(t transport.TPM, s ...Session) (*CertifyResponse, error) { +func (cmd Certify) Execute(t transport.TPM, s ...Session) (*CertifyResponse, error) { var rsp CertifyResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[CertifyResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -679,9 +626,6 @@ type CertifyResponse struct { Signature TPMTSignature } -// Response implements the Response interface. -func (*CertifyResponse) Response() TPMCC { return TPMCCCertify } - // CertifyCreation is the input to TPM2_CertifyCreation. // See definition in Part 3, Commands, section 18.3. type CertifyCreation struct { @@ -700,12 +644,12 @@ type CertifyCreation struct { } // Command implements the Command interface. -func (*CertifyCreation) Command() TPMCC { return TPMCCCertifyCreation } +func (CertifyCreation) Command() TPMCC { return TPMCCCertifyCreation } // Execute executes the command and returns the response. -func (cmd *CertifyCreation) Execute(t transport.TPM, s ...Session) (*CertifyCreationResponse, error) { +func (cmd CertifyCreation) Execute(t transport.TPM, s ...Session) (*CertifyCreationResponse, error) { var rsp CertifyCreationResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[CertifyCreationResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -719,9 +663,6 @@ type CertifyCreationResponse struct { Signature TPMTSignature } -// Response implements the Response interface. -func (*CertifyCreationResponse) Response() TPMCC { return TPMCCCertifyCreation } - // Quote is the input to TPM2_Quote. // See definition in Part 3, Commands, section 18.4 type Quote struct { @@ -736,12 +677,12 @@ type Quote struct { } // Command implements the Command interface. -func (*Quote) Command() TPMCC { return TPMCCQuote } +func (Quote) Command() TPMCC { return TPMCCQuote } // Execute executes the command and returns the response. -func (cmd *Quote) Execute(t transport.TPM, s ...Session) (*QuoteResponse, error) { +func (cmd Quote) Execute(t transport.TPM, s ...Session) (*QuoteResponse, error) { var rsp QuoteResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[QuoteResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -755,9 +696,6 @@ type QuoteResponse struct { Signature TPMTSignature } -// Response implements the Response interface. -func (*QuoteResponse) Response() TPMCC { return TPMCCQuote } - // GetSessionAuditDigest is the input to TPM2_GetSessionAuditDigest. // See definition in Part 3, Commands, section 18.5 type GetSessionAuditDigest struct { @@ -774,12 +712,12 @@ type GetSessionAuditDigest struct { } // Command implements the Command interface. -func (*GetSessionAuditDigest) Command() TPMCC { return TPMCCGetSessionAuditDigest } +func (GetSessionAuditDigest) Command() TPMCC { return TPMCCGetSessionAuditDigest } // Execute executes the command and returns the response. -func (cmd *GetSessionAuditDigest) Execute(t transport.TPM, s ...Session) (*GetSessionAuditDigestResponse, error) { +func (cmd GetSessionAuditDigest) Execute(t transport.TPM, s ...Session) (*GetSessionAuditDigestResponse, error) { var rsp GetSessionAuditDigestResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[GetSessionAuditDigestResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -794,9 +732,6 @@ type GetSessionAuditDigestResponse struct { Signature TPMTSignature } -// Response implements the Response interface. -func (*GetSessionAuditDigestResponse) Response() TPMCC { return TPMCCGetSessionAuditDigest } - // Commit is the input to TPM2_Commit. // See definition in Part 3, Commands, section 19.2. type Commit struct { @@ -811,12 +746,12 @@ type Commit struct { } // Command implements the Command interface. -func (*Commit) Command() TPMCC { return TPMCCCommit } +func (Commit) Command() TPMCC { return TPMCCCommit } // Execute executes the command and returns the response. -func (cmd *Commit) Execute(t transport.TPM, s ...Session) (*CommitResponse, error) { +func (cmd Commit) Execute(t transport.TPM, s ...Session) (*CommitResponse, error) { var rsp CommitResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[CommitResponse](t, cmd, &rsp, s...); err != nil { return nil, err } @@ -835,9 +770,6 @@ type CommitResponse struct { Counter uint16 } -// Response implements the Response interface. -func (*CommitResponse) Response() TPMCC { return TPMCCCommit } - // VerifySignature is the input to TPM2_VerifySignature. // See definition in Part 3, Commands, section 20.1 type VerifySignature struct { @@ -850,12 +782,12 @@ type VerifySignature struct { } // Command implements the Command interface. -func (*VerifySignature) Command() TPMCC { return TPMCCVerifySignature } +func (VerifySignature) Command() TPMCC { return TPMCCVerifySignature } // Execute executes the command and returns the response. -func (cmd *VerifySignature) Execute(t transport.TPM, s ...Session) (*VerifySignatureResponse, error) { +func (cmd VerifySignature) Execute(t transport.TPM, s ...Session) (*VerifySignatureResponse, error) { var rsp VerifySignatureResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[VerifySignatureResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -866,9 +798,6 @@ type VerifySignatureResponse struct { Validation TPMTTKVerified } -// Response implements the Response interface. -func (*VerifySignatureResponse) Response() TPMCC { return TPMCCVerifySignature } - // Sign is the input to TPM2_Sign. // See definition in Part 3, Commands, section 20.2. type Sign struct { @@ -885,12 +814,12 @@ type Sign struct { } // Command implements the Command interface. -func (*Sign) Command() TPMCC { return TPMCCSign } +func (Sign) Command() TPMCC { return TPMCCSign } // Execute executes the command and returns the response. -func (cmd *Sign) Execute(t transport.TPM, s ...Session) (*SignResponse, error) { +func (cmd Sign) Execute(t transport.TPM, s ...Session) (*SignResponse, error) { var rsp SignResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[SignResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -902,9 +831,6 @@ type SignResponse struct { Signature TPMTSignature } -// Response implements the Response interface. -func (*SignResponse) Response() TPMCC { return TPMCCSign } - // PCRExtend is the input to TPM2_PCR_Extend. // See definition in Part 3, Commands, section 22.2 type PCRExtend struct { @@ -915,20 +841,21 @@ type PCRExtend struct { } // Command implements the Command interface. -func (*PCRExtend) Command() TPMCC { return TPMCCPCRExtend } +func (PCRExtend) Command() TPMCC { return TPMCCPCRExtend } // Execute executes the command and returns the response. -func (cmd *PCRExtend) Execute(t transport.TPM, s ...Session) error { +func (cmd PCRExtend) Execute(t transport.TPM, s ...Session) (*PCRExtendResponse, error) { var rsp PCRExtendResponse - return execute(t, cmd, &rsp, s...) + err := execute[PCRExtendResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // PCRExtendResponse is the response from TPM2_PCR_Extend. type PCRExtendResponse struct{} -// Response implements the Response interface. -func (*PCRExtendResponse) Response() TPMCC { return TPMCCPCRExtend } - // PCREvent is the input to TPM2_PCR_Event. // See definition in Part 3, Commands, section 22.3 type PCREvent struct { @@ -939,20 +866,21 @@ type PCREvent struct { } // Command implements the Command interface. -func (*PCREvent) Command() TPMCC { return TPMCCPCREvent } +func (PCREvent) Command() TPMCC { return TPMCCPCREvent } // Execute executes the command and returns the response. -func (cmd *PCREvent) Execute(t transport.TPM, s ...Session) error { +func (cmd PCREvent) Execute(t transport.TPM, s ...Session) (*PCREventResponse, error) { var rsp PCREventResponse - return execute(t, cmd, &rsp, s...) + err := execute[PCREventResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // PCREventResponse is the response from TPM2_PCR_Event. type PCREventResponse struct{} -// Response implements the Response interface. -func (*PCREventResponse) Response() TPMCC { return TPMCCPCREvent } - // PCRRead is the input to TPM2_PCR_Read. // See definition in Part 3, Commands, section 22.4 type PCRRead struct { @@ -961,12 +889,12 @@ type PCRRead struct { } // Command implements the Command interface. -func (*PCRRead) Command() TPMCC { return TPMCCPCRRead } +func (PCRRead) Command() TPMCC { return TPMCCPCRRead } // Execute executes the command and returns the response. -func (cmd *PCRRead) Execute(t transport.TPM, s ...Session) (*PCRReadResponse, error) { +func (cmd PCRRead) Execute(t transport.TPM, s ...Session) (*PCRReadResponse, error) { var rsp PCRReadResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[PCRReadResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -982,9 +910,6 @@ type PCRReadResponse struct { PCRValues TPMLDigest } -// Response implements the Response interface. -func (*PCRReadResponse) Response() TPMCC { return TPMCCPCRRead } - // PCRReset is the input to TPM2_PCRReset. // See definition in Part 3, Commands, section 22.8. type PCRReset struct { @@ -993,12 +918,12 @@ type PCRReset struct { } // Command implements the Command interface. -func (*PCRReset) Command() TPMCC { return TPMCCPCRReset } +func (PCRReset) Command() TPMCC { return TPMCCPCRReset } // Execute executes the command and returns the response. -func (cmd *PCRReset) Execute(t transport.TPM, s ...Session) (*PCRResetResponse, error) { +func (cmd PCRReset) Execute(t transport.TPM, s ...Session) (*PCRResetResponse, error) { var rsp PCRResetResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[PCRResetResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1007,9 +932,6 @@ func (cmd *PCRReset) Execute(t transport.TPM, s ...Session) (*PCRResetResponse, // PCRResetResponse is the response from TPM2_PCRReset. type PCRResetResponse struct{} -// Response implements the Response interface. -func (*PCRResetResponse) Response() TPMCC { return TPMCCPCRReset } - // PolicySigned is the input to TPM2_PolicySigned. // See definition in Part 3, Commands, section 23.3. type PolicySigned struct { @@ -1031,12 +953,12 @@ type PolicySigned struct { } // Command implements the Command interface. -func (*PolicySigned) Command() TPMCC { return TPMCCPolicySigned } +func (PolicySigned) Command() TPMCC { return TPMCCPolicySigned } // Execute executes the command and returns the response. -func (cmd *PolicySigned) Execute(t transport.TPM, s ...Session) (*PolicySignedResponse, error) { +func (cmd PolicySigned) Execute(t transport.TPM, s ...Session) (*PolicySignedResponse, error) { var rsp PolicySignedResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[PolicySignedResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1052,7 +974,7 @@ func policyUpdate(policy *PolicyCalculator, cc TPMCC, arg2, arg3 []byte) error { } // Update implements the PolicyCommand interface. -func (cmd *PolicySigned) Update(policy *PolicyCalculator) error { +func (cmd PolicySigned) Update(policy *PolicyCalculator) error { return policyUpdate(policy, TPMCCPolicySigned, cmd.AuthObject.KnownName().Buffer, cmd.PolicyRef.Buffer) } @@ -1064,9 +986,6 @@ type PolicySignedResponse struct { PolicyTicket TPMTTKAuth } -// Response implements the Response interface. -func (*PolicySignedResponse) Response() TPMCC { return TPMCCPolicySigned } - // PolicySecret is the input to TPM2_PolicySecret. // See definition in Part 3, Commands, section 23.4. type PolicySecret struct { @@ -1086,19 +1005,19 @@ type PolicySecret struct { } // Command implements the Command interface. -func (*PolicySecret) Command() TPMCC { return TPMCCPolicySecret } +func (PolicySecret) Command() TPMCC { return TPMCCPolicySecret } // Execute executes the command and returns the response. -func (cmd *PolicySecret) Execute(t transport.TPM, s ...Session) (*PolicySecretResponse, error) { +func (cmd PolicySecret) Execute(t transport.TPM, s ...Session) (*PolicySecretResponse, error) { var rsp PolicySecretResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[PolicySecretResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil } // Update implements the PolicyCommand interface. -func (cmd *PolicySecret) Update(policy *PolicyCalculator) { +func (cmd PolicySecret) Update(policy *PolicyCalculator) { policyUpdate(policy, TPMCCPolicySecret, cmd.AuthHandle.KnownName().Buffer, cmd.PolicyRef.Buffer) } @@ -1110,9 +1029,6 @@ type PolicySecretResponse struct { PolicyTicket TPMTTKAuth } -// Response implements the Response interface. -func (*PolicySecretResponse) Response() TPMCC { return TPMCCPolicySecret } - // PolicyOr is the input to TPM2_PolicyOR. // See definition in Part 3, Commands, section 23.6. type PolicyOr struct { @@ -1123,16 +1039,20 @@ type PolicyOr struct { } // Command implements the Command interface. -func (*PolicyOr) Command() TPMCC { return TPMCCPolicyOR } +func (PolicyOr) Command() TPMCC { return TPMCCPolicyOR } // Execute executes the command and returns the response. -func (cmd *PolicyOr) Execute(t transport.TPM, s ...Session) error { +func (cmd PolicyOr) Execute(t transport.TPM, s ...Session) (*PolicyOrResponse, error) { var rsp PolicyOrResponse - return execute(t, cmd, &rsp, s...) + err := execute[PolicyOrResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // Update implements the PolicyCommand interface. -func (cmd *PolicyOr) Update(policy *PolicyCalculator) error { +func (cmd PolicyOr) Update(policy *PolicyCalculator) error { policy.Reset() var digests bytes.Buffer for _, digest := range cmd.PHashList.Digests { @@ -1144,9 +1064,6 @@ func (cmd *PolicyOr) Update(policy *PolicyCalculator) error { // PolicyOrResponse is the response from TPM2_PolicyOr. type PolicyOrResponse struct{} -// Response implements the Response interface. -func (*PolicyOrResponse) Response() TPMCC { return TPMCCPolicyOR } - // PolicyPCR is the input to TPM2_PolicyPCR. // See definition in Part 3, Commands, section 23.7. type PolicyPCR struct { @@ -1160,25 +1077,26 @@ type PolicyPCR struct { } // Command implements the Command interface. -func (*PolicyPCR) Command() TPMCC { return TPMCCPolicyPCR } +func (PolicyPCR) Command() TPMCC { return TPMCCPolicyPCR } // Execute executes the command and returns the response. -func (cmd *PolicyPCR) Execute(t transport.TPM, s ...Session) error { +func (cmd PolicyPCR) Execute(t transport.TPM, s ...Session) (*PolicyPCRResponse, error) { var rsp PolicyPCRResponse - return execute(t, cmd, &rsp, s...) + err := execute[PolicyPCRResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // Update implements the PolicyCommand interface. -func (cmd *PolicyPCR) Update(policy *PolicyCalculator) error { +func (cmd PolicyPCR) Update(policy *PolicyCalculator) error { return policy.Update(TPMCCPolicyPCR, cmd.Pcrs, cmd.PcrDigest.Buffer) } // PolicyPCRResponse is the response from TPM2_PolicyPCR. type PolicyPCRResponse struct{} -// Response implements the Response interface. -func (*PolicyPCRResponse) Response() TPMCC { return TPMCCPolicyPCR } - // PolicyNV is the input to TPM2_PolicyNV. // See definition in Part 3, Commands, section 23.9. type PolicyNV struct { @@ -1197,16 +1115,20 @@ type PolicyNV struct { } // Command implements the Command interface. -func (*PolicyNV) Command() TPMCC { return TPMCCPolicyNV } +func (PolicyNV) Command() TPMCC { return TPMCCPolicyNV } // Execute executes the command and returns the response. -func (cmd *PolicyNV) Execute(t transport.TPM, s ...Session) error { +func (cmd PolicyNV) Execute(t transport.TPM, s ...Session) (*PolicyNVResponse, error) { var rsp PolicyNVResponse - return execute(t, cmd, &rsp, s...) + err := execute[PolicyNVResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // Update implements the PolicyCommand interface. -func (cmd *PolicyNV) Update(policy *PolicyCalculator) error { +func (cmd PolicyNV) Update(policy *PolicyCalculator) error { alg, err := policy.alg.Hash() if err != nil { return err @@ -1222,9 +1144,6 @@ func (cmd *PolicyNV) Update(policy *PolicyCalculator) error { // PolicyNVResponse is the response from TPM2_PolicyPCR. type PolicyNVResponse struct{} -// Response implements the Response interface. -func (*PolicyNVResponse) Response() TPMCC { return TPMCCPolicyNV } - // PolicyCommandCode is the input to TPM2_PolicyCommandCode. // See definition in Part 3, Commands, section 23.11. type PolicyCommandCode struct { @@ -1235,25 +1154,26 @@ type PolicyCommandCode struct { } // Command implements the Command interface. -func (*PolicyCommandCode) Command() TPMCC { return TPMCCPolicyCommandCode } +func (PolicyCommandCode) Command() TPMCC { return TPMCCPolicyCommandCode } // Execute executes the command and returns the response. -func (cmd *PolicyCommandCode) Execute(t transport.TPM, s ...Session) error { +func (cmd PolicyCommandCode) Execute(t transport.TPM, s ...Session) (*PolicyCommandCodeResponse, error) { var rsp PolicyCommandCodeResponse - return execute(t, cmd, &rsp, s...) + err := execute[PolicyCommandCodeResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // Update implements the PolicyCommand interface. -func (cmd *PolicyCommandCode) Update(policy *PolicyCalculator) error { +func (cmd PolicyCommandCode) Update(policy *PolicyCalculator) error { return policy.Update(TPMCCPolicyCommandCode, cmd.Code) } // PolicyCommandCodeResponse is the response from TPM2_PolicyCommandCode. type PolicyCommandCodeResponse struct{} -// Response implements the Response interface. -func (*PolicyCommandCodeResponse) Response() TPMCC { return TPMCCPolicyCommandCode } - // PolicyCPHash is the input to TPM2_PolicyCpHash. // See definition in Part 3, Commands, section 23.13. type PolicyCPHash struct { @@ -1264,25 +1184,26 @@ type PolicyCPHash struct { } // Command implements the Command interface. -func (*PolicyCPHash) Command() TPMCC { return TPMCCPolicyCpHash } +func (PolicyCPHash) Command() TPMCC { return TPMCCPolicyCpHash } // Execute executes the command and returns the response. -func (cmd *PolicyCPHash) Execute(t transport.TPM, s ...Session) error { +func (cmd PolicyCPHash) Execute(t transport.TPM, s ...Session) (*PolicyCPHashResponse, error) { var rsp PolicyCPHashResponse - return execute(t, cmd, &rsp, s...) + err := execute[PolicyCPHashResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // Update implements the PolicyCommand interface. -func (cmd *PolicyCPHash) Update(policy *PolicyCalculator) error { +func (cmd PolicyCPHash) Update(policy *PolicyCalculator) error { return policy.Update(TPMCCPolicyCpHash, cmd.CPHashA.Buffer) } // PolicyCPHashResponse is the response from TPM2_PolicyCpHash. type PolicyCPHashResponse struct{} -// Response implements the Response interface. -func (*PolicyCPHashResponse) Response() TPMCC { return TPMCCPolicyCpHash } - // PolicyAuthorize is the input to TPM2_PolicySigned. // See definition in Part 3, Commands, section 23.16. type PolicyAuthorize struct { @@ -1299,25 +1220,26 @@ type PolicyAuthorize struct { } // Command implements the Command interface. -func (*PolicyAuthorize) Command() TPMCC { return TPMCCPolicyAuthorize } +func (PolicyAuthorize) Command() TPMCC { return TPMCCPolicyAuthorize } // Execute executes the command and returns the response. -func (cmd *PolicyAuthorize) Execute(t transport.TPM, s ...Session) error { +func (cmd PolicyAuthorize) Execute(t transport.TPM, s ...Session) (*PolicyAuthorizeResponse, error) { var rsp PolicyAuthorizeResponse - return execute(t, cmd, &rsp, s...) + err := execute[PolicyAuthorizeResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // Update implements the PolicyCommand interface. -func (cmd *PolicyAuthorize) Update(policy *PolicyCalculator) error { +func (cmd PolicyAuthorize) Update(policy *PolicyCalculator) error { return policyUpdate(policy, TPMCCPolicyAuthorize, cmd.KeySign.Buffer, cmd.PolicyRef.Buffer) } // PolicyAuthorizeResponse is the response from TPM2_PolicyAuthorize. type PolicyAuthorizeResponse struct{} -// Response implements the Response interface. -func (*PolicyAuthorizeResponse) Response() TPMCC { return TPMCCPolicyAuthorize } - // PolicyGetDigest is the input to TPM2_PolicyGetDigest. // See definition in Part 3, Commands, section 23.19. type PolicyGetDigest struct { @@ -1326,12 +1248,12 @@ type PolicyGetDigest struct { } // Command implements the Command interface. -func (*PolicyGetDigest) Command() TPMCC { return TPMCCPolicyGetDigest } +func (PolicyGetDigest) Command() TPMCC { return TPMCCPolicyGetDigest } // Execute executes the command and returns the response. -func (cmd *PolicyGetDigest) Execute(t transport.TPM, s ...Session) (*PolicyGetDigestResponse, error) { +func (cmd PolicyGetDigest) Execute(t transport.TPM, s ...Session) (*PolicyGetDigestResponse, error) { var rsp PolicyGetDigestResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[PolicyGetDigestResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1343,9 +1265,6 @@ type PolicyGetDigestResponse struct { PolicyDigest TPM2BDigest } -// Response implements the Response interface. -func (*PolicyGetDigestResponse) Response() TPMCC { return TPMCCPolicyGetDigest } - // PolicyNVWritten is the input to TPM2_PolicyNvWritten. // See definition in Part 3, Commands, section 23.20. type PolicyNVWritten struct { @@ -1357,19 +1276,19 @@ type PolicyNVWritten struct { } // Command implements the Command interface. -func (*PolicyNVWritten) Command() TPMCC { return TPMCCPolicyNvWritten } +func (PolicyNVWritten) Command() TPMCC { return TPMCCPolicyNvWritten } // Execute executes the command and returns the response. -func (cmd *PolicyNVWritten) Execute(t transport.TPM, s ...Session) (*PolicyNVWrittenResponse, error) { +func (cmd PolicyNVWritten) Execute(t transport.TPM, s ...Session) (*PolicyNVWrittenResponse, error) { var rsp PolicyNVWrittenResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[PolicyNVWrittenResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil } // Update implements the PolicyCommand interface. -func (cmd *PolicyNVWritten) Update(policy *PolicyCalculator) error { +func (cmd PolicyNVWritten) Update(policy *PolicyCalculator) error { return policy.Update(TPMCCPolicyNvWritten, cmd.WrittenSet) } @@ -1377,9 +1296,6 @@ func (cmd *PolicyNVWritten) Update(policy *PolicyCalculator) error { type PolicyNVWrittenResponse struct { } -// Response implements the Response interface. -func (*PolicyNVWrittenResponse) Response() TPMCC { return TPMCCPolicyNvWritten } - // PolicyAuthorizeNV is the input to TPM2_PolicyAuthorizeNV. // See definition in Part 3, Commands, section 23.22. type PolicyAuthorizeNV struct { @@ -1392,16 +1308,20 @@ type PolicyAuthorizeNV struct { } // Command implements the Command interface. -func (*PolicyAuthorizeNV) Command() TPMCC { return TPMCCPolicyAuthorizeNV } +func (PolicyAuthorizeNV) Command() TPMCC { return TPMCCPolicyAuthorizeNV } // Execute executes the command and returns the response. -func (cmd *PolicyAuthorizeNV) Execute(t transport.TPM, s ...Session) error { +func (cmd PolicyAuthorizeNV) Execute(t transport.TPM, s ...Session) (*PolicyAuthorizeNVResponse, error) { var rsp PolicyAuthorizeNVResponse - return execute(t, cmd, &rsp, s...) + err := execute[PolicyAuthorizeNVResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // Update implements the PolicyCommand interface. -func (cmd *PolicyAuthorizeNV) Update(policy *PolicyCalculator) error { +func (cmd PolicyAuthorizeNV) Update(policy *PolicyCalculator) error { policy.Reset() return policy.Update(TPMCCPolicyAuthorizeNV, cmd.NVIndex.KnownName().Buffer) } @@ -1409,9 +1329,6 @@ func (cmd *PolicyAuthorizeNV) Update(policy *PolicyCalculator) error { // PolicyAuthorizeNVResponse is the response from TPM2_PolicyAuthorizeNV. type PolicyAuthorizeNVResponse struct{} -// Response implements the Response interface. -func (*PolicyAuthorizeNVResponse) Response() TPMCC { return TPMCCPolicyAuthorizeNV } - // CreatePrimary is the input to TPM2_CreatePrimary. // See definition in Part 3, Commands, section 24.1 type CreatePrimary struct { @@ -1431,12 +1348,12 @@ type CreatePrimary struct { } // Command implements the Command interface. -func (*CreatePrimary) Command() TPMCC { return TPMCCCreatePrimary } +func (CreatePrimary) Command() TPMCC { return TPMCCCreatePrimary } // Execute executes the command and returns the response. -func (cmd *CreatePrimary) Execute(t transport.TPM, s ...Session) (*CreatePrimaryResponse, error) { +func (cmd CreatePrimary) Execute(t transport.TPM, s ...Session) (*CreatePrimaryResponse, error) { var rsp CreatePrimaryResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[CreatePrimaryResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1449,7 +1366,7 @@ type CreatePrimaryResponse struct { // the public portion of the created object OutPublic TPM2BPublic // contains a TPMS_CREATION_DATA - CreationData TPM2BCreationData + CreationData tpm2bCreationData // digest of creationData using nameAlg of outPublic CreationHash TPM2BDigest // ticket used by TPM2_CertifyCreation() to validate that the @@ -1459,9 +1376,6 @@ type CreatePrimaryResponse struct { Name TPM2BName } -// Response implements the Response interface. -func (*CreatePrimaryResponse) Response() TPMCC { return TPMCCCreatePrimary } - // Clear is the input to TPM2_Clear. // See definition in Part 3, Commands, section 24.6 type Clear struct { @@ -1470,20 +1384,21 @@ type Clear struct { } // Command implements the Command interface. -func (*Clear) Command() TPMCC { return TPMCCClear } +func (Clear) Command() TPMCC { return TPMCCClear } // Execute executes the command and returns the response. -func (cmd *Clear) Execute(t transport.TPM, s ...Session) error { +func (cmd Clear) Execute(t transport.TPM, s ...Session) (*ClearResponse, error) { var rsp ClearResponse - return execute(t, cmd, &rsp, s...) + err := execute[ClearResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // ClearResponse is the response from TPM2_Clear. type ClearResponse struct{} -// Response implements the Response interface. -func (*ClearResponse) Response() TPMCC { return TPMCCClear } - // ContextSave is the input to TPM2_ContextSave. // See definition in Part 3, Commands, section 28.2 type ContextSave struct { @@ -1492,12 +1407,12 @@ type ContextSave struct { } // Command implements the Command interface. -func (*ContextSave) Command() TPMCC { return TPMCCContextSave } +func (ContextSave) Command() TPMCC { return TPMCCContextSave } // Execute executes the command and returns the response. -func (cmd *ContextSave) Execute(t transport.TPM, s ...Session) (*ContextSaveResponse, error) { +func (cmd ContextSave) Execute(t transport.TPM, s ...Session) (*ContextSaveResponse, error) { var rsp ContextSaveResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[ContextSaveResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1508,9 +1423,6 @@ type ContextSaveResponse struct { Context TPMSContext } -// Response implements the Response interface. -func (*ContextSaveResponse) Response() TPMCC { return TPMCCContextSave } - // ContextLoad is the input to TPM2_ContextLoad. // See definition in Part 3, Commands, section 28.3 type ContextLoad struct { @@ -1519,12 +1431,12 @@ type ContextLoad struct { } // Command implements the Command interface. -func (*ContextLoad) Command() TPMCC { return TPMCCContextLoad } +func (ContextLoad) Command() TPMCC { return TPMCCContextLoad } // Execute executes the command and returns the response. -func (cmd *ContextLoad) Execute(t transport.TPM, s ...Session) (*ContextLoadResponse, error) { +func (cmd ContextLoad) Execute(t transport.TPM, s ...Session) (*ContextLoadResponse, error) { var rsp ContextLoadResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[ContextLoadResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1536,9 +1448,6 @@ type ContextLoadResponse struct { LoadedHandle TPMIDHContext } -// Response implements the Response interface. -func (*ContextLoadResponse) Response() TPMCC { return TPMCCContextLoad } - // FlushContext is the input to TPM2_FlushContext. // See definition in Part 3, Commands, section 28.4 type FlushContext struct { @@ -1547,20 +1456,21 @@ type FlushContext struct { } // Command implements the Command interface. -func (*FlushContext) Command() TPMCC { return TPMCCFlushContext } +func (FlushContext) Command() TPMCC { return TPMCCFlushContext } // Execute executes the command and returns the response. -func (cmd *FlushContext) Execute(t transport.TPM, s ...Session) error { +func (cmd FlushContext) Execute(t transport.TPM, s ...Session) (*FlushContextResponse, error) { var rsp FlushContextResponse - return execute(t, cmd, &rsp, s...) + err := execute[FlushContextResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // FlushContextResponse is the response from TPM2_FlushContext. type FlushContextResponse struct{} -// Response implements the Response interface. -func (*FlushContextResponse) Response() TPMCC { return TPMCCFlushContext } - // GetCapability is the input to TPM2_GetCapability. // See definition in Part 3, Commands, section 30.2 type GetCapability struct { @@ -1573,12 +1483,12 @@ type GetCapability struct { } // Command implements the Command interface. -func (*GetCapability) Command() TPMCC { return TPMCCGetCapability } +func (GetCapability) Command() TPMCC { return TPMCCGetCapability } // Execute executes the command and returns the response. -func (cmd *GetCapability) Execute(t transport.TPM, s ...Session) (*GetCapabilityResponse, error) { +func (cmd GetCapability) Execute(t transport.TPM, s ...Session) (*GetCapabilityResponse, error) { var rsp GetCapabilityResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[GetCapabilityResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1592,9 +1502,6 @@ type GetCapabilityResponse struct { CapabilityData TPMSCapabilityData } -// Response implements the Response interface. -func (*GetCapabilityResponse) Response() TPMCC { return TPMCCGetCapability } - // NVDefineSpace is the input to TPM2_NV_DefineSpace. // See definition in Part 3, Commands, section 31.3. type NVDefineSpace struct { @@ -1607,20 +1514,21 @@ type NVDefineSpace struct { } // Command implements the Command interface. -func (*NVDefineSpace) Command() TPMCC { return TPMCCNVDefineSpace } +func (NVDefineSpace) Command() TPMCC { return TPMCCNVDefineSpace } // Execute executes the command and returns the response. -func (cmd *NVDefineSpace) Execute(t transport.TPM, s ...Session) error { +func (cmd NVDefineSpace) Execute(t transport.TPM, s ...Session) (*NVDefineSpaceResponse, error) { var rsp NVDefineSpaceResponse - return execute(t, cmd, &rsp, s...) + err := execute[NVDefineSpaceResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // NVDefineSpaceResponse is the response from TPM2_NV_DefineSpace. type NVDefineSpaceResponse struct{} -// Response implements the Response interface. -func (*NVDefineSpaceResponse) Response() TPMCC { return TPMCCNVDefineSpace } - // NVUndefineSpace is the input to TPM2_NV_UndefineSpace. // See definition in Part 3, Commands, section 31.4. type NVUndefineSpace struct { @@ -1631,20 +1539,21 @@ type NVUndefineSpace struct { } // Command implements the Command interface. -func (*NVUndefineSpace) Command() TPMCC { return TPMCCNVUndefineSpace } +func (NVUndefineSpace) Command() TPMCC { return TPMCCNVUndefineSpace } // Execute executes the command and returns the response. -func (cmd *NVUndefineSpace) Execute(t transport.TPM, s ...Session) error { +func (cmd NVUndefineSpace) Execute(t transport.TPM, s ...Session) (*NVUndefineSpaceResponse, error) { var rsp NVUndefineSpaceResponse - return execute(t, cmd, &rsp, s...) + err := execute[NVUndefineSpaceResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // NVUndefineSpaceResponse is the response from TPM2_NV_UndefineSpace. type NVUndefineSpaceResponse struct{} -// Response implements the Response interface. -func (*NVUndefineSpaceResponse) Response() TPMCC { return TPMCCNVUndefineSpace } - // NVUndefineSpaceSpecial is the input to TPM2_NV_UndefineSpaceSpecial. // See definition in Part 3, Commands, section 31.5. type NVUndefineSpaceSpecial struct { @@ -1655,20 +1564,21 @@ type NVUndefineSpaceSpecial struct { } // Command implements the Command interface. -func (*NVUndefineSpaceSpecial) Command() TPMCC { return TPMCCNVUndefineSpaceSpecial } +func (NVUndefineSpaceSpecial) Command() TPMCC { return TPMCCNVUndefineSpaceSpecial } // Execute executes the command and returns the response. -func (cmd *NVUndefineSpaceSpecial) Execute(t transport.TPM, s ...Session) error { +func (cmd NVUndefineSpaceSpecial) Execute(t transport.TPM, s ...Session) (*NVUndefineSpaceSpecialResponse, error) { var rsp NVUndefineSpaceSpecialResponse - return execute(t, cmd, &rsp, s...) + err := execute[NVUndefineSpaceSpecialResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // NVUndefineSpaceSpecialResponse is the response from TPM2_NV_UndefineSpaceSpecial. type NVUndefineSpaceSpecialResponse struct{} -// Response implements the Response interface. -func (*NVUndefineSpaceSpecialResponse) Response() TPMCC { return TPMCCNVUndefineSpaceSpecial } - // NVReadPublic is the input to TPM2_NV_ReadPublic. // See definition in Part 3, Commands, section 31.6. type NVReadPublic struct { @@ -1677,12 +1587,12 @@ type NVReadPublic struct { } // Command implements the Command interface. -func (*NVReadPublic) Command() TPMCC { return TPMCCNVReadPublic } +func (NVReadPublic) Command() TPMCC { return TPMCCNVReadPublic } // Execute executes the command and returns the response. -func (cmd *NVReadPublic) Execute(t transport.TPM, s ...Session) (*NVReadPublicResponse, error) { +func (cmd NVReadPublic) Execute(t transport.TPM, s ...Session) (*NVReadPublicResponse, error) { var rsp NVReadPublicResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[NVReadPublicResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1694,9 +1604,6 @@ type NVReadPublicResponse struct { NVName TPM2BName } -// Response implements the Response interface. -func (*NVReadPublicResponse) Response() TPMCC { return TPMCCNVReadPublic } - // NVWrite is the input to TPM2_NV_Write. // See definition in Part 3, Commands, section 31.7. type NVWrite struct { @@ -1711,20 +1618,21 @@ type NVWrite struct { } // Command implements the Command interface. -func (*NVWrite) Command() TPMCC { return TPMCCNVWrite } +func (NVWrite) Command() TPMCC { return TPMCCNVWrite } // Execute executes the command and returns the response. -func (cmd *NVWrite) Execute(t transport.TPM, s ...Session) error { +func (cmd NVWrite) Execute(t transport.TPM, s ...Session) (*NVWriteResponse, error) { var rsp NVWriteResponse - return execute(t, cmd, &rsp, s...) + err := execute[NVWriteResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // NVWriteResponse is the response from TPM2_NV_Write. type NVWriteResponse struct{} -// Response implements the Response interface. -func (*NVWriteResponse) Response() TPMCC { return TPMCCNVWrite } - // NVIncrement is the input to TPM2_NV_Increment. // See definition in Part 3, Commands, section 31.8. type NVIncrement struct { @@ -1735,20 +1643,21 @@ type NVIncrement struct { } // Command implements the Command interface. -func (*NVIncrement) Command() TPMCC { return TPMCCNVIncrement } +func (NVIncrement) Command() TPMCC { return TPMCCNVIncrement } // Execute executes the command and returns the response. -func (cmd *NVIncrement) Execute(t transport.TPM, s ...Session) error { +func (cmd NVIncrement) Execute(t transport.TPM, s ...Session) (*NVIncrementResponse, error) { var rsp NVIncrementResponse - return execute(t, cmd, &rsp, s...) + err := execute[NVIncrementResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // NVIncrementResponse is the response from TPM2_NV_Increment. type NVIncrementResponse struct{} -// Response implements the Response interface. -func (*NVIncrementResponse) Response() TPMCC { return TPMCCNVIncrement } - // NVWriteLock is the input to TPM2_NV_WriteLock. // See definition in Part 3, Commands, section 31.11. type NVWriteLock struct { @@ -1759,20 +1668,21 @@ type NVWriteLock struct { } // Command implements the Command interface. -func (*NVWriteLock) Command() TPMCC { return TPMCCNVWriteLock } +func (NVWriteLock) Command() TPMCC { return TPMCCNVWriteLock } // Execute executes the command and returns the response. -func (cmd *NVWriteLock) Execute(t transport.TPM, s ...Session) error { +func (cmd NVWriteLock) Execute(t transport.TPM, s ...Session) (*NVWriteLockResponse, error) { var rsp NVWriteLockResponse - return execute(t, cmd, &rsp, s...) + err := execute[NVWriteLockResponse](t, cmd, &rsp, s...) + if err != nil { + return nil, err + } + return &rsp, nil } // NVWriteLockResponse is the response from TPM2_NV_WriteLock. type NVWriteLockResponse struct{} -// Response implements the Response interface. -func (*NVWriteLockResponse) Response() TPMCC { return TPMCCNVWriteLock } - // NVRead is the input to TPM2_NV_Read. // See definition in Part 3, Commands, section 31.13. type NVRead struct { @@ -1787,12 +1697,12 @@ type NVRead struct { } // Command implements the Command interface. -func (*NVRead) Command() TPMCC { return TPMCCNVRead } +func (NVRead) Command() TPMCC { return TPMCCNVRead } // Execute executes the command and returns the response. -func (cmd *NVRead) Execute(t transport.TPM, s ...Session) (*NVReadResponse, error) { +func (cmd NVRead) Execute(t transport.TPM, s ...Session) (*NVReadResponse, error) { var rsp NVReadResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[NVReadResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1804,9 +1714,6 @@ type NVReadResponse struct { Data TPM2BMaxNVBuffer } -// Response implements the Response interface. -func (*NVReadResponse) Response() TPMCC { return TPMCCNVRead } - // NVCertify is the input to TPM2_NV_Certify. // See definition in Part 3, Commands, section 31.16. type NVCertify struct { @@ -1827,12 +1734,12 @@ type NVCertify struct { } // Command implements the Command interface. -func (*NVCertify) Command() TPMCC { return TPMCCNVCertify } +func (NVCertify) Command() TPMCC { return TPMCCNVCertify } // Execute executes the command and returns the response. -func (cmd *NVCertify) Execute(t transport.TPM, s ...Session) (*NVCertifyResponse, error) { +func (cmd NVCertify) Execute(t transport.TPM, s ...Session) (*NVCertifyResponse, error) { var rsp NVCertifyResponse - if err := execute(t, cmd, &rsp, s...); err != nil { + if err := execute[NVCertifyResponse](t, cmd, &rsp, s...); err != nil { return nil, err } return &rsp, nil @@ -1845,6 +1752,3 @@ type NVCertifyResponse struct { // the asymmetric signature over certifyInfo using the key referenced by signHandle Signature TPMTSignature } - -// Response implements the Response interface. -func (*NVCertifyResponse) Response() TPMCC { return TPMCCNVCertify } diff --git a/tpm2/tpm2b.go b/tpm2/tpm2b.go new file mode 100644 index 00000000..f5af16ab --- /dev/null +++ b/tpm2/tpm2b.go @@ -0,0 +1,83 @@ +package tpm2 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" +) + +// TPM2B is a helper type for all sized TPM structures. It can be instantiated with either a raw byte buffer or the actual struct. +type TPM2B[T Marshallable, P interface { + *T + Unmarshallable +}] struct { + contents *T + buffer []byte +} + +// New2B creates a new TPM2B containing the given contents. +func New2B[T Marshallable, P interface { + *T + Unmarshallable +}](t T) TPM2B[T, P] { + return TPM2B[T, P]{contents: &t} +} + +// BytesAs2B creates a new TPM2B containing the given byte array. +func BytesAs2B[T Marshallable, P interface { + *T + Unmarshallable +}](b []byte) TPM2B[T, P] { + return TPM2B[T, P]{buffer: b} +} + +// Contents returns the structured contents of the TPM2B. +// It can fail if the TPM2B was instantiated with an invalid byte buffer. +func (value *TPM2B[T, P]) Contents() (*T, error) { + if value.contents != nil { + return value.contents, nil + } + if value.buffer == nil { + return nil, fmt.Errorf("TPMB had no contents or buffer") + } + contents, err := Unmarshal[T, P](value.buffer) + if err != nil { + return nil, err + } + // Cache the result + value.contents = (*T)(contents) + return value.contents, nil +} + +// Bytes returns the inner contents of the TPM2B as a byte array, not including the length field. +func (value *TPM2B[T, P]) Bytes() []byte { + if value.buffer != nil { + return value.buffer + } + if value.contents == nil { + return []byte{} + } + + // Cache the result + value.buffer = Marshal(*value.contents) + return value.buffer +} + +// marshal implements the tpm2.Marshallable interface. +func (value TPM2B[T, P]) marshal(buf *bytes.Buffer) { + b := value.Bytes() + binary.Write(buf, binary.BigEndian, uint16(len(b))) + buf.Write(b) +} + +// unmarshal implements the tpm2.Unmarshallable interface. +// Note: the structure contents are not validated during unmarshalling. +func (value *TPM2B[T, P]) unmarshal(buf *bytes.Buffer) error { + var size uint16 + binary.Read(buf, binary.BigEndian, &size) + value.contents = nil + value.buffer = make([]byte, size) + _, err := io.ReadAtLeast(buf, value.buffer, int(size)) + return err +} diff --git a/tpm2/wrappers.go b/tpm2/wrappers.go deleted file mode 100644 index 20e4aae5..00000000 --- a/tpm2/wrappers.go +++ /dev/null @@ -1,10 +0,0 @@ -package tpm2 - -// This file provides wrapper functions for concrete types used by tpm2, for -// setting union member pointers. - -// NewKeyBits allocates and returns the address of a new TPMKeyBits. -func NewKeyBits(v TPMKeyBits) *TPMKeyBits { return &v } - -// NewAlgID allocates and returns the address of a new TPMAlgID. -func NewAlgID(v TPMAlgID) *TPMAlgID { return &v }