diff --git a/README.md b/README.md index 86f40e4..db9243e 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,9 @@ notion-cli page sync ./document.md # Updates page using notion-cli page sync ./document.md --parent "Engineering" # Set parent on first sync notion-cli page sync ./document.md --parent-db # Sync as database entry +# Local image paths like ./image.png are uploaded via official API fallback during upload/sync +# when detected in markdown content. Requires NOTION_API_TOKEN. + # Edit an existing page notion-cli page edit --replace "New content" # Replace all content notion-cli page edit --find "old text" --replace-with "new text" # Find and replace diff --git a/cmd/page.go b/cmd/page.go index 539fcae..6b2f065 100644 --- a/cmd/page.go +++ b/cmd/page.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" + "github.com/lox/notion-cli/internal/api" "github.com/lox/notion-cli/internal/cli" "github.com/lox/notion-cli/internal/mcp" "github.com/lox/notion-cli/internal/output" @@ -206,6 +207,15 @@ func runPageUpload(ctx *Context, file, title, parent, parentDB, icon string) err } markdown := string(content) + bgCtx := context.Background() + localUploads, err := maybeUploadLocalImages(bgCtx, file, markdown, "", "") + if err != nil { + output.PrintError(err) + return err + } + if len(localUploads) > 0 { + output.PrintInfo(fmt.Sprintf("Uploaded %d local image(s) via Notion REST file uploads", len(localUploads))) + } if title == "" { title = extractTitleFromMarkdown(markdown) @@ -224,8 +234,6 @@ func runPageUpload(ctx *Context, file, title, parent, parentDB, icon string) err } defer func() { _ = client.Close() }() - bgCtx := context.Background() - req := mcp.CreatePageRequest{ Title: title, Content: markdown, @@ -257,6 +265,15 @@ func runPageUpload(ctx *Context, file, title, parent, parentDB, icon string) err output.PrintError(err) return err } + pageID := pageIDFromCreateResponse(resp) + if len(localUploads) > 0 { + if pageID == "" { + output.PrintWarning("Page created but could not retrieve ID to append uploaded local images") + } else if err := appendUploadedLocalImages(bgCtx, pageID, localUploads); err != nil { + output.PrintError(err) + return err + } + } displayTitle := title if icon != "" { @@ -265,7 +282,7 @@ func runPageUpload(ctx *Context, file, title, parent, parentDB, icon string) err if ctx.JSON { outPage := output.Page{ - ID: resp.ID, + ID: pageID, URL: resp.URL, Title: displayTitle, Icon: icon, @@ -392,6 +409,15 @@ func runPageSync(ctx *Context, file, title, parent, parentDB, icon string) error content := string(raw) fm, body := cli.ParseFrontmatter(content) + bgCtx := context.Background() + localUploads, err := maybeUploadLocalImages(bgCtx, file, body, "", "") + if err != nil { + output.PrintError(err) + return err + } + if len(localUploads) > 0 { + output.PrintInfo(fmt.Sprintf("Uploaded %d local image(s) via Notion REST file uploads", len(localUploads))) + } if title == "" { title = extractTitleFromMarkdown(body) @@ -409,8 +435,6 @@ func runPageSync(ctx *Context, file, title, parent, parentDB, icon string) error } defer func() { _ = client.Close() }() - bgCtx := context.Background() - if fm.NotionID != "" { req := mcp.UpdatePageRequest{ PageID: fm.NotionID, @@ -421,6 +445,12 @@ func runPageSync(ctx *Context, file, title, parent, parentDB, icon string) error output.PrintError(err) return err } + if len(localUploads) > 0 { + if err := appendUploadedLocalImages(bgCtx, fm.NotionID, localUploads); err != nil { + output.PrintError(err) + return err + } + } displayTitle := title if icon != "" { @@ -472,9 +502,14 @@ func runPageSync(ctx *Context, file, title, parent, parentDB, icon string) error return err } - pageID := resp.ID - if pageID == "" && resp.URL != "" { - pageID, _ = cli.ExtractNotionUUID(resp.URL) + pageID := pageIDFromCreateResponse(resp) + if len(localUploads) > 0 { + if pageID == "" { + output.PrintWarning("Page created but could not retrieve ID to append uploaded local images") + } else if err := appendUploadedLocalImages(bgCtx, pageID, localUploads); err != nil { + output.PrintError(err) + return err + } } if pageID == "" { output.PrintWarning("Page created but could not retrieve ID for frontmatter") @@ -511,3 +546,88 @@ func runPageSync(ctx *Context, file, title, parent, parentDB, icon string) error } return nil } + +type uploadedLocalImage struct { + Alt string + FileUploadID string + ResolvedPath string +} + +func maybeUploadLocalImages(ctx context.Context, sourceFile, markdown, assetBaseURL, _ string) ([]uploadedLocalImage, error) { + if strings.TrimSpace(assetBaseURL) != "" { + return nil, nil + } + + images, err := cli.FindLocalMarkdownImages(markdown, sourceFile) + if err != nil { + return nil, err + } + if len(images) == 0 { + return nil, nil + } + + apiClient, err := cli.RequireOfficialAPIClient() + if err != nil { + return nil, err + } + + uploadIDByPath := make(map[string]string, len(images)) + uploads := make([]uploadedLocalImage, 0, len(images)) + for _, image := range images { + uploadID, ok := uploadIDByPath[image.Resolved] + if !ok { + fileData, err := os.ReadFile(image.Resolved) + if err != nil { + return nil, fmt.Errorf("read local image %q: %w", image.Resolved, err) + } + uploadID, err = apiClient.UploadFile(ctx, filepath.Base(image.Resolved), fileData) + if err != nil { + return nil, fmt.Errorf("upload local image %q: %w", image.Resolved, err) + } + uploadIDByPath[image.Resolved] = uploadID + } + + uploads = append(uploads, uploadedLocalImage{ + Alt: image.Alt, + FileUploadID: uploadID, + ResolvedPath: image.Resolved, + }) + } + + return uploads, nil +} + +func appendUploadedLocalImages(ctx context.Context, pageID string, uploads []uploadedLocalImage) error { + if strings.TrimSpace(pageID) == "" || len(uploads) == 0 { + return nil + } + + apiClient, err := cli.RequireOfficialAPIClient() + if err != nil { + return err + } + + blocks := make([]api.UploadedImageBlock, 0, len(uploads)) + for _, upload := range uploads { + blocks = append(blocks, api.UploadedImageBlock{ + FileUploadID: upload.FileUploadID, + Caption: upload.Alt, + }) + } + + return apiClient.AppendUploadedImageBlocks(ctx, pageID, blocks) +} + +func pageIDFromCreateResponse(resp *mcp.CreatePageResponse) string { + if resp == nil { + return "" + } + if strings.TrimSpace(resp.ID) != "" { + return strings.TrimSpace(resp.ID) + } + if strings.TrimSpace(resp.URL) == "" { + return "" + } + id, _ := cli.ExtractNotionUUID(resp.URL) + return id +} diff --git a/cmd/page_local_images_test.go b/cmd/page_local_images_test.go new file mode 100644 index 0000000..7b71d4a --- /dev/null +++ b/cmd/page_local_images_test.go @@ -0,0 +1,116 @@ +package cmd + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestMaybeUploadLocalImagesSkipsWhenAssetBaseURLSet(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + + uploads, err := maybeUploadLocalImages(context.Background(), "/tmp/doc.md", "![A](./a.png)", "https://cdn.example.com/base", "") + if err != nil { + t.Fatalf("maybeUploadLocalImages: %v", err) + } + if len(uploads) != 0 { + t.Fatalf("expected no uploads, got %d", len(uploads)) + } +} + +func TestMaybeUploadLocalImagesUploadsAndDeduplicates(t *testing.T) { + tmp := t.TempDir() + docDir := filepath.Join(tmp, "docs") + if err := os.MkdirAll(filepath.Join(docDir, "assets"), 0o755); err != nil { + t.Fatalf("mkdir assets: %v", err) + } + img := filepath.Join(docDir, "assets", "diagram.png") + if err := os.WriteFile(img, []byte("PNGDATA"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + doc := filepath.Join(docDir, "guide.md") + markdown := "![One](./assets/diagram.png)\n![Two](./assets/diagram.png)\n" + + createCalls := 0 + sendCalls := 0 + getCalls := 0 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && r.URL.Path == "/v1/file_uploads": + createCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"upload_123","status":"pending"}`)) + return + + case r.Method == http.MethodPost && r.URL.Path == "/v1/file_uploads/upload_123/send": + sendCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"upload_123","status":"uploaded"}`)) + return + + case r.Method == http.MethodGet && r.URL.Path == "/v1/file_uploads/upload_123": + getCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"upload_123","status":"uploaded"}`)) + return + } + t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) + })) + defer srv.Close() + + t.Setenv("HOME", t.TempDir()) + t.Setenv("NOTION_API_BASE_URL", srv.URL+"/v1") + t.Setenv("NOTION_API_TOKEN", "test-token") + + uploads, err := maybeUploadLocalImages(context.Background(), doc, markdown, "", "") + if err != nil { + t.Fatalf("maybeUploadLocalImages: %v", err) + } + if len(uploads) != 2 { + t.Fatalf("len(uploads)=%d, want 2", len(uploads)) + } + if uploads[0].FileUploadID != "upload_123" || uploads[1].FileUploadID != "upload_123" { + t.Fatalf("unexpected upload ids: %#v", uploads) + } + if createCalls != 1 || sendCalls != 1 || getCalls != 1 { + t.Fatalf("unexpected call counts create=%d send=%d get=%d", createCalls, sendCalls, getCalls) + } +} + +func TestAppendUploadedLocalImages(t *testing.T) { + var gotBody map[string]any + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPatch || r.URL.Path != "/v1/blocks/page_123/children" { + t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) + } + defer func() { _ = r.Body.Close() }() + if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil { + t.Fatalf("decode request body: %v", err) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"object":"list","results":[]}`)) + })) + defer srv.Close() + + t.Setenv("HOME", t.TempDir()) + t.Setenv("NOTION_API_BASE_URL", srv.URL+"/v1") + t.Setenv("NOTION_API_TOKEN", "test-token") + + err := appendUploadedLocalImages(context.Background(), "page_123", []uploadedLocalImage{ + {Alt: "Diagram", FileUploadID: "upload_1"}, + }) + if err != nil { + t.Fatalf("appendUploadedLocalImages: %v", err) + } + + children, ok := gotBody["children"].([]any) + if !ok || len(children) != 1 { + t.Fatalf("children payload mismatch: %#v", gotBody["children"]) + } +} diff --git a/internal/api/client.go b/internal/api/client.go index acd1a0d..339e5d7 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -6,7 +6,9 @@ import ( "encoding/json" "fmt" "io" + "mime/multipart" "net/http" + "path/filepath" "strings" "time" @@ -16,6 +18,7 @@ import ( const ( defaultBaseURL = "https://api.notion.com/v1" defaultNotionAPIRev = "2022-06-28" + fileUploadAPIRev = "2025-09-03" ) type Client struct { @@ -25,6 +28,16 @@ type Client struct { token string } +type FileUpload struct { + ID string `json:"id"` + Status string `json:"status"` +} + +type UploadedImageBlock struct { + FileUploadID string + Caption string +} + func NewClient(cfg config.APIConfig, token string) (*Client, error) { token = strings.TrimSpace(token) if token == "" { @@ -62,7 +75,105 @@ func (c *Client) PatchPage(ctx context.Context, pageID string, patch map[string] return c.doJSON(ctx, http.MethodPatch, "/pages/"+pageID, patch, nil) } +func (c *Client) UploadFile(ctx context.Context, filename string, data []byte) (string, error) { + filename = strings.TrimSpace(filename) + if filename == "" { + return "", fmt.Errorf("filename is required") + } + if len(data) == 0 { + return "", fmt.Errorf("file data is required") + } + + filename = filepath.Base(filename) + + var created FileUpload + createPayload := map[string]any{ + "mode": "single_part", + "filename": filename, + } + if err := c.doJSONWithVersion(ctx, http.MethodPost, "/file_uploads", createPayload, &created, fileUploadAPIRev); err != nil { + return "", err + } + if strings.TrimSpace(created.ID) == "" { + return "", fmt.Errorf("create file upload failed: empty upload ID") + } + + sent, err := c.sendFileUploadPart(ctx, created.ID, filename, data) + if err != nil { + return "", err + } + + uploaded, err := c.waitForFileUploadUploaded(ctx, sent.ID) + if err != nil { + return "", err + } + return uploaded.ID, nil +} + +func (c *Client) GetFileUpload(ctx context.Context, fileUploadID string) (*FileUpload, error) { + fileUploadID = strings.TrimSpace(fileUploadID) + if fileUploadID == "" { + return nil, fmt.Errorf("file upload ID is required") + } + var out FileUpload + if err := c.doJSONWithVersion(ctx, http.MethodGet, "/file_uploads/"+fileUploadID, nil, &out, fileUploadAPIRev); err != nil { + return nil, err + } + if strings.TrimSpace(out.ID) == "" { + out.ID = fileUploadID + } + return &out, nil +} + +func (c *Client) AppendUploadedImageBlocks(ctx context.Context, parentID string, blocks []UploadedImageBlock) error { + parentID = strings.TrimSpace(parentID) + if parentID == "" { + return fmt.Errorf("parent ID is required") + } + if len(blocks) == 0 { + return nil + } + + children := make([]map[string]any, 0, len(blocks)) + for _, block := range blocks { + id := strings.TrimSpace(block.FileUploadID) + if id == "" { + return fmt.Errorf("file upload ID is required for image block") + } + + image := map[string]any{ + "type": "file_upload", + "file_upload": map[string]any{ + "id": id, + }, + } + if caption := strings.TrimSpace(block.Caption); caption != "" { + image["caption"] = []map[string]any{ + { + "type": "text", + "text": map[string]any{ + "content": caption, + }, + }, + } + } + + children = append(children, map[string]any{ + "object": "block", + "type": "image", + "image": image, + }) + } + + payload := map[string]any{"children": children} + return c.doJSONWithVersion(ctx, http.MethodPatch, "/blocks/"+parentID+"/children", payload, nil, fileUploadAPIRev) +} + func (c *Client) doJSON(ctx context.Context, method, path string, payload any, out any) error { + return c.doJSONWithVersion(ctx, method, path, payload, out, c.notionVersion) +} + +func (c *Client) doJSONWithVersion(ctx context.Context, method, path string, payload any, out any, notionVersion string) error { var bodyReader io.Reader if payload != nil { data, err := json.Marshal(payload) @@ -72,15 +183,23 @@ func (c *Client) doJSON(ctx context.Context, method, path string, payload any, o bodyReader = bytes.NewReader(data) } - req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, bodyReader) + contentType := "" + if payload != nil { + contentType = "application/json" + } + return c.doRequest(ctx, method, path, bodyReader, contentType, out, notionVersion) +} + +func (c *Client) doRequest(ctx context.Context, method, path string, body io.Reader, contentType string, out any, notionVersion string) error { + req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, body) if err != nil { return err } req.Header.Set("accept", "application/json") req.Header.Set("authorization", "Bearer "+c.token) - req.Header.Set("notion-version", c.notionVersion) - if payload != nil { - req.Header.Set("content-type", "application/json") + req.Header.Set("notion-version", notionVersion) + if contentType != "" { + req.Header.Set("content-type", contentType) } resp, err := c.httpClient.Do(req) @@ -117,3 +236,62 @@ func (c *Client) doJSON(ctx context.Context, method, path string, payload any, o } return nil } + +func (c *Client) sendFileUploadPart(ctx context.Context, fileUploadID, filename string, data []byte) (*FileUpload, error) { + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + part, err := writer.CreateFormFile("file", filename) + if err != nil { + return nil, fmt.Errorf("create multipart file part: %w", err) + } + if _, err := part.Write(data); err != nil { + return nil, fmt.Errorf("write multipart file data: %w", err) + } + if err := writer.Close(); err != nil { + return nil, fmt.Errorf("close multipart writer: %w", err) + } + + var out FileUpload + path := "/file_uploads/" + strings.TrimSpace(fileUploadID) + "/send" + if err := c.doRequest(ctx, http.MethodPost, path, bytes.NewReader(body.Bytes()), writer.FormDataContentType(), &out, fileUploadAPIRev); err != nil { + return nil, err + } + if strings.TrimSpace(out.ID) == "" { + out.ID = strings.TrimSpace(fileUploadID) + } + return &out, nil +} + +func (c *Client) waitForFileUploadUploaded(ctx context.Context, fileUploadID string) (*FileUpload, error) { + id := strings.TrimSpace(fileUploadID) + if id == "" { + return nil, fmt.Errorf("file upload ID is required") + } + + const maxChecks = 20 + for i := 0; i < maxChecks; i++ { + upload, err := c.GetFileUpload(ctx, id) + if err != nil { + return nil, err + } + + status := strings.ToLower(strings.TrimSpace(upload.Status)) + switch status { + case "", "uploaded": + return upload, nil + case "pending": + if i == maxChecks-1 { + return nil, fmt.Errorf("file upload %s did not reach uploaded status in time", id) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(250 * time.Millisecond): + } + default: + return nil, fmt.Errorf("file upload %s failed with status %q", id, upload.Status) + } + } + return nil, fmt.Errorf("file upload %s did not reach uploaded status in time", id) +} diff --git a/internal/api/client_test.go b/internal/api/client_test.go index 075081f..a9ce7f0 100644 --- a/internal/api/client_test.go +++ b/internal/api/client_test.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "io" "net/http" "net/http/httptest" "strings" @@ -106,3 +107,146 @@ func TestPatchPageReturnsAPIErrorMessage(t *testing.T) { t.Fatalf("expected unauthorized message, got: %v", err) } } + +func TestUploadFileSinglePart(t *testing.T) { + t.Parallel() + + var createVersion string + var sendVersion string + var getVersion string + var sentFileName string + var sentFileData string + getCalls := 0 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && r.URL.Path == "/file_uploads": + createVersion = r.Header.Get("Notion-Version") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"upload_123","status":"pending"}`)) + return + + case r.Method == http.MethodPost && r.URL.Path == "/file_uploads/upload_123/send": + sendVersion = r.Header.Get("Notion-Version") + if err := r.ParseMultipartForm(1 << 20); err != nil { + t.Fatalf("parse multipart form: %v", err) + } + file, hdr, err := r.FormFile("file") + if err != nil { + t.Fatalf("form file: %v", err) + } + defer func() { _ = file.Close() }() + data, err := io.ReadAll(file) + if err != nil { + t.Fatalf("read form file: %v", err) + } + sentFileName = hdr.Filename + sentFileData = string(data) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"upload_123","status":"pending"}`)) + return + + case r.Method == http.MethodGet && r.URL.Path == "/file_uploads/upload_123": + getVersion = r.Header.Get("Notion-Version") + getCalls++ + w.Header().Set("Content-Type", "application/json") + if getCalls == 1 { + _, _ = w.Write([]byte(`{"id":"upload_123","status":"pending"}`)) + } else { + _, _ = w.Write([]byte(`{"id":"upload_123","status":"uploaded"}`)) + } + return + } + + t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{BaseURL: srv.URL}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + id, err := client.UploadFile(context.Background(), "diagram.png", []byte("PNGDATA")) + if err != nil { + t.Fatalf("upload file: %v", err) + } + if id != "upload_123" { + t.Fatalf("upload id mismatch: got %q", id) + } + if sentFileName != "diagram.png" { + t.Fatalf("file name mismatch: got %q", sentFileName) + } + if sentFileData != "PNGDATA" { + t.Fatalf("file data mismatch: got %q", sentFileData) + } + if createVersion != "2025-09-03" || sendVersion != "2025-09-03" || getVersion != "2025-09-03" { + t.Fatalf("unexpected notion version headers: create=%q send=%q get=%q", createVersion, sendVersion, getVersion) + } +} + +func TestAppendUploadedImageBlocks(t *testing.T) { + t.Parallel() + + var gotMethod string + var gotPath string + var gotVersion string + var gotBody map[string]any + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + gotVersion = r.Header.Get("Notion-Version") + + defer func() { _ = r.Body.Close() }() + if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil { + t.Fatalf("decode request body: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"object":"list","results":[]}`)) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{BaseURL: srv.URL}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + err = client.AppendUploadedImageBlocks(context.Background(), "page_123", []UploadedImageBlock{ + {FileUploadID: "upload_1", Caption: "Diagram"}, + }) + if err != nil { + t.Fatalf("append uploaded image blocks: %v", err) + } + + if gotMethod != http.MethodPatch { + t.Fatalf("method mismatch: got %s", gotMethod) + } + if gotPath != "/blocks/page_123/children" { + t.Fatalf("path mismatch: got %s", gotPath) + } + if gotVersion != "2025-09-03" { + t.Fatalf("notion-version mismatch: got %s", gotVersion) + } + + children, ok := gotBody["children"].([]any) + if !ok || len(children) != 1 { + t.Fatalf("children payload mismatch: %#v", gotBody["children"]) + } + child, ok := children[0].(map[string]any) + if !ok { + t.Fatalf("child payload mismatch: %#v", children[0]) + } + image, ok := child["image"].(map[string]any) + if !ok { + t.Fatalf("image payload mismatch: %#v", child["image"]) + } + fileUpload, ok := image["file_upload"].(map[string]any) + if !ok { + t.Fatalf("file_upload payload mismatch: %#v", image["file_upload"]) + } + if fileUpload["id"] != "upload_1" { + t.Fatalf("file_upload id mismatch: got %#v", fileUpload["id"]) + } +} diff --git a/internal/cli/markdown_images.go b/internal/cli/markdown_images.go new file mode 100644 index 0000000..a475407 --- /dev/null +++ b/internal/cli/markdown_images.go @@ -0,0 +1,321 @@ +package cli + +import ( + "fmt" + "net/url" + "os" + "path/filepath" + "regexp" + "strings" +) + +type MarkdownImageRewriteOptions struct { + SourceFile string + AssetBaseURL string + AssetRoot string +} + +type MarkdownImageRewrite struct { + Original string + Resolved string + URL string +} + +type LocalMarkdownImage struct { + Alt string + Original string + Resolved string +} + +var markdownImageRE = regexp.MustCompile(`!\[([^\]]*)\]\(([^)\n]+)\)`) +var uriSchemeRE = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9+.-]*:`) + +// RewriteLocalMarkdownImages rewrites local markdown image links to absolute URLs. +// If AssetBaseURL is empty, markdown is returned unchanged. +func RewriteLocalMarkdownImages(markdown string, opts MarkdownImageRewriteOptions) (string, []MarkdownImageRewrite, error) { + if strings.TrimSpace(opts.AssetBaseURL) == "" { + return markdown, nil, nil + } + + baseURL, err := url.Parse(strings.TrimSpace(opts.AssetBaseURL)) + if err != nil || baseURL.Scheme == "" || baseURL.Host == "" { + return "", nil, fmt.Errorf("invalid asset base URL %q", opts.AssetBaseURL) + } + if baseURL.Scheme != "http" && baseURL.Scheme != "https" { + return "", nil, fmt.Errorf("asset base URL must use http or https: %q", opts.AssetBaseURL) + } + + sourceFileAbs, err := filepath.Abs(opts.SourceFile) + if err != nil { + return "", nil, fmt.Errorf("resolve source file path: %w", err) + } + sourceDir := filepath.Dir(sourceFileAbs) + + assetRootAbs := "" + if strings.TrimSpace(opts.AssetRoot) != "" { + assetRootAbs, err = filepath.Abs(opts.AssetRoot) + if err != nil { + return "", nil, fmt.Errorf("resolve asset root path: %w", err) + } + assetRootAbs = filepath.Clean(assetRootAbs) + } + + matches := markdownImageRE.FindAllStringSubmatchIndex(markdown, -1) + if len(matches) == 0 { + return markdown, nil, nil + } + + var out strings.Builder + out.Grow(len(markdown) + len(matches)*16) + + last := 0 + rewrites := make([]MarkdownImageRewrite, 0, len(matches)) + for _, m := range matches { + matchStart, matchEnd := m[0], m[1] + altStart, altEnd := m[2], m[3] + destStart, destEnd := m[4], m[5] + + out.WriteString(markdown[last:matchStart]) + + alt := markdown[altStart:altEnd] + rawDest := markdown[destStart:destEnd] + + dest, ok := parseMarkdownDestination(rawDest) + if !ok || !isLocalDestination(dest) { + out.WriteString(markdown[matchStart:matchEnd]) + last = matchEnd + continue + } + + originalDest := dest + resolvedPath, err := resolveLocalPath(dest, sourceDir) + if err != nil { + return "", nil, err + } + + info, err := os.Stat(resolvedPath) + if err != nil { + return "", nil, fmt.Errorf("local image %q not found (from %s): %w", originalDest, opts.SourceFile, err) + } + if info.IsDir() { + return "", nil, fmt.Errorf("local image %q resolves to a directory: %s", originalDest, resolvedPath) + } + + urlPath := buildURLPath(originalDest, resolvedPath, sourceDir, assetRootAbs) + assetURL := joinBaseURL(baseURL, urlPath) + + out.WriteString("![") + out.WriteString(alt) + out.WriteString("](") + out.WriteString(assetURL) + out.WriteString(")") + + rewrites = append(rewrites, MarkdownImageRewrite{ + Original: originalDest, + Resolved: resolvedPath, + URL: assetURL, + }) + last = matchEnd + } + + out.WriteString(markdown[last:]) + return out.String(), rewrites, nil +} + +// FindLocalMarkdownImages returns all local markdown image links in order. +func FindLocalMarkdownImages(markdown, sourceFile string) ([]LocalMarkdownImage, error) { + sourceFileAbs, err := filepath.Abs(sourceFile) + if err != nil { + return nil, fmt.Errorf("resolve source file path: %w", err) + } + sourceDir := filepath.Dir(sourceFileAbs) + + matches := markdownImageRE.FindAllStringSubmatchIndex(markdown, -1) + if len(matches) == 0 { + return nil, nil + } + + local := make([]LocalMarkdownImage, 0, len(matches)) + for _, m := range matches { + altStart, altEnd := m[2], m[3] + destStart, destEnd := m[4], m[5] + + alt := markdown[altStart:altEnd] + rawDest := markdown[destStart:destEnd] + + dest, ok := parseMarkdownDestination(rawDest) + if !ok || !isLocalDestination(dest) { + continue + } + + resolvedPath, err := resolveLocalPath(dest, sourceDir) + if err != nil { + return nil, err + } + + info, err := os.Stat(resolvedPath) + if err != nil { + return nil, fmt.Errorf("local image %q not found (from %s): %w", dest, sourceFile, err) + } + if info.IsDir() { + return nil, fmt.Errorf("local image %q resolves to a directory: %s", dest, resolvedPath) + } + + local = append(local, LocalMarkdownImage{ + Alt: alt, + Original: dest, + Resolved: resolvedPath, + }) + } + + return local, nil +} + +func parseMarkdownDestination(raw string) (string, bool) { + s := strings.TrimSpace(raw) + if s == "" { + return "", false + } + + if strings.HasPrefix(s, "<") { + end := strings.Index(s, ">") + if end > 1 { + return s[1:end], true + } + } + + escaped := false + for i, r := range s { + if escaped { + escaped = false + continue + } + if r == '\\' { + escaped = true + continue + } + if r == ' ' || r == '\t' || r == '\n' || r == '\r' { + s = s[:i] + break + } + } + + s = strings.TrimSpace(s) + if s == "" { + return "", false + } + return s, true +} + +func isLocalDestination(dest string) bool { + d := strings.TrimSpace(dest) + if d == "" { + return false + } + + // Windows absolute paths like C:\foo are local paths. + if len(d) >= 2 && d[1] == ':' { + return true + } + + lower := strings.ToLower(d) + switch { + case strings.HasPrefix(lower, "#"): + return false + case strings.HasPrefix(lower, "http://"), + strings.HasPrefix(lower, "https://"), + strings.HasPrefix(lower, "mailto:"), + strings.HasPrefix(lower, "tel:"), + strings.HasPrefix(lower, "data:"): + return false + case strings.HasPrefix(lower, "file://"): + return true + } + + return !uriSchemeRE.MatchString(d) +} + +func resolveLocalPath(dest, sourceDir string) (string, error) { + d := strings.TrimSpace(dest) + if strings.HasPrefix(strings.ToLower(d), "file://") { + parsed, err := url.Parse(d) + if err != nil { + return "", fmt.Errorf("invalid file URL %q: %w", d, err) + } + unescaped, err := url.PathUnescape(parsed.Path) + if err != nil { + return "", fmt.Errorf("invalid file URL path %q: %w", d, err) + } + d = unescaped + } + + if strings.HasPrefix(d, "~"+string(filepath.Separator)) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("expand home path %q: %w", d, err) + } + d = filepath.Join(home, strings.TrimPrefix(d, "~"+string(filepath.Separator))) + } + + if !filepath.IsAbs(d) { + d = filepath.Join(sourceDir, d) + } + + abs, err := filepath.Abs(d) + if err != nil { + return "", fmt.Errorf("resolve local path %q: %w", dest, err) + } + return filepath.Clean(abs), nil +} + +func buildURLPath(originalDest, resolvedPath, sourceDir, assetRootAbs string) string { + if assetRootAbs != "" { + if rel, ok := relativeInside(assetRootAbs, resolvedPath); ok { + return rel + } + } + + if !filepath.IsAbs(originalDest) && !strings.HasPrefix(strings.ToLower(originalDest), "file://") { + return filepath.ToSlash(filepath.Clean(originalDest)) + } + + if rel, ok := relativeInside(sourceDir, resolvedPath); ok { + return rel + } + + return filepath.Base(resolvedPath) +} + +func relativeInside(root, target string) (string, bool) { + rel, err := filepath.Rel(root, target) + if err != nil { + return "", false + } + rel = filepath.Clean(rel) + if rel == "." { + return "", true + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", false + } + return filepath.ToSlash(rel), true +} + +func joinBaseURL(base *url.URL, relPath string) string { + u := *base + basePath := strings.TrimSuffix(u.Path, "/") + if relPath == "" { + if basePath == "" { + u.Path = "/" + } else { + u.Path = basePath + } + return u.String() + } + if basePath == "" { + u.Path = "/" + strings.TrimPrefix(relPath, "/") + } else { + u.Path = basePath + "/" + strings.TrimPrefix(relPath, "/") + } + return u.String() +} diff --git a/internal/cli/markdown_images_test.go b/internal/cli/markdown_images_test.go new file mode 100644 index 0000000..6c2f549 --- /dev/null +++ b/internal/cli/markdown_images_test.go @@ -0,0 +1,186 @@ +package cli + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestRewriteLocalMarkdownImages(t *testing.T) { + tmp := t.TempDir() + docDir := filepath.Join(tmp, "docs") + assetsDir := filepath.Join(docDir, "assets") + if err := os.MkdirAll(assetsDir, 0o755); err != nil { + t.Fatalf("mkdir assets: %v", err) + } + + imagePath := filepath.Join(assetsDir, "diagram.png") + if err := os.WriteFile(imagePath, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + docFile := filepath.Join(docDir, "guide.md") + + md := "Intro\n\n![Diagram](./assets/diagram.png \"caption\")\n![Remote](https://example.com/x.png)\n" + got, rewrites, err := RewriteLocalMarkdownImages(md, MarkdownImageRewriteOptions{ + SourceFile: docFile, + AssetBaseURL: "https://assets.example.com/notion", + }) + if err != nil { + t.Fatalf("RewriteLocalMarkdownImages() error: %v", err) + } + + if len(rewrites) != 1 { + t.Fatalf("rewrites len = %d, want 1", len(rewrites)) + } + if rewrites[0].Resolved != imagePath { + t.Fatalf("resolved = %q, want %q", rewrites[0].Resolved, imagePath) + } + if !strings.Contains(got, "![Diagram](https://assets.example.com/notion/assets/diagram.png)") { + t.Fatalf("expected rewritten local image, got: %q", got) + } + if !strings.Contains(got, "![Remote](https://example.com/x.png)") { + t.Fatalf("expected remote image untouched, got: %q", got) + } +} + +func TestRewriteLocalMarkdownImages_AssetRoot(t *testing.T) { + tmp := t.TempDir() + assetRoot := filepath.Join(tmp, "render") + nested := filepath.Join(assetRoot, "images") + if err := os.MkdirAll(nested, 0o755); err != nil { + t.Fatalf("mkdir nested: %v", err) + } + + imagePath := filepath.Join(nested, "chart 1.png") + if err := os.WriteFile(imagePath, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + docFile := filepath.Join(tmp, "doc.md") + md := "![Chart](<./render/images/chart 1.png>)\n" + got, rewrites, err := RewriteLocalMarkdownImages(md, MarkdownImageRewriteOptions{ + SourceFile: docFile, + AssetBaseURL: "https://cdn.example.com/base/", + AssetRoot: assetRoot, + }) + if err != nil { + t.Fatalf("RewriteLocalMarkdownImages() error: %v", err) + } + + if len(rewrites) != 1 { + t.Fatalf("rewrites len = %d, want 1", len(rewrites)) + } + + want := "![Chart](https://cdn.example.com/base/images/chart%201.png)\n" + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestRewriteLocalMarkdownImages_FileURL(t *testing.T) { + tmp := t.TempDir() + imagePath := filepath.Join(tmp, "socket.png") + if err := os.WriteFile(imagePath, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + docFile := filepath.Join(tmp, "doc.md") + fileURL := "file://" + filepath.ToSlash(imagePath) + md := "![Socket](" + fileURL + ")\n" + got, rewrites, err := RewriteLocalMarkdownImages(md, MarkdownImageRewriteOptions{ + SourceFile: docFile, + AssetBaseURL: "https://assets.example.com", + }) + if err != nil { + t.Fatalf("RewriteLocalMarkdownImages() error: %v", err) + } + + if len(rewrites) != 1 { + t.Fatalf("rewrites len = %d, want 1", len(rewrites)) + } + if got != "![Socket](https://assets.example.com/socket.png)\n" { + t.Fatalf("unexpected rewrite: %q", got) + } +} + +func TestRewriteLocalMarkdownImages_MissingFile(t *testing.T) { + tmp := t.TempDir() + docFile := filepath.Join(tmp, "doc.md") + md := "![Missing](./missing.png)\n" + _, _, err := RewriteLocalMarkdownImages(md, MarkdownImageRewriteOptions{ + SourceFile: docFile, + AssetBaseURL: "https://assets.example.com", + }) + if err == nil { + t.Fatal("expected error for missing file") + } +} + +func TestRewriteLocalMarkdownImages_NoBaseURL(t *testing.T) { + md := "![Local](./img.png)\n" + got, rewrites, err := RewriteLocalMarkdownImages(md, MarkdownImageRewriteOptions{ + SourceFile: "/tmp/doc.md", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(rewrites) != 0 { + t.Fatalf("rewrites len = %d, want 0", len(rewrites)) + } + if got != md { + t.Fatalf("got %q, want %q", got, md) + } +} + +func TestFindLocalMarkdownImages(t *testing.T) { + tmp := t.TempDir() + docDir := filepath.Join(tmp, "docs") + if err := os.MkdirAll(filepath.Join(docDir, "assets"), 0o755); err != nil { + t.Fatalf("mkdir assets: %v", err) + } + + img1 := filepath.Join(docDir, "assets", "diagram.png") + if err := os.WriteFile(img1, []byte("png"), 0o644); err != nil { + t.Fatalf("write image1: %v", err) + } + img2 := filepath.Join(docDir, "assets", "chart.jpg") + if err := os.WriteFile(img2, []byte("jpg"), 0o644); err != nil { + t.Fatalf("write image2: %v", err) + } + + docFile := filepath.Join(docDir, "guide.md") + md := "![Diagram](./assets/diagram.png)\n![Remote](https://example.com/r.png)\n![Chart](./assets/chart.jpg)\n" + + got, err := FindLocalMarkdownImages(md, docFile) + if err != nil { + t.Fatalf("FindLocalMarkdownImages() error: %v", err) + } + if len(got) != 2 { + t.Fatalf("len(got)=%d, want 2", len(got)) + } + if got[0].Resolved != img1 { + t.Fatalf("first resolved = %q, want %q", got[0].Resolved, img1) + } + if got[0].Alt != "Diagram" { + t.Fatalf("first alt = %q, want %q", got[0].Alt, "Diagram") + } + if got[1].Resolved != img2 { + t.Fatalf("second resolved = %q, want %q", got[1].Resolved, img2) + } +} + +func TestFindLocalMarkdownImages_MissingFile(t *testing.T) { + tmp := t.TempDir() + docFile := filepath.Join(tmp, "doc.md") + md := "![Missing](./missing.png)\n" + + _, err := FindLocalMarkdownImages(md, docFile) + if err == nil { + t.Fatal("expected error for missing local file") + } + if !strings.Contains(err.Error(), "not found") { + t.Fatalf("unexpected error: %v", err) + } +}