Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions internal/api/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/lox/notion-cli/internal/config"
)

const (
defaultBaseURL = "https://api.notion.com/v1"
defaultNotionAPIRev = "2022-06-28"
)

type Client struct {
httpClient *http.Client
baseURL string
notionVersion string
token string
}

func NewClient(cfg config.APIConfig, token string) (*Client, error) {
token = strings.TrimSpace(token)
if token == "" {
return nil, fmt.Errorf("official API token is required")
}

baseURL := strings.TrimSpace(cfg.BaseURL)
if baseURL == "" {
baseURL = defaultBaseURL
}
baseURL = strings.TrimRight(baseURL, "/")

notionVersion := strings.TrimSpace(cfg.NotionVersion)
if notionVersion == "" {
notionVersion = defaultNotionAPIRev
}

return &Client{
httpClient: &http.Client{Timeout: 20 * time.Second},
baseURL: baseURL,
notionVersion: notionVersion,
token: token,
}, nil
}

func (c *Client) PatchPage(ctx context.Context, pageID string, patch map[string]any) error {
pageID = strings.TrimSpace(pageID)
if pageID == "" {
return fmt.Errorf("page ID is required")
}
if len(patch) == 0 {
return fmt.Errorf("patch payload is required")
}

return c.doJSON(ctx, http.MethodPatch, "/pages/"+pageID, patch, nil)
}

func (c *Client) doJSON(ctx context.Context, method, path string, payload any, out any) error {
var bodyReader io.Reader
if payload != nil {
data, err := json.Marshal(payload)
if err != nil {
return err
}
bodyReader = bytes.NewReader(data)
}

req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, bodyReader)
if err != nil {
return err
}
req.Header.Set("accept", "application/json")
req.Header.Set("authorization", "Bearer "+c.token)
req.Header.Set("notion-version", c.notionVersion)
if payload != nil {
req.Header.Set("content-type", "application/json")
}

resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}

if resp.StatusCode >= 400 {
message := strings.TrimSpace(string(respBody))
if message == "" {
message = http.StatusText(resp.StatusCode)
} else {
var errResp struct {
Message string `json:"message"`
}
if err := json.Unmarshal(respBody, &errResp); err == nil && strings.TrimSpace(errResp.Message) != "" {
message = strings.TrimSpace(errResp.Message)
}
}
return fmt.Errorf("official API %s %s failed (%d): %s", method, path, resp.StatusCode, message)
}

if out == nil || len(respBody) == 0 {
return nil
}
if err := json.Unmarshal(respBody, out); err != nil {
return fmt.Errorf("parse official API response for %s %s: %w", method, path, err)
}
return nil
}
108 changes: 108 additions & 0 deletions internal/api/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package api

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/lox/notion-cli/internal/config"
)

func TestNewClientRequiresToken(t *testing.T) {
t.Parallel()

_, err := NewClient(config.APIConfig{}, "")
if err == nil {
t.Fatal("expected token error")
}
}

func TestPatchPageSendsPatchRequest(t *testing.T) {
t.Parallel()

var gotMethod string
var gotPath string
var gotAuth string
var gotVersion string
var gotContentType string
var gotBody map[string]any

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotMethod = r.Method
gotPath = r.URL.Path
gotAuth = r.Header.Get("Authorization")
gotVersion = r.Header.Get("Notion-Version")
gotContentType = r.Header.Get("Content-Type")

defer func() { _ = r.Body.Close() }()
if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil {
t.Fatalf("decode request body: %v", err)
}

w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"page-id","object":"page"}`))
}))
defer srv.Close()

client, err := NewClient(config.APIConfig{
BaseURL: srv.URL,
NotionVersion: "2022-06-28",
}, "secret-token")
if err != nil {
t.Fatalf("new client: %v", err)
}

patch := map[string]any{
"archived": true,
}

if err := client.PatchPage(context.Background(), "page-id", patch); err != nil {
t.Fatalf("patch page: %v", err)
}

if gotMethod != http.MethodPatch {
t.Fatalf("method mismatch: got %s", gotMethod)
}
if gotPath != "/pages/page-id" {
t.Fatalf("path mismatch: got %s", gotPath)
}
if gotAuth != "Bearer secret-token" {
t.Fatalf("auth mismatch: got %s", gotAuth)
}
if gotVersion != "2022-06-28" {
t.Fatalf("notion-version mismatch: got %s", gotVersion)
}
if gotContentType != "application/json" {
t.Fatalf("content-type mismatch: got %s", gotContentType)
}

if gotBody["archived"] != true {
t.Fatalf("archived mismatch: %v", gotBody["archived"])
}
}

func TestPatchPageReturnsAPIErrorMessage(t *testing.T) {
t.Parallel()

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"object":"error","message":"unauthorized"}`))
}))
defer srv.Close()

client, err := NewClient(config.APIConfig{BaseURL: srv.URL}, "secret-token")
if err != nil {
t.Fatalf("new client: %v", err)
}

err = client.PatchPage(context.Background(), "page-id", map[string]any{"archived": true})
if err == nil {
t.Fatal("expected API error")
}
if !strings.Contains(err.Error(), "unauthorized") {
t.Fatalf("expected unauthorized message, got: %v", err)
}
}
22 changes: 22 additions & 0 deletions internal/cli/official_api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package cli

import (
"fmt"

"github.com/lox/notion-cli/internal/api"
"github.com/lox/notion-cli/internal/config"
)

func RequireOfficialAPIClient() (*api.Client, error) {
cfg, err := config.Load()
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}

client, err := api.NewClient(cfg.API, cfg.API.Token)
if err != nil {
return nil, fmt.Errorf("create official API client: %w (set api.token in ~/.config/notion-cli/config.json or NOTION_API_TOKEN)", err)
}

return client, nil
}
96 changes: 96 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package config

import (
"encoding/json"
"errors"
"os"
"path/filepath"
"strings"
)

const (
configDirName = ".config/notion-cli"
configFileName = "config.json"
)

type Config struct {
ActiveAccount string `json:"active_account,omitempty"`
API APIConfig `json:"api,omitempty"`
}

type APIConfig struct {
BaseURL string `json:"base_url,omitempty"`
NotionVersion string `json:"notion_version,omitempty"`
Token string `json:"token,omitempty"`
}

func Default() Config {
return Config{
API: APIConfig{
BaseURL: "https://api.notion.com/v1",
NotionVersion: "2022-06-28",
},
}
}

func Load() (Config, error) {
cfg := Default()

path, err := Path()
if err != nil {
return cfg, err
}

if data, err := os.ReadFile(path); err == nil {
if err := json.Unmarshal(data, &cfg); err != nil {
return cfg, err
}
} else if !errors.Is(err, os.ErrNotExist) {
return cfg, err
}

applyEnvOverrides(&cfg)
normalize(&cfg)
return cfg, nil
}

func Path() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, configDirName, configFileName), nil
}

func applyEnvOverrides(cfg *Config) {
if cfg == nil {
return
}

if s := os.Getenv("NOTION_API_BASE_URL"); s != "" {
cfg.API.BaseURL = s
}
if s := os.Getenv("NOTION_API_NOTION_VERSION"); s != "" {
cfg.API.NotionVersion = s
}
if s := os.Getenv("NOTION_API_TOKEN"); s != "" {
cfg.API.Token = s
}
}

func normalize(cfg *Config) {
if cfg == nil {
return
}

cfg.API.BaseURL = strings.TrimSpace(cfg.API.BaseURL)
if cfg.API.BaseURL == "" {
cfg.API.BaseURL = "https://api.notion.com/v1"
}
cfg.API.BaseURL = strings.TrimRight(cfg.API.BaseURL, "/")
cfg.API.NotionVersion = strings.TrimSpace(cfg.API.NotionVersion)
if cfg.API.NotionVersion == "" {
cfg.API.NotionVersion = "2022-06-28"
}
cfg.API.Token = strings.TrimSpace(cfg.API.Token)
}
47 changes: 47 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package config

import "testing"

func TestApplyEnvOverrides(t *testing.T) {
t.Setenv("NOTION_API_BASE_URL", "https://api.example.com/v1/")
t.Setenv("NOTION_API_NOTION_VERSION", "2022-06-28")
t.Setenv("NOTION_API_TOKEN", "api-token")

cfg := Default()
applyEnvOverrides(&cfg)
normalize(&cfg)

if cfg.API.BaseURL != "https://api.example.com/v1" {
t.Fatalf("unexpected api.base_url normalization: %q", cfg.API.BaseURL)
}
if cfg.API.NotionVersion != "2022-06-28" {
t.Fatalf("unexpected api.notion_version: %q", cfg.API.NotionVersion)
}
if cfg.API.Token != "api-token" {
t.Fatalf("unexpected api.token: %q", cfg.API.Token)
}
}

func TestNormalizeAppliesAPIDefaults(t *testing.T) {
cfg := Config{}
normalize(&cfg)

if cfg.API.BaseURL != "https://api.notion.com/v1" {
t.Fatalf("unexpected api.base_url default: %q", cfg.API.BaseURL)
}
if cfg.API.NotionVersion != "2022-06-28" {
t.Fatalf("unexpected api.notion_version default: %q", cfg.API.NotionVersion)
}
}

func TestPathUsesHome(t *testing.T) {
t.Setenv("HOME", "/tmp/example-home")

path, err := Path()
if err != nil {
t.Fatal(err)
}
if path != "/tmp/example-home/.config/notion-cli/config.json" {
t.Fatalf("unexpected path: %s", path)
}
}