diff --git a/Makefile b/Makefile index 1ea22d00..bd737d93 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,8 @@ LINT_NEW_FROM ?= origin/main WK_CLIENT_ID ?= $(GOG_CLIENT_ID) WK_CLIENT_SECRET ?= $(GOG_CLIENT_SECRET) WK_CALLBACK_SERVER ?= $(GOG_CALLBACK_SERVER) +WK_M365_CLIENT_ID ?= $(GOG_M365_CLIENT_ID) +WK_M365_TENANT_ID ?= $(GOG_M365_TENANT_ID) # Allow passing CLI args as extra "targets": # make workit -- --help @@ -60,7 +62,9 @@ build-internal: @go build -ldflags "$(LDFLAGS) \ -X 'github.com/automagik-dev/workit/internal/config.DefaultClientID=$(WK_CLIENT_ID)' \ -X 'github.com/automagik-dev/workit/internal/config.DefaultClientSecret=$(WK_CLIENT_SECRET)' \ - -X 'github.com/automagik-dev/workit/internal/config.DefaultCallbackServer=$(WK_CALLBACK_SERVER)'" \ + -X 'github.com/automagik-dev/workit/internal/config.DefaultCallbackServer=$(WK_CALLBACK_SERVER)' \ + -X 'github.com/automagik-dev/workit/internal/config.DefaultM365ClientID=$(WK_M365_CLIENT_ID)' \ + -X 'github.com/automagik-dev/workit/internal/config.DefaultM365TenantID=$(WK_M365_TENANT_ID)'" \ -o $(BIN) $(CMD) # Build with credentials from ~/.config/workit/credentials.env (WK_* primary contract). @@ -73,20 +77,28 @@ build-automagik: wk_client_id="$${WK_CLIENT_ID:-$${GOG_CLIENT_ID}}" && \ wk_client_secret="$${WK_CLIENT_SECRET:-$${GOG_CLIENT_SECRET}}" && \ wk_callback_server="$${WK_CALLBACK_SERVER:-$${GOG_CALLBACK_SERVER}}" && \ + wk_m365_client_id="$${WK_M365_CLIENT_ID:-$${GOG_M365_CLIENT_ID}}" && \ + wk_m365_tenant_id="$${WK_M365_TENANT_ID:-$${GOG_M365_TENANT_ID}}" && \ go build -ldflags "$(LDFLAGS) \ -X 'github.com/automagik-dev/workit/internal/config.DefaultClientID=$$wk_client_id' \ -X 'github.com/automagik-dev/workit/internal/config.DefaultClientSecret=$$wk_client_secret' \ - -X 'github.com/automagik-dev/workit/internal/config.DefaultCallbackServer=$$wk_callback_server'" \ + -X 'github.com/automagik-dev/workit/internal/config.DefaultCallbackServer=$$wk_callback_server' \ + -X 'github.com/automagik-dev/workit/internal/config.DefaultM365ClientID=$$wk_m365_client_id' \ + -X 'github.com/automagik-dev/workit/internal/config.DefaultM365TenantID=$$wk_m365_tenant_id'" \ -o $(BIN) $(CMD); \ elif [ -f "$(HOME)/.config/gog/credentials.env" ]; then \ . $(HOME)/.config/gog/credentials.env && \ wk_client_id="$${WK_CLIENT_ID:-$${GOG_CLIENT_ID}}" && \ wk_client_secret="$${WK_CLIENT_SECRET:-$${GOG_CLIENT_SECRET}}" && \ wk_callback_server="$${WK_CALLBACK_SERVER:-$${GOG_CALLBACK_SERVER}}" && \ + wk_m365_client_id="$${WK_M365_CLIENT_ID:-$${GOG_M365_CLIENT_ID}}" && \ + wk_m365_tenant_id="$${WK_M365_TENANT_ID:-$${GOG_M365_TENANT_ID}}" && \ go build -ldflags "$(LDFLAGS) \ -X 'github.com/automagik-dev/workit/internal/config.DefaultClientID=$$wk_client_id' \ -X 'github.com/automagik-dev/workit/internal/config.DefaultClientSecret=$$wk_client_secret' \ - -X 'github.com/automagik-dev/workit/internal/config.DefaultCallbackServer=$$wk_callback_server'" \ + -X 'github.com/automagik-dev/workit/internal/config.DefaultCallbackServer=$$wk_callback_server' \ + -X 'github.com/automagik-dev/workit/internal/config.DefaultM365ClientID=$$wk_m365_client_id' \ + -X 'github.com/automagik-dev/workit/internal/config.DefaultM365TenantID=$$wk_m365_tenant_id'" \ -o $(BIN) $(CMD); \ else \ echo "Missing credentials file: $(HOME)/.config/workit/credentials.env"; \ diff --git a/internal/cmd/auth.go b/internal/cmd/auth.go index 448ecb6f..789cd83d 100644 --- a/internal/cmd/auth.go +++ b/internal/cmd/auth.go @@ -28,6 +28,8 @@ var ( checkRefreshToken = googleauth.CheckRefreshToken ensureKeychainAccess = secrets.EnsureKeychainAccess fetchAuthorizedEmail = googleauth.EmailForRefreshToken + authorizeM365 = msauth.Authorize + m365ManualAuthURL = msauth.ManualAuthURL headlessAuthorize = googleauth.HeadlessAuthorize pollForToken = googleauth.PollForToken callbackServerURLFn = googleauth.CallbackServerURL @@ -505,6 +507,10 @@ type AuthAddCmd struct { func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error { u := ui.FromContext(ctx) + if isM365ServicesCSV(c.ServicesCSV) { + return c.runM365(ctx, flags, u) + } + override := authclient.ClientOverrideFromContext(ctx) client, err := authclient.ResolveClientWithOverride(c.Email, override) if err != nil { @@ -1439,6 +1445,10 @@ type AuthManageCmd struct { } func (c *AuthManageCmd) Run(ctx context.Context, _ *RootFlags) error { + if isM365ServicesCSV(c.ServicesCSV) { + return c.runM365(ctx) + } + services, err := parseAuthServices(c.ServicesCSV) if err != nil { return err diff --git a/internal/cmd/auth_m365.go b/internal/cmd/auth_m365.go new file mode 100644 index 00000000..ce20e8d8 --- /dev/null +++ b/internal/cmd/auth_m365.go @@ -0,0 +1,119 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "sort" + "strings" + + "github.com/automagik-dev/workit/internal/msauth" + "github.com/automagik-dev/workit/internal/outfmt" + "github.com/automagik-dev/workit/internal/secrets" + "github.com/automagik-dev/workit/internal/ui" +) + +func isM365ServicesCSV(value string) bool { + parts := strings.Split(value, ",") + if len(parts) == 0 { + return false + } + for _, part := range parts { + if strings.ToLower(strings.TrimSpace(part)) != "m365" { + return false + } + } + return true +} + +func (c *AuthAddCmd) runM365(ctx context.Context, flags *RootFlags, u *ui.UI) error { + if !c.Readonly { + return usage("m365 auth requires explicit --readonly") + } + if c.Headless || c.NoPoll || c.CallbackServer != "" { + return usage("m365 auth uses browser OAuth; headless callback-server mode is not supported yet") + } + if c.AuthCode != "" { + return usage("m365 auth does not accept raw --auth-code; use browser OAuth") + } + if c.Step != 0 && c.Step != 1 && c.Step != 2 { + return usage("step must be 1 or 2") + } + if c.Step != 0 && !c.Remote { + return usage("--step requires --remote") + } + if c.Remote || c.Step != 0 || c.AuthURL != "" { + return usage("m365 remote auth is not supported yet; use browser OAuth on this machine") + } + if dryRunErr := dryRunExit(ctx, flags, "auth.add.m365", map[string]any{ + "email": strings.TrimSpace(c.Email), + "provider": "microsoft_graph", + "services": []string{"m365"}, + "scopes": msauth.PilotAllowedScopes(), + "readonly": c.Readonly, + }); dryRunErr != nil { + return dryRunErr + } + if keychainErr := ensureKeychainAccessIfNeeded(); keychainErr != nil { + return fmt.Errorf("keychain access: %w", keychainErr) + } + result, err := authorizeM365(ctx, msauth.AuthorizeOptions{ + ExpectedEmail: strings.TrimSpace(c.Email), + Readonly: c.Readonly, + ForceConsent: c.ForceConsent, + Timeout: c.Timeout, + }) + if err != nil { + return err + } + if normalizeEmail(result.Email) != normalizeEmail(c.Email) { + return fmt.Errorf("authorized as %s, expected %s", result.Email, c.Email) + } + return storeM365Token(ctx, u, result.Email, result.RefreshToken) +} + +func (c *AuthManageCmd) runM365(ctx context.Context) error { + if !outfmt.IsJSON(ctx) && !c.PrintURL { + return usage("m365 auth manage requires --print-url") + } + + result, err := m365ManualAuthURL(ctx, msauth.ManualAuthURLOptions{Readonly: true, ForceConsent: c.ForceConsent}) + if err != nil { + return err + } + if outfmt.IsJSON(ctx) || c.PrintURL { + return outfmt.WriteJSON(ctx, os.Stdout, map[string]any{ + "provider": "microsoft_graph", + "auth_url": result.URL, + "state": result.State, + "expires_in": result.ExpiresIn, + }) + } + return nil +} + +func storeM365Token(ctx context.Context, u *ui.UI, email string, refreshToken string) error { + store, err := openSecretsStore() + if err != nil { + return err + } + serviceNames := []string{"m365"} + scopes := append([]string(nil), msauth.PilotAllowedScopes()...) + sort.Strings(scopes) + if err := store.MergeToken(msauth.ClientName, email, secrets.Token{ + Client: msauth.ClientName, + Email: email, + Services: serviceNames, + Scopes: scopes, + RefreshToken: refreshToken, + }); err != nil { + return err + } + return writeResult(ctx, u, + kv("stored", true), + kv("provider", "microsoft_graph"), + kv("email", email), + kv("services", serviceNames), + kv("client", msauth.ClientName), + ) +} diff --git a/internal/cmd/m365_auth_dryrun_test.go b/internal/cmd/m365_auth_dryrun_test.go new file mode 100644 index 00000000..91566bd1 --- /dev/null +++ b/internal/cmd/m365_auth_dryrun_test.go @@ -0,0 +1,50 @@ +package cmd + +import ( + "os" + "strings" + "testing" + + "github.com/automagik-dev/workit/internal/config" +) + +func TestAuthAddM365DryRunReportsPilotScopes(t *testing.T) { + out := captureStdout(t, func() { + _ = captureStderr(t, func() { + if err := Execute([]string{"--json", "--dry-run", "auth", "add", "pilot@example.com", "--services", "m365", "--readonly"}); err != nil { + t.Fatalf("dry-run m365 auth: %v", err) + } + }) + }) + + for _, want := range []string{"auth.add.m365", "microsoft_graph", "User.Read", "Mail.Read", "Calendars.Read"} { + if !strings.Contains(out, want) { + t.Fatalf("dry-run output missing %s: %s", want, out) + } + } +} + +func TestAuthAddM365RealFlowFailsClosedWithoutClientID(t *testing.T) { + origClientID := config.DefaultM365ClientID + origEnv, hadEnv := os.LookupEnv("WK_M365_CLIENT_ID") + t.Cleanup(func() { + config.DefaultM365ClientID = origClientID + if hadEnv { + _ = os.Setenv("WK_M365_CLIENT_ID", origEnv) + } else { + _ = os.Unsetenv("WK_M365_CLIENT_ID") + } + }) + config.DefaultM365ClientID = "" + _ = os.Unsetenv("WK_M365_CLIENT_ID") + + _ = captureStderr(t, func() { + err := Execute([]string{"--json", "auth", "add", "pilot@example.com", "--services", "m365", "--readonly"}) + if err == nil { + t.Fatal("expected missing m365 client id") + } + if !strings.Contains(err.Error(), "client id") { + t.Fatalf("unexpected error: %v", err) + } + }) +} diff --git a/internal/cmd/m365_auth_edges_test.go b/internal/cmd/m365_auth_edges_test.go new file mode 100644 index 00000000..80ec76de --- /dev/null +++ b/internal/cmd/m365_auth_edges_test.go @@ -0,0 +1,41 @@ +package cmd + +import ( + "strings" + "testing" +) + +func TestAuthAddM365RejectsUnsupportedModes(t *testing.T) { + tests := []struct { + args []string + want string + }{ + {[]string{"--json", "auth", "add", "pilot@example.com", "--services", "m365", "--readonly", "--headless"}, "headless callback-server mode is not supported yet"}, + {[]string{"--json", "auth", "add", "pilot@example.com", "--services", "m365", "--readonly", "--auth-code", "raw"}, "does not accept raw --auth-code"}, + {[]string{"--json", "auth", "add", "pilot@example.com", "--services", "m365", "--readonly", "--remote", "--step", "2"}, "remote auth is not supported yet"}, + } + + for _, tc := range tests { + _ = captureStderr(t, func() { + err := Execute(tc.args) + if err == nil { + t.Fatalf("expected error for %#v", tc.args) + } + if !strings.Contains(err.Error(), tc.want) { + t.Fatalf("expected %q for %#v, got %v", tc.want, tc.args, err) + } + }) + } +} + +func TestAuthAddMixedM365AndGoogleFailsClosed(t *testing.T) { + _ = captureStderr(t, func() { + err := Execute([]string{"--json", "auth", "add", "pilot@example.com", "--services", "m365,gmail", "--readonly"}) + if err == nil { + t.Fatal("expected mixed m365/google services to fail closed") + } + if !strings.Contains(err.Error(), "unknown service") { + t.Fatalf("unexpected error: %v", err) + } + }) +} diff --git a/internal/cmd/m365_auth_more_test.go b/internal/cmd/m365_auth_more_test.go new file mode 100644 index 00000000..4d7a831c --- /dev/null +++ b/internal/cmd/m365_auth_more_test.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "strings" + "testing" + + "github.com/automagik-dev/workit/internal/msauth" +) + +func TestAuthManageM365PrintURLPropagatesForceConsent(t *testing.T) { + origURL := m365ManualAuthURL + t.Cleanup(func() { m365ManualAuthURL = origURL }) + + var got msauth.ManualAuthURLOptions + m365ManualAuthURL = func(_ context.Context, opts msauth.ManualAuthURLOptions) (msauth.ManualAuthURLResult, error) { + got = opts + return msauth.ManualAuthURLResult{URL: "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize", State: "state"}, nil + } + + _ = captureStdout(t, func() { + _ = captureStderr(t, func() { + if err := Execute([]string{"--json", "auth", "manage", "--services", "m365", "--force-consent", "--print-url"}); err != nil { + t.Fatalf("auth manage m365: %v", err) + } + }) + }) + + if !got.Readonly || !got.ForceConsent { + t.Fatalf("options = %#v", got) + } +} + +func TestAuthManageM365RequiresPrintURLForTextMode(t *testing.T) { + _ = captureStderr(t, func() { + err := Execute([]string{"auth", "manage", "--services", "m365"}) + if err == nil { + t.Fatal("expected m365 manage text mode to fail closed") + } + if !strings.Contains(err.Error(), "requires --print-url") { + t.Fatalf("unexpected error: %v", err) + } + }) +} diff --git a/internal/cmd/m365_auth_test.go b/internal/cmd/m365_auth_test.go new file mode 100644 index 00000000..f2764e02 --- /dev/null +++ b/internal/cmd/m365_auth_test.go @@ -0,0 +1,108 @@ +package cmd + +import ( + "context" + "encoding/json" + "strings" + "testing" + "time" + + "github.com/automagik-dev/workit/internal/msauth" + "github.com/automagik-dev/workit/internal/secrets" +) + +func TestAuthAddM365UsesOAuthAndStoresToken(t *testing.T) { + origOpen := openSecretsStore + origAuth := authorizeM365 + origKeychain := ensureKeychainAccess + t.Cleanup(func() { + openSecretsStore = origOpen + authorizeM365 = origAuth + ensureKeychainAccess = origKeychain + }) + + store := newMemSecretsStore() + openSecretsStore = func() (secrets.Store, error) { return store, nil } + ensureKeychainAccess = func() error { return nil } + authorizeM365 = func(context.Context, msauth.AuthorizeOptions) (msauth.AuthorizeResult, error) { + return msauth.AuthorizeResult{Email: "pilot@example.com", RefreshToken: "m365-refresh-token"}, nil + } + + out := captureStdout(t, func() { + _ = captureStderr(t, func() { + if err := Execute([]string{"--json", "auth", "add", "pilot@example.com", "--services", "m365", "--readonly"}); err != nil { + t.Fatalf("auth add m365: %v", err) + } + }) + }) + + var payload map[string]any + if err := json.Unmarshal([]byte(out), &payload); err != nil { + t.Fatalf("json output: %v\n%s", err, out) + } + if payload["provider"] != "microsoft_graph" || payload["stored"] != true { + t.Fatalf("unexpected output: %#v", payload) + } + tok, err := store.GetToken(msauth.ClientName, "pilot@example.com") + if err != nil { + t.Fatalf("stored m365 token: %v", err) + } + if tok.RefreshToken != "m365-refresh-token" { + t.Fatalf("refresh token = %q", tok.RefreshToken) + } + if !stringSliceContainsForM365AuthTest(tok.Services, "m365") { + t.Fatalf("services = %#v", tok.Services) + } +} + +func TestAuthAddM365RequiresReadonly(t *testing.T) { + _ = captureStderr(t, func() { + err := Execute([]string{"--json", "auth", "add", "pilot@example.com", "--services", "m365"}) + if err == nil { + t.Fatal("expected missing --readonly to fail closed") + } + if !strings.Contains(err.Error(), "--readonly") { + t.Fatalf("expected --readonly error, got: %v", err) + } + }) +} + +func TestAuthAddM365RemoteModeFailsClosed(t *testing.T) { + _ = captureStderr(t, func() { + err := Execute([]string{"--json", "auth", "add", "pilot@example.com", "--services", "m365", "--readonly", "--remote", "--step", "1"}) + if err == nil { + t.Fatal("expected remote m365 auth to fail closed") + } + if !strings.Contains(err.Error(), "remote auth is not supported yet") { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestAuthManageM365PrintURLIsNonTechnicalOAuthHandoff(t *testing.T) { + origURL := m365ManualAuthURL + t.Cleanup(func() { m365ManualAuthURL = origURL }) + m365ManualAuthURL = func(context.Context, msauth.ManualAuthURLOptions) (msauth.ManualAuthURLResult, error) { + return msauth.ManualAuthURLResult{URL: "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?client_id=test", State: "state", ExpiresIn: int((5 * time.Minute).Seconds())}, nil + } + + out := captureStdout(t, func() { + _ = captureStderr(t, func() { + if err := Execute([]string{"--json", "auth", "manage", "--services", "m365", "--print-url"}); err != nil { + t.Fatalf("auth manage m365 print-url: %v", err) + } + }) + }) + if !strings.Contains(out, "login.microsoftonline.com") || !strings.Contains(out, "microsoft_graph") { + t.Fatalf("unexpected output: %s", out) + } +} + +func stringSliceContainsForM365AuthTest(values []string, want string) bool { + for _, value := range values { + if value == want { + return true + } + } + return false +} diff --git a/internal/cmd/testutil_test.go b/internal/cmd/testutil_test.go index e5053786..b1b17550 100644 --- a/internal/cmd/testutil_test.go +++ b/internal/cmd/testutil_test.go @@ -80,11 +80,17 @@ func captureStdout(t *testing.T, fn func()) string { } os.Stdout = w + outCh := make(chan []byte, 1) + go func() { + b, _ := io.ReadAll(r) + outCh <- b + }() + fn() _ = w.Close() os.Stdout = orig - b, _ := io.ReadAll(r) + b := <-outCh _ = r.Close() return string(b) } @@ -99,11 +105,17 @@ func captureStderr(t *testing.T, fn func()) string { } os.Stderr = w + outCh := make(chan []byte, 1) + go func() { + b, _ := io.ReadAll(r) + outCh <- b + }() + fn() _ = w.Close() os.Stderr = orig - b, _ := io.ReadAll(r) + b := <-outCh _ = r.Close() return string(b) } diff --git a/internal/config/defaults.go b/internal/config/defaults.go index b761005b..db0f6af1 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -11,6 +11,11 @@ var ( DefaultClientID string DefaultClientSecret string + // DefaultM365ClientID and DefaultM365TenantID are optional Microsoft 365 OAuth defaults + // injected into internal builds so non-technical users can log in with a browser. + DefaultM365ClientID string + DefaultM365TenantID string + // DefaultCallbackServer is the default OAuth relay used when WK_CALLBACK_SERVER is not set. // Override at build time via: -ldflags "-X github.com/automagik-dev/workit/internal/config.DefaultCallbackServer=https://custom.example.com" DefaultCallbackServer = "https://auth.automagik.dev" diff --git a/internal/msauth/oauth.go b/internal/msauth/oauth.go new file mode 100644 index 00000000..76a6ce5c --- /dev/null +++ b/internal/msauth/oauth.go @@ -0,0 +1,383 @@ +package msauth + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "os" + "strings" + "time" + + "golang.org/x/oauth2" + + "github.com/automagik-dev/workit/internal/config" +) + +const ( + ClientName = "m365" + DefaultLocalAuthPort = 8085 +) + +var localAuthPort = DefaultLocalAuthPort + +var ( + ErrMissingClientID = errors.New("m365 oauth client id missing") + ErrMissingScopes = errors.New("m365 oauth scopes missing") + ErrNoRefreshToken = errors.New("m365 refresh token missing; ensure offline_access is granted") + ErrStateMismatch = errors.New("m365 oauth state mismatch") + ErrMissingCode = errors.New("m365 oauth missing code") + ErrAuthorization = errors.New("m365 authorization error") + ErrProfileStatus = errors.New("fetch m365 profile status error") + ErrProfileMissingEmail = errors.New("m365 profile missing email") + ErrContextDone = errors.New("m365 oauth context done") +) + +var ( + openBrowserFn func(context.Context, string) error = openBrowser + randomStateFn = randomURLToken + oauthConfigFn = oauthConfig + graphMeURL = "https://graph.microsoft.com/v1.0/me" +) + +type AuthorizeOptions struct { + ExpectedEmail string + Readonly bool + Manual bool + ForceConsent bool + Timeout time.Duration + AuthURL string +} + +type AuthorizeResult struct { + Email string + RefreshToken string +} + +type ManualAuthURLOptions struct { + Readonly bool + ForceConsent bool +} + +type ManualAuthURLResult struct { + URL string `json:"auth_url"` + State string `json:"state"` + ExpiresIn int `json:"expires_in"` +} + +type oauthSettings struct { + ClientID string + TenantID string +} + +func Authorize(ctx context.Context, opts AuthorizeOptions) (AuthorizeResult, error) { + if opts.Timeout <= 0 { + opts.Timeout = 5 * time.Minute + } + + settings, err := resolveOAuthSettings() + if err != nil { + return AuthorizeResult{}, err + } + + scopes, err := OAuthScopes(opts.Readonly) + if err != nil { + return AuthorizeResult{}, err + } + + ctx, cancel := context.WithTimeout(ctx, opts.Timeout) + defer cancel() + + state, verifier, challenge, err := newOAuthStateAndPKCE() + if err != nil { + return AuthorizeResult{}, err + } + + ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", localAuthPort)) + if err != nil { + return AuthorizeResult{}, fmt.Errorf("listen for m365 callback on port %d: %w", localAuthPort, err) + } + + defer func() { _ = ln.Close() }() + + redirectURI := fmt.Sprintf("http://localhost:%d/oauth2/callback", localAuthPort) + cfg := oauthConfigFn(settings, redirectURI, scopes) + codeCh := make(chan string, 1) + errCh := make(chan error, 1) + srv := m365OAuthServer(ctx, state, codeCh, errCh) + + go func() { + <-ctx.Done() + _ = srv.Close() + }() + + go func() { + if serveErr := srv.Serve(ln); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { + select { + case errCh <- serveErr: + default: + } + } + }() + + authURL := cfg.AuthCodeURL(state, authParams(opts.ForceConsent, challenge)...) + + fmt.Fprintln(os.Stderr, "Opening browser for Microsoft 365 authorization…") + fmt.Fprintln(os.Stderr, "If the browser doesn't open, visit this URL:") + fmt.Fprintln(os.Stderr, authURL) + _ = openBrowserFn(ctx, authURL) + + select { + case code := <-codeCh: + return exchangeCodeAndProfile(ctx, srv, cfg, code, verifier) + case authErr := <-errCh: + return AuthorizeResult{}, authErr + case <-ctx.Done(): + return AuthorizeResult{}, fmt.Errorf("%w: %w", ErrContextDone, ctx.Err()) + } +} + +func ManualAuthURL(_ context.Context, opts ManualAuthURLOptions) (ManualAuthURLResult, error) { + settings, err := resolveOAuthSettings() + if err != nil { + return ManualAuthURLResult{}, err + } + + scopes, err := OAuthScopes(opts.Readonly) + if err != nil { + return ManualAuthURLResult{}, err + } + + state, _, challenge, err := newOAuthStateAndPKCE() + if err != nil { + return ManualAuthURLResult{}, err + } + + redirectURI := fmt.Sprintf("http://localhost:%d/oauth2/callback", localAuthPort) + cfg := oauthConfigFn(settings, redirectURI, scopes) + url := cfg.AuthCodeURL(state, authParams(opts.ForceConsent, challenge)...) + + return ManualAuthURLResult{URL: url, State: state, ExpiresIn: 300}, nil +} + +func OAuthScopes(readonly bool) ([]string, error) { + if !readonly { + return nil, ErrPilotScopeNotAllowed + } + + scopes := append([]string{"offline_access"}, PilotAllowedScopes()...) + if len(scopes) == 0 { + return nil, ErrMissingScopes + } + + return scopes, nil +} + +func FetchEmail(ctx context.Context, accessToken string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, graphMeURL, nil) + if err != nil { + return "", fmt.Errorf("create m365 profile request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetch m365 profile: %w", err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("%w: %d", ErrProfileStatus, resp.StatusCode) + } + + var me struct { + Mail string `json:"mail"` + UserPrincipalName string `json:"userPrincipalName"` //nolint:tagliatelle // Microsoft Graph field name. + } + if err := json.NewDecoder(resp.Body).Decode(&me); err != nil { + return "", fmt.Errorf("decode m365 profile: %w", err) + } + + email := strings.TrimSpace(me.Mail) + if email == "" { + email = strings.TrimSpace(me.UserPrincipalName) + } + + if email == "" { + return "", ErrProfileMissingEmail + } + + return email, nil +} + +func resolveOAuthSettings() (oauthSettings, error) { + clientID := strings.TrimSpace(config.DefaultM365ClientID) + if clientID == "" { + clientID = strings.TrimSpace(os.Getenv("WK_M365_CLIENT_ID")) + } + + if clientID == "" { + return oauthSettings{}, ErrMissingClientID + } + + tenantID := strings.TrimSpace(config.DefaultM365TenantID) + if tenantID == "" { + tenantID = strings.TrimSpace(os.Getenv("WK_M365_TENANT_ID")) + } + + if tenantID == "" { + tenantID = "organizations" + } + + return oauthSettings{ClientID: clientID, TenantID: tenantID}, nil +} + +func oauthConfig(settings oauthSettings, redirectURI string, scopes []string) oauth2.Config { + base := "https://login.microsoftonline.com/" + settings.TenantID + "/oauth2/v2.0" + + return oauth2.Config{ + ClientID: settings.ClientID, + Endpoint: oauth2.Endpoint{AuthURL: base + "/authorize", TokenURL: base + "/token"}, + RedirectURL: redirectURI, + Scopes: scopes, + } +} + +func authParams(forceConsent bool, challenge string) []oauth2.AuthCodeOption { + params := []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", challenge), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + } + if forceConsent { + params = append(params, oauth2.SetAuthURLParam("prompt", "consent")) + } + + return params +} + +func newOAuthStateAndPKCE() (state string, verifier string, challenge string, err error) { + state, err = randomStateFn() + if err != nil { + return "", "", "", err + } + + verifier, challenge, err = pkcePair() + if err != nil { + return "", "", "", err + } + + return state, verifier, challenge, nil +} + +func pkcePair() (verifier string, challenge string, err error) { + verifier, err = randomURLToken() + if err != nil { + return "", "", err + } + + sum := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(sum[:]) + + return verifier, challenge, nil +} + +func randomURLToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate random token: %w", err) + } + + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func m365OAuthServer(ctx context.Context, state string, codeCh chan<- string, errCh chan<- error) *http.Server { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleM365OAuthCallback(w, r, state, codeCh, errCh) + }) + + return &http.Server{ReadHeaderTimeout: 5 * time.Second, Handler: handler, BaseContext: func(net.Listener) context.Context { return ctx }} +} + +func handleM365OAuthCallback(w http.ResponseWriter, r *http.Request, state string, codeCh chan<- string, errCh chan<- error) { + if r.URL.Path != "/oauth2/callback" { + http.NotFound(w, r) + return + } + + q := r.URL.Query() + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + + if q.Get("error") != "" { + select { + case errCh <- fmt.Errorf("%w: %s", ErrAuthorization, q.Get("error")): + default: + } + + _, _ = w.Write([]byte("Microsoft 365 authorization failed. You may close this tab.")) + + return + } + + if q.Get("state") != state { + select { + case errCh <- ErrStateMismatch: + default: + } + + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("State mismatch. Please try again.")) + + return + } + + code := q.Get("code") + if code == "" { + select { + case errCh <- ErrMissingCode: + default: + } + + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("Missing authorization code. Please try again.")) + + return + } + + select { + case codeCh <- code: + default: + } + + _, _ = w.Write([]byte("Microsoft 365 authorization complete. You may close this tab.")) +} + +func exchangeCodeAndProfile(ctx context.Context, srv *http.Server, cfg oauth2.Config, code string, verifier string) (AuthorizeResult, error) { + tok, err := cfg.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", verifier)) + if err != nil { + _ = srv.Close() + return AuthorizeResult{}, fmt.Errorf("exchange m365 code: %w", err) + } + + if tok.RefreshToken == "" { + _ = srv.Close() + return AuthorizeResult{}, ErrNoRefreshToken + } + + email, err := FetchEmail(ctx, tok.AccessToken) + if err != nil { + _ = srv.Close() + return AuthorizeResult{}, err + } + + _ = srv.Shutdown(ctx) + + return AuthorizeResult{Email: email, RefreshToken: tok.RefreshToken}, nil +} diff --git a/internal/msauth/oauth_more_test.go b/internal/msauth/oauth_more_test.go new file mode 100644 index 00000000..3f427db1 --- /dev/null +++ b/internal/msauth/oauth_more_test.go @@ -0,0 +1,273 @@ +package msauth + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "golang.org/x/oauth2" + + "github.com/automagik-dev/workit/internal/config" +) + +func TestFetchEmailUsesMailThenUserPrincipalName(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer access-token" { + t.Fatalf("Authorization = %q", got) + } + + _, _ = w.Write([]byte(`{"mail":"","userPrincipalName":"pilot@example.com"}`)) + })) + defer server.Close() + + origURL := graphMeURL + + t.Cleanup(func() { graphMeURL = origURL }) + graphMeURL = server.URL + + email, err := FetchEmail(context.Background(), "access-token") + if err != nil { + t.Fatalf("FetchEmail: %v", err) + } + + if email != "pilot@example.com" { + t.Fatalf("email = %q", email) + } +} + +func TestFetchEmailFailsClosedOnBadProfileStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer server.Close() + + origURL := graphMeURL + + t.Cleanup(func() { graphMeURL = origURL }) + graphMeURL = server.URL + + _, err := FetchEmail(context.Background(), "access-token") + if !errors.Is(err, ErrProfileStatus) { + t.Fatalf("expected ErrProfileStatus, got: %v", err) + } +} + +func TestHandleM365OAuthCallbackValidatesStateAndCode(t *testing.T) { + codeCh := make(chan string, 1) + errCh := make(chan error, 1) + + req := httptest.NewRequest(http.MethodGet, "/oauth2/callback?state=good&code=abc", nil) + rec := httptest.NewRecorder() + handleM365OAuthCallback(rec, req, "good", codeCh, errCh) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d", rec.Code) + } + + if got := <-codeCh; got != "abc" { + t.Fatalf("code = %q", got) + } + + badReq := httptest.NewRequest(http.MethodGet, "/oauth2/callback?state=bad&code=abc", nil) + badRec := httptest.NewRecorder() + handleM365OAuthCallback(badRec, badReq, "good", codeCh, errCh) + + if badRec.Code != http.StatusBadRequest { + t.Fatalf("bad status = %d", badRec.Code) + } + + if err := <-errCh; !errors.Is(err, ErrStateMismatch) { + t.Fatalf("expected state mismatch, got: %v", err) + } +} + +func TestHandleM365OAuthCallbackReportsProviderErrorAndMissingCode(t *testing.T) { + codeCh := make(chan string, 1) + errCh := make(chan error, 2) + + errorReq := httptest.NewRequest(http.MethodGet, "/oauth2/callback?error=access_denied", nil) + errorRec := httptest.NewRecorder() + handleM365OAuthCallback(errorRec, errorReq, "state", codeCh, errCh) + + if err := <-errCh; !errors.Is(err, ErrAuthorization) { + t.Fatalf("expected ErrAuthorization, got: %v", err) + } + + missingCodeReq := httptest.NewRequest(http.MethodGet, "/oauth2/callback?state=state", nil) + missingCodeRec := httptest.NewRecorder() + handleM365OAuthCallback(missingCodeRec, missingCodeReq, "state", codeCh, errCh) + + if missingCodeRec.Code != http.StatusBadRequest { + t.Fatalf("missing-code status = %d", missingCodeRec.Code) + } + + if err := <-errCh; !errors.Is(err, ErrMissingCode) { + t.Fatalf("expected ErrMissingCode, got: %v", err) + } +} + +func TestExchangeCodeAndProfileRequiresRefreshToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/token") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"access-token","token_type":"Bearer"}`)) + + return + } + + t.Fatalf("unexpected path: %s", r.URL.Path) + })) + defer server.Close() + + cfg := oauth2.Config{ + ClientID: "client-id", + Endpoint: oauth2.Endpoint{TokenURL: server.URL + "/token"}, + RedirectURL: "http://localhost:8085/oauth2/callback", + Scopes: []string{"User.Read"}, + } + + _, err := exchangeCodeAndProfile(context.Background(), &http.Server{}, cfg, "code", "verifier") + if !errors.Is(err, ErrNoRefreshToken) { + t.Fatalf("expected ErrNoRefreshToken, got: %v", err) + } +} + +func TestExchangeCodeAndProfileStoresSuccessfulProfileEmail(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, "/token") { + t.Fatalf("unexpected token path: %s", r.URL.Path) + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"access-token","refresh_token":"refresh-token","token_type":"Bearer"}`)) + })) + defer tokenServer.Close() + + graphServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"mail":"pilot@example.com"}`)) + })) + defer graphServer.Close() + + origURL := graphMeURL + + t.Cleanup(func() { graphMeURL = origURL }) + graphMeURL = graphServer.URL + + cfg := oauth2.Config{ + ClientID: "client-id", + Endpoint: oauth2.Endpoint{TokenURL: tokenServer.URL + "/token"}, + RedirectURL: "http://localhost:8085/oauth2/callback", + Scopes: []string{"User.Read"}, + } + + result, err := exchangeCodeAndProfile(context.Background(), &http.Server{}, cfg, "code", "verifier") + if err != nil { + t.Fatalf("exchangeCodeAndProfile: %v", err) + } + + if result.Email != "pilot@example.com" || result.RefreshToken != "refresh-token" { + t.Fatalf("result = %#v", result) + } +} + +func TestAuthorizeCompletesBrowserOAuthWithPKCE(t *testing.T) { + origClientID := config.DefaultM365ClientID + origTenantID := config.DefaultM365TenantID + origRandom := randomStateFn + origOpen := openBrowserFn + origOAuthConfig := oauthConfigFn + origGraphURL := graphMeURL + origPort := localAuthPort + + t.Cleanup(func() { + config.DefaultM365ClientID = origClientID + config.DefaultM365TenantID = origTenantID + randomStateFn = origRandom + openBrowserFn = origOpen + oauthConfigFn = origOAuthConfig + graphMeURL = origGraphURL + localAuthPort = origPort + }) + + config.DefaultM365ClientID = "client-id" + config.DefaultM365TenantID = "organizations" + localAuthPort = 18085 + randomStateFn = func() (string, error) { return "fixed-state", nil } + + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"access-token","refresh_token":"refresh-token","token_type":"Bearer"}`)) + })) + defer tokenServer.Close() + + graphServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"mail":"pilot@example.com"}`)) + })) + defer graphServer.Close() + graphMeURL = graphServer.URL + + oauthConfigFn = func(settings oauthSettings, redirectURI string, scopes []string) oauth2.Config { + return oauth2.Config{ + ClientID: settings.ClientID, + Endpoint: oauth2.Endpoint{AuthURL: tokenServer.URL + "/authorize", TokenURL: tokenServer.URL + "/token"}, + RedirectURL: redirectURI, + Scopes: scopes, + } + } + openBrowserFn = func(ctx context.Context, authURL string) error { + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("parse auth url: %v", err) + } + + redirectURI := parsed.Query().Get("redirect_uri") + if redirectURI == "" { + t.Fatal("missing redirect_uri") + } + callbackURL := redirectURI + "?state=fixed-state&" + "code=ok" + + go func() { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, callbackURL, nil) + if err != nil { + return + } + + resp, err := http.DefaultClient.Do(req) + if err == nil { + _ = resp.Body.Close() + } + }() + + return nil + } + + result, err := Authorize(context.Background(), AuthorizeOptions{Readonly: true}) + if err != nil { + t.Fatalf("Authorize: %v", err) + } + + if result.Email != "pilot@example.com" || result.RefreshToken != "refresh-token" { + t.Fatalf("result = %#v", result) + } +} + +func TestServicesInfoReturnsM365ReadOnlyMetadata(t *testing.T) { + infos := ServicesInfo() + if len(infos) != 1 || infos[0].Service != "m365" { + t.Fatalf("unexpected infos: %#v", infos) + } + + for _, forbidden := range []string{"Mail.Send", "Calendars.ReadWrite"} { + for _, scope := range infos[0].Scopes { + if scope == forbidden { + t.Fatalf("forbidden scope exposed: %s", forbidden) + } + } + } +} diff --git a/internal/msauth/oauth_test.go b/internal/msauth/oauth_test.go new file mode 100644 index 00000000..cd21f2c5 --- /dev/null +++ b/internal/msauth/oauth_test.go @@ -0,0 +1,91 @@ +package msauth + +import ( + "context" + "net/url" + "os" + "strings" + "testing" + + "github.com/automagik-dev/workit/internal/config" +) + +func TestManualAuthURLRequiresClientIDFailClosed(t *testing.T) { + origClientID := config.DefaultM365ClientID + origTenantID := config.DefaultM365TenantID + + t.Cleanup(func() { + config.DefaultM365ClientID = origClientID + config.DefaultM365TenantID = origTenantID + }) + + config.DefaultM365ClientID = "" + config.DefaultM365TenantID = "" + + t.Setenv("WK_M365_CLIENT_ID", "") + + _, err := ManualAuthURL(context.Background(), ManualAuthURLOptions{Readonly: true}) + if err == nil || !strings.Contains(err.Error(), "client id") { + t.Fatalf("expected missing client id error, got: %v", err) + } +} + +func TestManualAuthURLUsesMicrosoftOAuthWithOnlyReadPilotScopes(t *testing.T) { + origClientID := config.DefaultM365ClientID + origTenantID := config.DefaultM365TenantID + origRandom := randomStateFn + + t.Cleanup(func() { + config.DefaultM365ClientID = origClientID + config.DefaultM365TenantID = origTenantID + randomStateFn = origRandom + }) + + config.DefaultM365ClientID = "test-client-id" + config.DefaultM365TenantID = "organizations" + _ = os.Unsetenv("WK_M365_CLIENT_ID") + randomStateFn = func() (string, error) { return "state-for-test", nil } + + result, err := ManualAuthURL(context.Background(), ManualAuthURLOptions{Readonly: true}) + if err != nil { + t.Fatalf("ManualAuthURL: %v", err) + } + + parsed, err := url.Parse(result.URL) + if err != nil { + t.Fatalf("parse auth url: %v", err) + } + + if parsed.Host != "login.microsoftonline.com" || !strings.Contains(parsed.Path, "/organizations/oauth2/v2.0/authorize") { + t.Fatalf("unexpected auth endpoint: %s", result.URL) + } + + q := parsed.Query() + if q.Get("client_id") != "test-client-id" { + t.Fatalf("client_id = %q", q.Get("client_id")) + } + + if q.Get("code_challenge_method") != "S256" || q.Get("code_challenge") == "" { + t.Fatalf("missing PKCE params: %s", result.URL) + } + + scope := q.Get("scope") + for _, want := range []string{"offline_access", "User.Read", "Mail.Read", "Calendars.Read"} { + if !strings.Contains(scope, want) { + t.Fatalf("scope %q missing %s", scope, want) + } + } + + for _, forbidden := range []string{"Mail.Send", "Calendars.ReadWrite"} { + if strings.Contains(scope, forbidden) { + t.Fatalf("scope %q contains forbidden %s", scope, forbidden) + } + } +} + +func TestOAuthScopesRejectsNonReadonly(t *testing.T) { + _, err := OAuthScopes(false) + if err == nil { + t.Fatal("expected non-readonly scopes to fail closed") + } +} diff --git a/internal/msauth/open_browser.go b/internal/msauth/open_browser.go new file mode 100644 index 00000000..bdc9888a --- /dev/null +++ b/internal/msauth/open_browser.go @@ -0,0 +1,32 @@ +package msauth + +import ( + "context" + "fmt" + "os/exec" + "runtime" + "strings" +) + +func quoteWindowsStartURL(url string) string { + return `"` + strings.ReplaceAll(url, `"`, `%22`) + `"` +} + +func openBrowser(ctx context.Context, url string) error { + var cmd *exec.Cmd + + switch runtime.GOOS { + case "darwin": + cmd = exec.CommandContext(ctx, "open", url) + case "windows": + cmd = exec.CommandContext(ctx, "cmd", "/c", "start", "", quoteWindowsStartURL(url)) //nolint:gosec // URL is quoted for cmd/start; browser launch intentionally uses the OAuth URL. + default: + cmd = exec.CommandContext(ctx, "xdg-open", url) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("open browser: %w", err) + } + + return nil +} diff --git a/internal/msauth/open_browser_test.go b/internal/msauth/open_browser_test.go new file mode 100644 index 00000000..074b57b9 --- /dev/null +++ b/internal/msauth/open_browser_test.go @@ -0,0 +1,30 @@ +package msauth + +import ( + "strings" + "testing" +) + +func TestQuoteWindowsStartURLPreservesOAuthQuerySeparators(t *testing.T) { + url := `https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?client_id=client&scope=User.Read+Mail.Read&redirect_uri=http://localhost:8085/oauth2/callback` + got := quoteWindowsStartURL(url) + + if !strings.HasPrefix(got, `"`) || !strings.HasSuffix(got, `"`) { + t.Fatalf("expected quoted URL, got %q", got) + } + + if !strings.Contains(got, `&scope=`) || !strings.Contains(got, `&redirect_uri=`) { + t.Fatalf("expected OAuth query separators preserved, got %q", got) + } +} + +func TestQuoteWindowsStartURLEscapesEmbeddedQuotes(t *testing.T) { + got := quoteWindowsStartURL(`https://example.test/callback?x="bad"&y=1`) + if strings.Contains(strings.Trim(got, `"`), `"`) { + t.Fatalf("expected embedded quotes to be escaped, got %q", got) + } + + if !strings.Contains(got, `%22bad%22`) { + t.Fatalf("expected escaped quote payload, got %q", got) + } +}