Skip to content

Commit c11e706

Browse files
authored
Merge pull request #534 from smallstep/mariano/validity
Add support validities in templates
2 parents 5a67b5f + 4a9ee47 commit c11e706

File tree

10 files changed

+232
-28
lines changed

10 files changed

+232
-28
lines changed

internal/templates/funcmap.go

+35-4
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@ package templates
33
import (
44
"errors"
55
"text/template"
6+
"time"
67

78
"github.com/Masterminds/sprig/v3"
9+
"go.step.sm/crypto/jose"
810
)
911

10-
// GetFuncMap returns the list of functions provided by sprig. It changes the
11-
// function "fail" to set the given string, this way we can report template
12-
// errors directly to the template without having the wrapper that text/template
13-
// adds.
12+
// GetFuncMap returns the list of functions provided by sprig. It adds the
13+
// function "toTime" and changes the function "fail".
14+
//
15+
// The "toTime" function receives a time or a Unix epoch and formats it to
16+
// RFC3339 in UTC. The "fail" function sets the provided message, so that
17+
// template errors are reported directly to the template without having the
18+
// wrapper that text/template adds.
1419
//
1520
// sprig "env" and "expandenv" functions are removed to avoid the leak of
1621
// information.
@@ -22,5 +27,31 @@ func GetFuncMap(failMessage *string) template.FuncMap {
2227
*failMessage = msg
2328
return "", errors.New(msg)
2429
}
30+
m["toTime"] = toTime
2531
return m
2632
}
33+
34+
func toTime(v any) string {
35+
var t time.Time
36+
switch date := v.(type) {
37+
case time.Time:
38+
t = date
39+
case *time.Time:
40+
t = *date
41+
case int64:
42+
t = time.Unix(date, 0)
43+
case float64: // from json
44+
t = time.Unix(int64(date), 0)
45+
case int:
46+
t = time.Unix(int64(date), 0)
47+
case int32:
48+
t = time.Unix(int64(date), 0)
49+
case jose.NumericDate:
50+
t = date.Time()
51+
case *jose.NumericDate:
52+
t = date.Time()
53+
default:
54+
t = time.Now()
55+
}
56+
return t.UTC().Format(time.RFC3339)
57+
}

internal/templates/funcmap_test.go

+53
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ package templates
33
import (
44
"errors"
55
"testing"
6+
"time"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
"go.step.sm/crypto/jose"
611
)
712

813
func Test_GetFuncMap_fail(t *testing.T) {
@@ -20,3 +25,51 @@ func Test_GetFuncMap_fail(t *testing.T) {
2025
t.Errorf("fail() message = \"%s\", want \"the fail message\"", failMesage)
2126
}
2227
}
28+
29+
func TestGetFuncMap_toTime(t *testing.T) {
30+
now := time.Now()
31+
numericDate := jose.NewNumericDate(now)
32+
expected := now.UTC().Format(time.RFC3339)
33+
loc, err := time.LoadLocation("America/Los_Angeles")
34+
require.NoError(t, err)
35+
36+
type args struct {
37+
v any
38+
}
39+
tests := []struct {
40+
name string
41+
args args
42+
want string
43+
}{
44+
{"time", args{now}, expected},
45+
{"time pointer", args{&now}, expected},
46+
{"time UTC", args{now.UTC()}, expected},
47+
{"time with location", args{now.In(loc)}, expected},
48+
{"unix", args{now.Unix()}, expected},
49+
{"unix int", args{int(now.Unix())}, expected},
50+
{"unix int32", args{int32(now.Unix())}, expected},
51+
{"unix float64", args{float64(now.Unix())}, expected},
52+
{"unix float64", args{float64(now.Unix()) + 0.999}, expected},
53+
{"jose.NumericDate", args{*numericDate}, expected},
54+
{"jose.NumericDate pointer", args{numericDate}, expected},
55+
}
56+
for _, tt := range tests {
57+
t.Run(tt.name, func(t *testing.T) {
58+
var failMesage string
59+
fns := GetFuncMap(&failMesage)
60+
fn := fns["toTime"].(func(any) string)
61+
assert.Equal(t, tt.want, fn(tt.args.v))
62+
})
63+
}
64+
65+
t.Run("default", func(t *testing.T) {
66+
var failMesage string
67+
fns := GetFuncMap(&failMesage)
68+
fn := fns["toTime"].(func(any) string)
69+
want := time.Now()
70+
got, err := time.Parse(time.RFC3339, fn(nil))
71+
require.NoError(t, err)
72+
assert.WithinDuration(t, want, got, time.Second)
73+
assert.Equal(t, time.UTC, got.Location())
74+
})
75+
}

sshutil/certificate.go

+12-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"crypto/rand"
77
"encoding/binary"
88
"encoding/json"
9+
"time"
910

1011
"github.com/pkg/errors"
1112
"go.step.sm/crypto/randutil"
@@ -20,8 +21,8 @@ type Certificate struct {
2021
Type CertType `json:"type"`
2122
KeyID string `json:"keyId"`
2223
Principals []string `json:"principals"`
23-
ValidAfter uint64 `json:"-"`
24-
ValidBefore uint64 `json:"-"`
24+
ValidAfter time.Time `json:"validAfter"`
25+
ValidBefore time.Time `json:"validBefore"`
2526
CriticalOptions map[string]string `json:"criticalOptions"`
2627
Extensions map[string]string `json:"extensions"`
2728
Reserved []byte `json:"reserved"`
@@ -62,8 +63,8 @@ func (c *Certificate) GetCertificate() *ssh.Certificate {
6263
CertType: uint32(c.Type),
6364
KeyId: c.KeyID,
6465
ValidPrincipals: c.Principals,
65-
ValidAfter: c.ValidAfter,
66-
ValidBefore: c.ValidBefore,
66+
ValidAfter: toValidity(c.ValidAfter),
67+
ValidBefore: toValidity(c.ValidBefore),
6768
Permissions: ssh.Permissions{
6869
CriticalOptions: c.CriticalOptions,
6970
Extensions: c.Extensions,
@@ -124,3 +125,10 @@ func CreateCertificate(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certifica
124125

125126
return cert, nil
126127
}
128+
129+
func toValidity(t time.Time) uint64 {
130+
if t.IsZero() {
131+
return 0
132+
}
133+
return uint64(t.Unix())
134+
}

sshutil/certificate_test.go

+65-20
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import (
1010
"io"
1111
"reflect"
1212
"testing"
13+
"time"
1314

15+
"github.com/stretchr/testify/assert"
1416
"golang.org/x/crypto/ssh"
1517
)
1618

@@ -71,6 +73,7 @@ func mustGeneratePublicKey(t *testing.T) ssh.PublicKey {
7173
}
7274

7375
func TestNewCertificate(t *testing.T) {
76+
now := time.Now().Truncate(time.Second)
7477
key := mustGeneratePublicKey(t)
7578
cr := CertificateRequest{
7679
Key: key,
@@ -100,8 +103,8 @@ func TestNewCertificate(t *testing.T) {
100103
Type: UserCert,
101104
102105
Principals: []string{"jane"},
103-
ValidAfter: 0,
104-
ValidBefore: 0,
106+
ValidAfter: time.Time{},
107+
ValidBefore: time.Time{},
105108
CriticalOptions: nil,
106109
Extensions: map[string]string{
107110
"permit-X11-forwarding": "",
@@ -121,8 +124,8 @@ func TestNewCertificate(t *testing.T) {
121124
Type: HostCert,
122125
KeyID: "foobar",
123126
Principals: []string{"foo.internal", "bar.internal"},
124-
ValidAfter: 0,
125-
ValidBefore: 0,
127+
ValidAfter: time.Time{},
128+
ValidBefore: time.Time{},
126129
CriticalOptions: nil,
127130
Extensions: nil,
128131
Reserved: nil,
@@ -136,8 +139,8 @@ func TestNewCertificate(t *testing.T) {
136139
Type: HostCert,
137140
KeyID: `foobar", "criticalOptions": {"foo": "bar"},"foo":"`,
138141
Principals: []string{"foo.internal", "bar.internal"},
139-
ValidAfter: 0,
140-
ValidBefore: 0,
142+
ValidAfter: time.Time{},
143+
ValidBefore: time.Time{},
141144
CriticalOptions: nil,
142145
Extensions: nil,
143146
Reserved: nil,
@@ -159,8 +162,8 @@ func TestNewCertificate(t *testing.T) {
159162
Type: UserCert,
160163
161164
Principals: []string{"john", "[email protected]"},
162-
ValidAfter: 0,
163-
ValidBefore: 0,
165+
ValidAfter: time.Time{},
166+
ValidBefore: time.Time{},
164167
CriticalOptions: nil,
165168
Extensions: map[string]string{
166169
"[email protected]": "john",
@@ -174,15 +177,47 @@ func TestNewCertificate(t *testing.T) {
174177
SignatureKey: nil,
175178
Signature: nil,
176179
}, false},
180+
{"file with dates", args{cr, []Option{WithTemplateFile("./testdata/date.tpl", TemplateData{
181+
TypeKey: UserCert,
182+
KeyIDKey: "[email protected]",
183+
PrincipalsKey: []string{"john", "[email protected]"},
184+
ExtensionsKey: DefaultExtensions(UserCert),
185+
InsecureKey: TemplateData{
186+
"User": map[string]interface{}{"username": "john"},
187+
},
188+
WebhooksKey: TemplateData{
189+
"Test": map[string]interface{}{"validity": "16h"},
190+
},
191+
})}}, &Certificate{
192+
Nonce: nil,
193+
Key: key,
194+
Serial: 0,
195+
Type: UserCert,
196+
197+
Principals: []string{"john", "[email protected]"},
198+
ValidAfter: now,
199+
ValidBefore: now.Add(16 * time.Hour),
200+
CriticalOptions: nil,
201+
Extensions: map[string]string{
202+
"permit-X11-forwarding": "",
203+
"permit-agent-forwarding": "",
204+
"permit-port-forwarding": "",
205+
"permit-pty": "",
206+
"permit-user-rc": "",
207+
},
208+
Reserved: nil,
209+
SignatureKey: nil,
210+
Signature: nil,
211+
}, false},
177212
{"base64", args{cr, []Option{WithTemplateBase64(base64.StdEncoding.EncodeToString([]byte(DefaultTemplate)), CreateTemplateData(HostCert, "foo.internal", nil))}}, &Certificate{
178213
Nonce: nil,
179214
Key: key,
180215
Serial: 0,
181216
Type: HostCert,
182217
KeyID: "foo.internal",
183218
Principals: nil,
184-
ValidAfter: 0,
185-
ValidBefore: 0,
219+
ValidAfter: time.Time{},
220+
ValidBefore: time.Time{},
186221
CriticalOptions: nil,
187222
Extensions: nil,
188223
Reserved: nil,
@@ -203,6 +238,15 @@ func TestNewCertificate(t *testing.T) {
203238
t.Errorf("NewCertificate() error = %v, wantErr %v", err, tt.wantErr)
204239
return
205240
}
241+
if got != nil && tt.want != nil {
242+
if assert.WithinDuration(t, tt.want.ValidAfter, got.ValidAfter, 2*time.Second) {
243+
tt.want.ValidAfter = got.ValidAfter
244+
}
245+
if assert.WithinDuration(t, tt.want.ValidBefore, got.ValidBefore, 2*time.Second) {
246+
tt.want.ValidBefore = got.ValidBefore
247+
}
248+
249+
}
206250
if !reflect.DeepEqual(got, tt.want) {
207251
t.Errorf("NewCertificate() = %v, want %v", got, tt.want)
208252
}
@@ -212,6 +256,7 @@ func TestNewCertificate(t *testing.T) {
212256

213257
func TestCertificate_GetCertificate(t *testing.T) {
214258
key := mustGeneratePublicKey(t)
259+
now := time.Now()
215260

216261
type fields struct {
217262
Nonce []byte
@@ -220,8 +265,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
220265
Type CertType
221266
KeyID string
222267
Principals []string
223-
ValidAfter uint64
224-
ValidBefore uint64
268+
ValidAfter time.Time
269+
ValidBefore time.Time
225270
CriticalOptions map[string]string
226271
Extensions map[string]string
227272
Reserved []byte
@@ -240,8 +285,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
240285
Type: UserCert,
241286
KeyID: "key-id",
242287
Principals: []string{"john"},
243-
ValidAfter: 1111,
244-
ValidBefore: 2222,
288+
ValidAfter: now,
289+
ValidBefore: now.Add(time.Hour),
245290
CriticalOptions: map[string]string{"foo": "bar"},
246291
Extensions: map[string]string{"[email protected]": "john"},
247292
Reserved: []byte("reserved"),
@@ -254,8 +299,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
254299
CertType: ssh.UserCert,
255300
KeyId: "key-id",
256301
ValidPrincipals: []string{"john"},
257-
ValidAfter: 1111,
258-
ValidBefore: 2222,
302+
ValidAfter: uint64(now.Unix()),
303+
ValidBefore: uint64(now.Add(time.Hour).Unix()),
259304
Permissions: ssh.Permissions{
260305
CriticalOptions: map[string]string{"foo": "bar"},
261306
Extensions: map[string]string{"[email protected]": "john"},
@@ -269,8 +314,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
269314
Type: HostCert,
270315
KeyID: "key-id",
271316
Principals: []string{"foo.internal", "bar.internal"},
272-
ValidAfter: 1111,
273-
ValidBefore: 2222,
317+
ValidAfter: time.Time{},
318+
ValidBefore: time.Time{},
274319
CriticalOptions: map[string]string{"foo": "bar"},
275320
Extensions: nil,
276321
Reserved: []byte("reserved"),
@@ -283,8 +328,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
283328
CertType: ssh.HostCert,
284329
KeyId: "key-id",
285330
ValidPrincipals: []string{"foo.internal", "bar.internal"},
286-
ValidAfter: 1111,
287-
ValidBefore: 2222,
331+
ValidAfter: 0,
332+
ValidBefore: 0,
288333
Permissions: ssh.Permissions{
289334
CriticalOptions: map[string]string{"foo": "bar"},
290335
Extensions: nil,

sshutil/testdata/date.tpl

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"type": "{{ .Type }}",
3+
"keyId": "{{ .KeyID }}",
4+
"principals": {{ toJson .Principals }},
5+
"extensions": {{ toJson .Extensions }},
6+
"validAfter": {{ now | toJson }},
7+
"validBefore": {{ now | dateModify .Webhooks.Test.validity | toJson }}
8+
}

x509util/certificate.go

+7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"crypto/rand"
88
"crypto/x509"
99
"encoding/json"
10+
"time"
1011

1112
"github.com/pkg/errors"
1213
)
@@ -23,6 +24,8 @@ type Certificate struct {
2324
IPAddresses MultiIP `json:"ipAddresses"`
2425
URIs MultiURL `json:"uris"`
2526
SANs []SubjectAlternativeName `json:"sans"`
27+
NotBefore time.Time `json:"notBefore"`
28+
NotAfter time.Time `json:"notAfter"`
2629
Extensions []Extension `json:"extensions"`
2730
KeyUsage KeyUsage `json:"keyUsage"`
2831
ExtKeyUsage ExtKeyUsage `json:"extKeyUsage"`
@@ -165,6 +168,10 @@ func (c *Certificate) GetCertificate() *x509.Certificate {
165168
e.Set(cert)
166169
}
167170

171+
// Validity bounds.
172+
cert.NotBefore = c.NotBefore
173+
cert.NotAfter = c.NotAfter
174+
168175
// Others.
169176
c.SerialNumber.Set(cert)
170177
c.SignatureAlgorithm.Set(cert)

0 commit comments

Comments
 (0)