diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0ffa82c37..37745d6ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,7 +63,7 @@ jobs: - name: Fetch meta data run: python3 scripts/fetch_meta.py - name: Run tests - run: go test -v -race -count=1 -timeout=5m ./cmd/... ./internal/... ./shortcuts/... + run: go test -v -race -count=1 -timeout=5m ./cmd/... ./internal/... ./shortcuts/... ./extension/... lint: needs: fast-gate diff --git a/.gitignore b/.gitignore index 90313e480..437052468 100644 --- a/.gitignore +++ b/.gitignore @@ -34,8 +34,13 @@ tests/mail/reports/ # Generated / test artifacts .hammer/ +.lark-slides/ internal/registry/meta_data.json cmd/api/download.bin app.log /sidecar-server-demo /server-demo +.tmp/ +cover*.out + +lark-env.sh diff --git a/.gitleaks.toml b/.gitleaks.toml index 597b33952..8dbe4165f 100644 --- a/.gitleaks.toml +++ b/.gitleaks.toml @@ -14,3 +14,4 @@ id = "lark-session-token" description = "Detect Lark session tokens" regex = '''\bXN0YXJ0-[A-Za-z0-9_-]+-WVuZA\b''' keywords = ["XN0YXJ0-", "-WVuZA"] + diff --git a/.golangci.yml b/.golangci.yml index 60526407f..c15ebe084 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -45,6 +45,7 @@ linters: - path: _test\.go$ linters: - bodyclose + - bidichk - gocritic - depguard - forbidigo diff --git a/CHANGELOG.md b/CHANGELOG.md index 8af91bad8..b7b98db45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,77 @@ All notable changes to this project will be documented in this file. +## [v1.0.34] - 2026-05-19 + +### Features + +- **drive**: Switch markdown export to V2 `docs_ai` fetch API (#948) +- **drive**: Add `+inspect` shortcut for document URL inspection with wiki unwrapping (#947) +- **wiki**: Add `+node-get` / `+node-delete` / `+space-create` shortcuts (#904) +- **base**: Support Base attachment APIs (#887) +- **mail**: Validate `bot` + `mailbox=me` and add dynamic `--as` help tests (#895) +- **mail**: Expose draft priority in `--inspect` projection and document `--set-priority` (#779) + +### Bug Fixes + +- **identitydiag**: Harden verify path and tighten status semantics (#961) +- **wiki**: Surface real node URL for `+node-create` / `+node-copy` (#960) +- **auth**: Split bot and user identity diagnostics (#957) +- **base**: Address Base attachment review follow-ups (#958) +- **docs**: Clarify `replace_all` selection errors (#954) + +### Documentation + +- **drive**: Clarify add comment constraints (#967) +- **lark-im**: Clarify message activity search (#865) + +### Tests + +- Verify e2e resource cleanup (#949) +- **lint**: Exclude `bidichk` from test files (#959) + +## [v1.0.33] - 2026-05-18 + +### Features + +- **markdown**: Add `+patch` shortcut (#857) +- **slides**: Improve slide planning and validation guidance (#847) +- **drive**: Add `+sync` workflow for Drive directories (#873) +- **drive**: Add drive version shortcut (#841) +- **extension**: Plugin / Hook framework with command pruning (#910) + +### Bug Fixes + +- **sheets**: Explicitly document safe JSON unmarshal ignore in `DryRun` (#935) +- **base**: Mark base field update high risk (#936) +- **auth**: Guide agents to yield during auth device flow (#933) + +### Documentation + +- **lark-wiki**: Correct the `--as` default-identity claim (#919) + +### Tests + +- Drop stale e2e `--yes` flags (#920) + +## [v1.0.32] - 2026-05-15 + +### Features + +- **doc**: Add `--width`/`--height` flags to `docs +media-insert` (#832) +- **wiki**: Add `+space-list` / `+node-list` / `+node-copy` shortcuts (#392) + +### Bug Fixes + +- **drive**: Preserve parent token on nested overwrite (#908) +- **selfupdate**: Use `LookPath` instead of `Executable` for binary verification (#886) +- **registry**: Wait for background meta refresh before test reset (#894) + +### Documentation + +- **doc**: Add SVG whiteboard support to `lark-doc` v2 skill (#901) +- **drive**: Add permission public patch error guidance (#863) + ## [v1.0.31] - 2026-05-14 ### Features @@ -703,6 +774,9 @@ Bundled AI agent skills for intelligent assistance: - Bilingual documentation (English & Chinese). - CI/CD pipelines: linting, testing, coverage reporting, and automated releases. +[v1.0.34]: https://github.com/larksuite/cli/releases/tag/v1.0.34 +[v1.0.33]: https://github.com/larksuite/cli/releases/tag/v1.0.33 +[v1.0.32]: https://github.com/larksuite/cli/releases/tag/v1.0.32 [v1.0.31]: https://github.com/larksuite/cli/releases/tag/v1.0.31 [v1.0.30]: https://github.com/larksuite/cli/releases/tag/v1.0.30 [v1.0.29]: https://github.com/larksuite/cli/releases/tag/v1.0.29 diff --git a/Makefile b/Makefile index 7d78c510b..14480aeec 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,9 @@ DATE := $(shell date +%Y-%m-%d) LDFLAGS := -s -w -X $(MODULE)/internal/build.Version=$(VERSION) -X $(MODULE)/internal/build.Date=$(DATE) PREFIX ?= /usr/local -.PHONY: build vet test unit-test integration-test install uninstall clean fetch_meta +.PHONY: all build vet fmt-check test unit-test integration-test examples-build install uninstall clean fetch_meta gitleaks + +all: test fetch_meta: python3 scripts/fetch_meta.py @@ -19,13 +21,32 @@ build: fetch_meta vet: fetch_meta go vet ./... +# fmt-check fails when any file would be reformatted by gofmt. Keep this +# in sync with the fast-gate "Check formatting" step in CI. +fmt-check: + @unformatted=$$(gofmt -l . | grep -v '^\.claude/' || true); \ + if [ -n "$$unformatted" ]; then \ + echo "Unformatted Go files:"; \ + echo "$$unformatted"; \ + echo "Run 'gofmt -w .' and commit."; \ + exit 1; \ + fi + +# ./extension/... keeps the public plugin SDK in the default test matrix. unit-test: fetch_meta - go test -race -gcflags="all=-N -l" -count=1 ./cmd/... ./internal/... ./shortcuts/... + go test -race -gcflags="all=-N -l" -count=1 \ + ./cmd/... ./internal/... ./shortcuts/... ./extension/... + +# examples-build keeps the shipped plugin-SDK examples compilable. If this +# breaks, the plugin author guide's "go build ./..." path is broken. +examples-build: + go build ./extension/platform/examples/audit-observer + go build ./extension/platform/examples/readonly-policy integration-test: build go test -v -count=1 ./tests/... -test: vet unit-test integration-test +test: vet fmt-check unit-test examples-build integration-test install: build install -d $(PREFIX)/bin @@ -37,3 +58,13 @@ uninstall: clean: rm -f $(BINARY) + +# Run secret-leak checks locally before pushing. +# Step 1: check-doc-tokens catches realistic-looking example tokens in reference +# docs and asks you to use _EXAMPLE_TOKEN placeholders instead. +# Step 2: gitleaks scans the full repo for real leaked secrets. +# Install gitleaks: https://github.com/gitleaks/gitleaks#installing +gitleaks: + @bash scripts/check-doc-tokens.sh + @command -v gitleaks >/dev/null 2>&1 || { echo "gitleaks not found. Install: brew install gitleaks"; exit 1; } + gitleaks detect --redact -v --exit-code=2 diff --git a/README.md b/README.md index d6ff3ed70..4902dcdfa 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ The official [Lark/Feishu](https://www.larksuite.com/) CLI tool, maintained by t | 💬 Messenger | Send/reply messages, create and manage group chats, view chat history & threads, search messages, download media | | 📄 Docs | Create, read, update, and search documents, read/write media & whiteboards | | 📁 Drive | Upload and download files, search docs & wiki, manage comments | -| 📝 Markdown | Create, fetch, and overwrite Drive-native `.md` files | +| 📝 Markdown | Create, fetch, patch, and overwrite Drive-native `.md` files | | 📊 Base | Create and manage tables, fields, records, views, dashboards, workflows, forms, roles & permissions, data aggregation & analytics | | 📈 Sheets | Create, read, write, append, find, and export spreadsheet data | | 🖼️ Slides | Create and manage presentations, read presentation content, and add or remove slides | @@ -132,7 +132,7 @@ lark-cli auth status | `lark-im` | Send/reply messages, group chat management, message search, upload/download images & files, reactions | | `lark-doc` | Create, read, update, search documents (Markdown-based) | | `lark-drive` | Upload, download files, manage permissions & comments | -| `lark-markdown` | Create, fetch, and overwrite Drive-native Markdown files | +| `lark-markdown` | Create, fetch, patch, and overwrite Drive-native Markdown files | | `lark-sheets` | Create, read, write, append, find, export spreadsheets | | `lark-slides` | Create and manage presentations, read presentation content, and add or remove slides | | `lark-base` | Tables, fields, records, views, dashboards, data aggregation & analytics | diff --git a/README.zh.md b/README.zh.md index 2f9b7558b..b9869090b 100644 --- a/README.zh.md +++ b/README.zh.md @@ -28,7 +28,7 @@ | 💬 即时通讯 | 发送/回复消息、创建和管理群聊、查看聊天记录与话题、搜索消息、下载媒体文件 | | 📄 云文档 | 创建、读取、更新文档、搜索文档、读写素材与画板 | | 📁 云空间 | 上传和下载文件、搜索文档与知识库、管理评论 | -| 📝 Markdown | 创建、读取、覆盖更新 Drive 中的原生 `.md` 文件 | +| 📝 Markdown | 创建、读取、局部 patch、覆盖更新 Drive 中的原生 `.md` 文件 | | 📊 多维表格 | 创建和管理数据表、字段、记录、视图、仪表盘、自动化流程、表单、角色权限,数据聚合分析 | | 📈 电子表格 | 创建、读取、写入、追加、查找和导出表格数据 | | 🖼️ 幻灯片 | 创建和管理演示文稿、读取演示文稿内容,以及新增或删除幻灯片页面 | @@ -133,7 +133,7 @@ lark-cli auth status | `lark-im` | 发送/回复消息、群聊管理、消息搜索、上传下载图片与文件、表情回复 | | `lark-doc` | 创建、读取、更新、搜索文档(基于 Markdown) | | `lark-drive` | 上传、下载文件,管理权限与评论 | -| `lark-markdown` | 创建、读取、覆盖更新 Drive 中的原生 Markdown 文件 | +| `lark-markdown` | 创建、读取、局部 patch、覆盖更新 Drive 中的原生 Markdown 文件 | | `lark-sheets` | 创建、读取、写入、追加、查找、导出电子表格 | | `lark-slides` | 创建和管理演示文稿、读取演示文稿内容,以及新增或删除幻灯片页面 | | `lark-base` | 多维表格、字段、记录、视图、仪表盘、数据聚合分析 | diff --git a/cmd/api/api.go b/cmd/api/api.go index 83e963059..f5676e9b1 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -103,6 +103,7 @@ func NewCmdApiWithContext(ctx context.Context, f *cmdutil.Factory, runF func(*AP cmdutil.RegisterFlagCompletion(cmd, "format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "ndjson", "table", "csv"}, cobra.ShellCompDirectiveNoFileComp }) + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/auth/auth_test.go b/cmd/auth/auth_test.go index 41a775145..c2b1940ff 100644 --- a/cmd/auth/auth_test.go +++ b/cmd/auth/auth_test.go @@ -44,6 +44,32 @@ func TestAuthLoginCmd_FlagParsing(t *testing.T) { } } +func TestAuthLoginCmd_HelpGuidesNonStreamingAgentsToSplitFlow(t *testing.T) { + f, stdout, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, + }) + + cmd := NewCmdAuthLogin(f, func(opts *LoginOptions) error { return nil }) + cmd.SetOut(stdout) + cmd.SetErr(io.Discard) + cmd.SetArgs([]string{"--help"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + got := stdout.String() + for _, want := range []string{ + "only delivers final turn messages", + "--no-wait --json", + "send the verification URL to the user as your final message", + "run --device-code in a later step", + } { + if !strings.Contains(got, want) { + t.Fatalf("help missing %q, got:\n%s", want, got) + } + } +} + func TestAuthCheckCmd_FlagParsing(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, diff --git a/cmd/auth/check.go b/cmd/auth/check.go index 5f0bd0f48..2dd0652e8 100644 --- a/cmd/auth/check.go +++ b/cmd/auth/check.go @@ -37,6 +37,7 @@ func NewCmdAuthCheck(f *cmdutil.Factory, runF func(*CheckOptions) error) *cobra. cmd.Flags().StringVar(&opts.Scope, "scope", "", "scopes to check (space-separated)") cmd.MarkFlagRequired("scope") + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/auth/list.go b/cmd/auth/list.go index 2cb1b778e..ff682f824 100644 --- a/cmd/auth/list.go +++ b/cmd/auth/list.go @@ -34,6 +34,7 @@ func NewCmdAuthList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Co return authListRun(opts) }, } + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 6a27ee6cf..02888c98e 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -47,10 +47,12 @@ func NewCmdAuthLogin(f *cmdutil.Factory, runF func(*LoginOptions) error) *cobra. Long: `Device Flow authorization login. For AI agents: this command blocks until the user completes authorization in the -browser. Run it in the background and retrieve the verification URL from its output.`, +browser. If your harness only delivers final turn messages, use --no-wait --json, +send the verification URL to the user as your final message, end the turn, then +run --device-code in a later step after the user confirms authorization.`, RunE: func(cmd *cobra.Command, args []string) error { if mode := f.ResolveStrictMode(cmd.Context()); mode == core.StrictModeBot { - return output.ErrWithHint(output.ExitValidation, "strict_mode", + return output.ErrWithHint(output.ExitValidation, "command_denied", fmt.Sprintf("strict mode is %q, user login is disabled in this profile", mode), "if the user explicitly wants to switch to user identity, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)") } @@ -62,6 +64,7 @@ browser. Run it in the background and retrieve the verification URL from its out }, } cmdutil.SetSupportedIdentities(cmd, []string{"user"}) + cmdutil.SetRisk(cmd, "write") cmd.Flags().StringVar(&opts.Scope, "scope", "", "scopes to request (space- or comma-separated). Combines additively with --domain/--recommend") cmd.Flags().BoolVar(&opts.Recommend, "recommend", false, "request only recommended (auto-approve) scopes") @@ -187,7 +190,7 @@ func authLoginRun(opts *LoginOptions) error { log("View all options:") log(msg.HintFooter) log("") - log("Note: this command blocks until authorization is complete. Run it in the background and retrieve the verification URL from its output.") + log("Note: this command blocks until authorization is complete. For non-streaming agent harnesses, use --no-wait --json, send the verification URL as the final message of the turn, then run --device-code in a later step after the user confirms authorization.") return output.ErrValidation("please specify the scopes to authorize") } } @@ -266,7 +269,7 @@ func authLoginRun(opts *LoginOptions) error { "verification_url": authResp.VerificationUriComplete, "device_code": authResp.DeviceCode, "expires_in": authResp.ExpiresIn, - "hint": fmt.Sprintf("Show verification_url to the user exactly as returned by the CLI and treat it as an opaque string. Do not URL-encode or decode it, do not normalize or rewrite it, do not add %%20, spaces, or punctuation, and do not wrap it as Markdown link text; prefer a fenced code block containing only the raw URL. Then immediately execute: lark-cli auth login --device-code %s (blocks until authorized or timeout). Do not instruct the user to run this command themselves.", authResp.DeviceCode), + "hint": fmt.Sprintf("Show verification_url to the user exactly as returned by the CLI and treat it as an opaque string. Do not URL-encode or decode it, do not normalize or rewrite it, do not add %%20, spaces, or punctuation, and do not wrap it as Markdown link text; prefer a fenced code block containing only the raw URL. For agent harnesses that only deliver final turn messages, make the URL the final message of the turn and return control to the user; do not block on --device-code in the same turn. After the user confirms authorization in a later step, run: lark-cli auth login --device-code %s", authResp.DeviceCode), } encoder := json.NewEncoder(f.IOStreams.Out) encoder.SetEscapeHTML(false) diff --git a/cmd/auth/login_messages.go b/cmd/auth/login_messages.go index 548d704b2..189c42740 100644 --- a/cmd/auth/login_messages.go +++ b/cmd/auth/login_messages.go @@ -59,7 +59,7 @@ var loginMsgZh = &loginMsg{ OpenURL: "在浏览器中打开以下链接进行认证:\n\n", WaitingAuth: "等待用户授权...", - AgentTimeoutHint: "[AI agent] 此命令最长阻塞约 10 分钟,等待用户在浏览器内完成授权。请确保 runner 的 timeout >= 600s;如不支持长 timeout,请改用 `lark-cli auth login --no-wait --json` 拿到 device_code 后再用 `lark-cli auth login --device-code ` 续上轮询。**不要短 timeout 反复重试**,每次重启会作废上一轮的 device code,导致用户授权链接失效。向用户展示授权链接时,必须逐字原样转发 CLI 返回的 URL,把它视为不可修改的 opaque string;不要做 URL 编码或解码,不要补 `%20`、空格或标点,不要改写成 Markdown 链接,建议用只包含该 URL 的代码块单独输出。", + AgentTimeoutHint: "[AI agent] 此命令最长阻塞约 10 分钟,等待用户在浏览器内完成授权。请确保 runner 的 timeout >= 600s。若你的 harness 只会把最终回复发给用户,请改用 `lark-cli auth login --no-wait --json` 拿到 device_code 和 verification_url,把 verification_url 作为本轮最终消息原样发给用户并结束本轮;等用户回复已完成授权后,再在后续步骤运行 `lark-cli auth login --device-code ` 续上轮询。**不要在同一轮里展示 URL 后立刻阻塞执行 --device-code**,也不要短 timeout 反复重试;每次重启会作废上一轮的 device code,导致用户授权链接失效。向用户展示授权链接时,必须逐字原样转发 CLI 返回的 URL,把它视为不可修改的 opaque string;不要做 URL 编码或解码,不要补 `%20`、空格或标点,不要改写成 Markdown 链接,建议用只包含该 URL 的代码块单独输出。", AuthSuccess: "已收到授权确认,正在获取用户信息并校验授权结果...", LoginSuccess: "授权成功! 用户: %s (%s)", AuthorizedUser: "当前授权账号: %s (%s)", @@ -95,7 +95,7 @@ var loginMsgEn = &loginMsg{ OpenURL: "Open this URL in your browser to authenticate:\n\n", WaitingAuth: "Waiting for user authorization...", - AgentTimeoutHint: "[AI agent] This command blocks for up to ~10 minutes while waiting for the user to authorize in their browser. Make sure your runner's timeout is >= 600s. If long timeouts are not supported, use `lark-cli auth login --no-wait --json` to get a device_code, then `lark-cli auth login --device-code ` to resume polling. **Do NOT retry with a short timeout**; each restart invalidates the previous device code and makes the earlier authorization URL useless. When showing the authorization URL to the user, copy the CLI-returned URL exactly as-is and treat it as an opaque string. Do not URL-encode or decode it, do not add `%20`, spaces, or punctuation, do not rewrite it as Markdown link text, and prefer a fenced code block containing only the raw URL.", + AgentTimeoutHint: "[AI agent] This command blocks for up to ~10 minutes while waiting for the user to authorize in their browser. Make sure your runner's timeout is >= 600s. If your harness only delivers final turn messages, use `lark-cli auth login --no-wait --json` to get device_code and verification_url, present verification_url to the user exactly as the final message of this turn, then end the turn; after the user replies that they authorized, run `lark-cli auth login --device-code ` in a later step to resume polling. **Do NOT show the URL and then immediately block on --device-code in the same turn**, and do not retry with a short timeout; each restart invalidates the previous device code and makes the earlier authorization URL useless. When showing the authorization URL to the user, copy the CLI-returned URL exactly as-is and treat it as an opaque string. Do not URL-encode or decode it, do not add `%20`, spaces, or punctuation, do not rewrite it as Markdown link text, and prefer a fenced code block containing only the raw URL.", AuthSuccess: "Authorization confirmed, fetching user info and validating granted scopes...", LoginSuccess: "Authorization successful! User: %s (%s)", AuthorizedUser: "Authorized account: %s (%s)", diff --git a/cmd/auth/login_messages_test.go b/cmd/auth/login_messages_test.go index 9471f344b..3c5cc1c88 100644 --- a/cmd/auth/login_messages_test.go +++ b/cmd/auth/login_messages_test.go @@ -97,16 +97,17 @@ func TestLoginMsg_FormatStrings(t *testing.T) { } // TestAgentTimeoutHint_CarriesKeyInfo guards the contract that the synchronous -// auth-login output tells AI agents two things: (a) this command blocks for -// minutes — set a long runner timeout, and (b) the alternative is the -// --no-wait + --device-code split-flow. Without (a) AI sets a 10s timeout and -// kills the process before the user can authorize; without (b) the AI has no -// recovery path and just retries with the same short timeout, invalidating -// each new device code in turn. +// auth-login output tells AI agents three things: (a) this command blocks for +// minutes — set a long runner timeout, (b) the alternative is the --no-wait + +// --device-code split-flow, and (c) non-streaming harnesses must end the turn +// after presenting the URL instead of blocking in the same turn. func TestAgentTimeoutHint_CarriesKeyInfo(t *testing.T) { for _, lang := range []string{"zh", "en"} { hint := getLoginMsg(lang).AgentTimeoutHint - for _, want := range []string{"--no-wait", "--device-code"} { + for _, want := range []string{"--no-wait", "--device-code", "turn"} { + if lang == "zh" && want == "turn" { + want = "本轮" + } if !strings.Contains(hint, want) { t.Errorf("%s AgentTimeoutHint missing %q: %s", lang, want, hint) } diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 8687d313a..51ebdb9d9 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -315,10 +315,12 @@ func TestAuthLoginRun_NonTerminal_NoFlags_RejectsWithHint(t *testing.T) { if !strings.Contains(msg, "scopes") { t.Errorf("expected error to mention scopes, got: %s", msg) } - // Stderr should contain background hint + // Stderr should explain the split-flow path for non-streaming agents. stderrStr := stderr.String() - if !strings.Contains(stderrStr, "background") { - t.Errorf("expected stderr to mention background, got: %s", stderrStr) + for _, want := range []string{"--no-wait --json", "final message of the turn", "--device-code"} { + if !strings.Contains(stderrStr, want) { + t.Errorf("expected stderr to mention %q, got: %s", want, stderrStr) + } } } @@ -949,11 +951,24 @@ func TestAuthLoginRun_NoWaitJSONHintIncludesRawURLGuidance(t *testing.T) { "do not add %20, spaces, or punctuation", "do not wrap it as Markdown link text", "fenced code block containing only the raw URL", + "final message of the turn", + "return control to the user", + "do not block on --device-code in the same turn", + "After the user confirms authorization in a later step", + "lark-cli auth login --device-code device-code", } { if !strings.Contains(hint, want) { t.Fatalf("hint missing %q, got:\n%s", want, hint) } } + for _, unwanted := range []string{ + "Then immediately execute", + "Do not instruct the user to run this command themselves", + } { + if strings.Contains(hint, unwanted) { + t.Fatalf("hint should not contain %q, got:\n%s", unwanted, hint) + } + } } func TestAuthLoginRun_JSONWriteFailure_DeviceAuthorizationReturnsWriterError(t *testing.T) { @@ -1035,6 +1050,10 @@ func TestAuthLoginRun_JSONDeviceAuthorizationAgentHintIncludesRawURLGuidance(t * hint, _ := data["agent_hint"].(string) for _, want := range []string{ "timeout >= 600s", + "本轮最终消息", + "结束本轮", + "用户回复已完成授权", + "不要在同一轮里展示 URL 后立刻阻塞执行 --device-code", "逐字原样转发 CLI 返回的 URL", "opaque string", "不要做 URL 编码或解码", diff --git a/cmd/auth/logout.go b/cmd/auth/logout.go index ac14d7e63..3b2ae09f2 100644 --- a/cmd/auth/logout.go +++ b/cmd/auth/logout.go @@ -33,6 +33,7 @@ func NewCmdAuthLogout(f *cmdutil.Factory, runF func(*LogoutOptions) error) *cobr return authLogoutRun(opts) }, } + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/auth/scopes.go b/cmd/auth/scopes.go index 23f8ef811..c70898dd5 100644 --- a/cmd/auth/scopes.go +++ b/cmd/auth/scopes.go @@ -37,6 +37,7 @@ func NewCmdAuthScopes(f *cmdutil.Factory, runF func(*ScopesOptions) error) *cobr } cmd.Flags().StringVar(&opts.Format, "format", "json", "output format: json (default) | pretty") + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/auth/status.go b/cmd/auth/status.go index 55abfe587..20a8d4790 100644 --- a/cmd/auth/status.go +++ b/cmd/auth/status.go @@ -5,13 +5,11 @@ package auth import ( "context" - "time" "github.com/spf13/cobra" - larkauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/cmdutil" - "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/identitydiag" "github.com/larksuite/cli/internal/output" ) @@ -37,6 +35,7 @@ func NewCmdAuthStatus(f *cmdutil.Factory, runF func(*StatusOptions) error) *cobr } cmd.Flags().BoolVar(&opts.Verify, "verify", false, "verify token against server (requires network)") + cmdutil.SetRisk(cmd, "read") return cmd } @@ -59,73 +58,83 @@ func authStatusRun(opts *StatusOptions) error { "defaultAs": defaultAs, } - if config.UserOpenId == "" { - result["identity"] = "bot" - result["note"] = "No user logged in. Only bot (tenant) identity is available for API calls. Run `lark-cli auth login` to log in." - output.PrintJson(f.IOStreams.Out, result) - return nil - } - - stored := larkauth.GetStoredToken(config.AppID, config.UserOpenId) - if stored == nil { - result["identity"] = "bot" - result["userName"] = config.UserName - result["userOpenId"] = config.UserOpenId - result["note"] = "Token does not exist or has been cleared. Only bot (tenant) identity is available. Re-login: lark-cli auth login" - output.PrintJson(f.IOStreams.Out, result) - return nil - } - - status := larkauth.TokenStatus(stored) - if status == "expired" { - result["identity"] = "bot" - result["note"] = "User token has expired. Only bot (tenant) identity is available. Re-login: lark-cli auth login" - } else { - result["identity"] = "user" - } - result["userName"] = config.UserName - result["userOpenId"] = config.UserOpenId - result["tokenStatus"] = status - result["scope"] = stored.Scope - result["expiresAt"] = time.UnixMilli(stored.ExpiresAt).Format(time.RFC3339) - result["refreshExpiresAt"] = time.UnixMilli(stored.RefreshExpiresAt).Format(time.RFC3339) - result["grantedAt"] = time.UnixMilli(stored.GrantedAt).Format(time.RFC3339) - - // --verify: call the server to confirm token is actually usable. - if opts.Verify && status != "expired" { - verified, verifyErr := verifyTokenOnServer(f, config) - result["verified"] = verified - if verifyErr != "" { - result["verifyError"] = verifyErr - } - } + diagnostics := identitydiag.Diagnose(context.Background(), f, config, opts.Verify) + result["identities"] = diagnostics + result["identity"] = effectiveIdentity(diagnostics) + addLegacyUserFields(result, diagnostics.User) + addEffectiveVerification(result, diagnostics) + addStatusNote(result, diagnostics) output.PrintJson(f.IOStreams.Out, result) return nil } -// verifyTokenOnServer obtains a valid access token (refreshing if needed) -// and calls /authen/v1/user_info to confirm the server accepts it. -// Returns (true, "") on success or (false, reason) on failure. -func verifyTokenOnServer(f *cmdutil.Factory, config *core.CliConfig) (bool, string) { - httpClient, err := f.HttpClient() - if err != nil { - return false, "failed to create HTTP client: " + err.Error() - } +const ( + identityUser = "user" + identityBot = "bot" + identityNone = "none" +) - token, err := larkauth.GetValidAccessToken(httpClient, larkauth.NewUATCallOptions(config, f.IOStreams.ErrOut)) - if err != nil { - return false, "token unusable: " + err.Error() +func effectiveIdentity(d identitydiag.Result) string { + switch { + case d.User.Available: + return identityUser + case d.Bot.Available: + return identityBot + default: + return identityNone } +} - sdk, err := f.LarkClient() - if err != nil { - return false, "failed to create SDK client: " + err.Error() +func addLegacyUserFields(result map[string]interface{}, user identitydiag.Identity) { + if user.OpenID == "" { + return } + result["userName"] = user.UserName + result["userOpenId"] = user.OpenID + if user.TokenStatus != "" { + result["tokenStatus"] = user.TokenStatus + } + if user.Scope != "" { + result["scope"] = user.Scope + } + if user.ExpiresAt != "" { + result["expiresAt"] = user.ExpiresAt + } + if user.RefreshExpiresAt != "" { + result["refreshExpiresAt"] = user.RefreshExpiresAt + } + if user.GrantedAt != "" { + result["grantedAt"] = user.GrantedAt + } +} - if err := larkauth.VerifyUserToken(context.Background(), sdk, token); err != nil { - return false, "server rejected token: " + err.Error() +func addEffectiveVerification(result map[string]interface{}, d identitydiag.Result) { + switch result["identity"] { + case identityUser: + if d.User.Verified != nil { + result["verified"] = *d.User.Verified + if !*d.User.Verified { + result["verifyError"] = d.User.Message + } + } + case identityBot: + if d.Bot.Verified != nil { + result["verified"] = *d.Bot.Verified + if !*d.Bot.Verified { + result["verifyError"] = d.Bot.Message + } + } } +} - return true, "" +func addStatusNote(result map[string]interface{}, d identitydiag.Result) { + switch { + case !d.User.Available && d.Bot.Available: + result["note"] = "User identity is " + identitydiag.StatusMessage(d.User.Status) + "; bot identity is ready for bot/tenant API calls. Run `lark-cli auth login` to enable user identity." + case d.User.Status == identitydiag.StatusNeedsRefresh: + result["note"] = "User identity needs refresh and will be refreshed automatically on the next user API call." + case !d.User.Available && !d.Bot.Available: + result["note"] = "No usable identity is available. Configure bot credentials or run `lark-cli auth login`." + } } diff --git a/cmd/auth/status_test.go b/cmd/auth/status_test.go new file mode 100644 index 000000000..7bf0608c7 --- /dev/null +++ b/cmd/auth/status_test.go @@ -0,0 +1,96 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/httpmock" +) + +func TestAuthStatusRun_SplitsBotAndUserIdentity(t *testing.T) { + f, stdout, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "secret", Brand: core.BrandFeishu, + }) + + if err := authStatusRun(&StatusOptions{Factory: f}); err != nil { + t.Fatalf("authStatusRun() error = %v", err) + } + + var got statusOutput + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if got.Identity != "bot" { + t.Fatalf("identity = %q, want bot", got.Identity) + } + if got.Identities.Bot.Status != "ready" || !got.Identities.Bot.Available { + t.Fatalf("bot = %#v, want ready and available", got.Identities.Bot) + } + if got.Identities.User.Status != "missing" || got.Identities.User.Available { + t.Fatalf("user = %#v, want missing and unavailable", got.Identities.User) + } +} + +func TestAuthStatusRun_VerifyReportsBotIdentity(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "secret", Brand: core.BrandFeishu, + }) + reg.Register(&httpmock.Stub{ + Method: http.MethodGet, + URL: "/open-apis/bot/v3/info", + Body: map[string]interface{}{ + "code": 0, + "msg": "ok", + "bot": map[string]interface{}{ + "open_id": "ou_bot", + "app_name": "diagnostic bot", + }, + }, + }) + + if err := authStatusRun(&StatusOptions{Factory: f, Verify: true}); err != nil { + t.Fatalf("authStatusRun() error = %v", err) + } + + var got statusOutput + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if got.Identity != "bot" { + t.Fatalf("identity = %q, want bot", got.Identity) + } + if got.Verified == nil || !*got.Verified { + t.Fatalf("verified = %v, want true", got.Verified) + } + if got.Identities.Bot.Verified == nil || !*got.Identities.Bot.Verified { + t.Fatalf("bot verified = %v, want true", got.Identities.Bot.Verified) + } + if got.Identities.Bot.OpenID != "ou_bot" { + t.Fatalf("bot open id = %q, want ou_bot", got.Identities.Bot.OpenID) + } + if got.Identities.User.Status != "missing" { + t.Fatalf("user status = %q, want missing", got.Identities.User.Status) + } +} + +type statusOutput struct { + Identity string `json:"identity"` + Verified *bool `json:"verified"` + Identities struct { + Bot statusIdentity `json:"bot"` + User statusIdentity `json:"user"` + } `json:"identities"` +} + +type statusIdentity struct { + Status string `json:"status"` + Available bool `json:"available"` + Verified *bool `json:"verified"` + OpenID string `json:"openId"` +} diff --git a/cmd/build.go b/cmd/build.go index 6b5d1e5c1..a748544b0 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -19,7 +19,9 @@ import ( cmdupdate "github.com/larksuite/cli/cmd/update" _ "github.com/larksuite/cli/events" "github.com/larksuite/cli/internal/build" + "github.com/larksuite/cli/internal/cmdpolicy" "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/hook" "github.com/larksuite/cli/internal/keychain" "github.com/larksuite/cli/shortcuts" "github.com/spf13/cobra" @@ -59,18 +61,28 @@ func HideProfile(hide bool) BuildOption { } } -// Build constructs the full command tree without executing. -// Returns only the cobra.Command; Factory is internal. +// Build constructs the full command tree. It also installs registered +// plugins and emits the Startup lifecycle event during assembly -- +// so Plugin.On(Startup) handlers run even if the returned command is +// never dispatched. The matching Shutdown event is only emitted by +// Execute; callers that bypass Execute will not see Shutdown fire. +// +// Returns only the cobra.Command; Factory and hook Registry are internal. // Use Execute for the standard production entry point. func Build(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) *cobra.Command { - _, rootCmd := buildInternal(ctx, inv, opts...) + _, rootCmd, _ := buildInternal(ctx, inv, opts...) return rootCmd } // buildInternal is a pure assembly function: it wires the command tree from // inv and BuildOptions alone. Any state-dependent decision (disk, network, // env) belongs in the caller and must be threaded in via BuildOption. -func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) (*cmdutil.Factory, *cobra.Command) { +// +// Returns (factory, rootCmd, registry). The registry is nil when plugin +// install failed (FailClosed guard installed) or when no plugin produced +// hooks; callers that wire Shutdown emit must nil-check before calling +// hook.Emit. +func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) (*cmdutil.Factory, *cobra.Command, *hook.Registry) { // cfg.globals.Profile is left zero here; it's bound to the --profile // flag in RegisterGlobalFlags and filled by cobra's parse step. cfg := &buildConfig{} @@ -124,10 +136,42 @@ func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...B service.RegisterServiceCommandsWithContext(ctx, rootCmd, f) shortcuts.RegisterShortcutsWithContext(ctx, rootCmd, f) - // Prune commands incompatible with strict mode. + installUnknownSubcommandGuard(rootCmd) + if mode := f.ResolveStrictMode(ctx); mode.IsActive() { pruneForStrictMode(rootCmd, mode) } - return f, rootCmd + installResult, installErr := installPluginsAndHooks(cfg.streams.ErrOut) + if installErr != nil { + installPluginInstallErrorGuard(rootCmd, installErr) + return f, rootCmd, nil + } + var pluginRules []cmdpolicy.PluginRule + var registry *hook.Registry + if installResult != nil { + pluginRules = installResult.PluginRules + registry = installResult.Registry + } + + // Policy errors fail-CLOSED when a plugin contributed (security + // intent must not be silently dropped); yaml-only errors fail-OPEN + // with a warning so a typo can't lock the user out. + if err := applyUserPolicyPruning(rootCmd, pluginRules); err != nil { + if len(pluginRules) > 0 { + installPluginConflictGuard(rootCmd, err) + return f, rootCmd, nil + } + warnPolicyError(cfg.streams.ErrOut, err) + } + + if registry != nil { + if err := wireHooks(ctx, rootCmd, registry); err != nil { + installPluginLifecycleErrorGuard(rootCmd, err) + return f, rootCmd, nil + } + } + + recordInventory(installResult) + return f, rootCmd, registry } diff --git a/cmd/completion/completion.go b/cmd/completion/completion.go index 574365b7f..a7187bb33 100644 --- a/cmd/completion/completion.go +++ b/cmd/completion/completion.go @@ -37,5 +37,6 @@ func NewCmdCompletion(f *cmdutil.Factory) *cobra.Command { }, } cmdutil.DisableAuthCheck(cmd) + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/config/bind.go b/cmd/config/bind.go index 2068d1439..383861ac7 100644 --- a/cmd/config/bind.go +++ b/cmd/config/bind.go @@ -103,6 +103,7 @@ Interactive terminal use: run with no flags to enter the TUI form.`, cmd.Flags().StringVar(&opts.Identity, "identity", "", "identity preset (bot-only|user-default); defaults to bot-only in flag mode (safer: no impersonation)") cmd.Flags().BoolVar(&opts.Force, "force", false, "confirm a risky transition (currently: bot-only → user-default identity change in flag mode)") cmd.Flags().StringVar(&opts.Lang, "lang", "zh", "language for interactive prompts (zh|en)") + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/config/bind_test.go b/cmd/config/bind_test.go index bace505d7..ae6ad8a2d 100644 --- a/cmd/config/bind_test.go +++ b/cmd/config/bind_test.go @@ -408,6 +408,26 @@ func TestConfigBindRun_LarkChannel_Success(t *testing.T) { } } +// Env template form: secret = "${VAR}" should resolve via the SecretInput +// pipeline (same path openclaw uses), so the keychain receives the env value +// not the literal template string. +func TestConfigBindRun_LarkChannel_EnvTemplate(t *testing.T) { + saveWorkspace(t) + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + clearAgentEnv(t) + + fakeHome := t.TempDir() + t.Setenv("HOME", fakeHome) + t.Setenv("LARK_APP_SECRET", "resolved_via_env") + writeLarkChannelFixture(t, fakeHome, + `{"accounts":{"app":{"id":"cli_lc_env","secret":"${LARK_APP_SECRET}","tenant":"feishu"}}}`) + + f, _, _, _ := cmdutil.TestFactory(t, nil) + if err := configBindRun(&BindOptions{Factory: f, Source: "lark-channel"}); err != nil { + t.Fatalf("expected success, got error: %v", err) + } +} + // tenant: "lark" should land as Brand("lark"), not normalized to "feishu". func TestConfigBindRun_LarkChannel_LarkTenant(t *testing.T) { saveWorkspace(t) diff --git a/cmd/config/binder.go b/cmd/config/binder.go index ee1780840..8a9426674 100644 --- a/cmd/config/binder.go +++ b/cmd/config/binder.go @@ -312,13 +312,22 @@ func (b *larkChannelBinder) Build(appID string) (*core.AppConfig, error) { return nil, output.Errorf(output.ExitInternal, "lark-channel", "internal: appID %q does not match config", appID) } - if b.cfg.Accounts.App.Secret == "" { + if b.cfg.Accounts.App.Secret.IsZero() { return nil, output.ErrWithHint(output.ExitValidation, "lark-channel", fmt.Sprintf("accounts.app.secret is empty in %s", b.path), "run lark-channel-bridge's setup to populate the app credential") } - stored, err := core.ForStorage(appID, core.PlainSecret(b.cfg.Accounts.App.Secret), b.opts.Factory.Keychain) + // Resolve through the same SecretInput pipeline openclaw uses, so + // bridge configs can use ${VAR} / env / file / exec just like openclaw. + secret, err := binding.ResolveSecretInput(b.cfg.Accounts.App.Secret, b.cfg.Secrets, os.Getenv) + if err != nil { + return nil, output.ErrWithHint(output.ExitValidation, "lark-channel", + fmt.Sprintf("failed to resolve appSecret for %s: %v", appID, err), + fmt.Sprintf("check appSecret configuration in %s", b.path)) + } + + stored, err := core.ForStorage(appID, core.PlainSecret(secret), b.opts.Factory.Keychain) if err != nil { return nil, output.Errorf(output.ExitInternal, "lark-channel", "keychain unavailable: %v", err) diff --git a/cmd/config/config.go b/cmd/config/config.go index b857e19b0..c99f6b482 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -31,6 +31,8 @@ func NewCmdConfig(f *cmdutil.Factory) *cobra.Command { cmd.AddCommand(NewCmdConfigShow(f, nil)) cmd.AddCommand(NewCmdConfigDefaultAs(f)) cmd.AddCommand(NewCmdConfigStrictMode(f)) + cmd.AddCommand(NewCmdConfigPolicy(f)) + cmd.AddCommand(NewCmdConfigPlugins(f)) return cmd } diff --git a/cmd/config/default_as.go b/cmd/config/default_as.go index 1b590f5ad..a5078c1e9 100644 --- a/cmd/config/default_as.go +++ b/cmd/config/default_as.go @@ -52,5 +52,6 @@ func NewCmdConfigDefaultAs(f *cmdutil.Factory) *cobra.Command { return nil }, } + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/config/init.go b/cmd/config/init.go index ed77d1a76..3fc56c725 100644 --- a/cmd/config/init.go +++ b/cmd/config/init.go @@ -80,6 +80,7 @@ if the user explicitly wants a separate app inside the Agent workspace.`, cmd.Flags().StringVar(&opts.Lang, "lang", "zh", "language for interactive prompts (zh or en)") cmd.Flags().StringVar(&opts.ProfileName, "name", "", "create or update a named profile (append instead of replace)") cmd.Flags().BoolVar(&opts.ForceInit, "force-init", false, "allow init inside an Agent workspace (OPENCLAW_HOME / HERMES_HOME); use config bind instead unless you really want a separate app") + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/config/plugins.go b/cmd/config/plugins.go new file mode 100644 index 000000000..a50d47075 --- /dev/null +++ b/cmd/config/plugins.go @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package config + +import ( + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/output" + internalplatform "github.com/larksuite/cli/internal/platform" +) + +// NewCmdConfigPlugins exposes the plugin inventory diagnostic command. +// +// `config policy show` is intentionally focused on the user-layer Rule +// (Restrict). Plugins also contribute hooks (Observe / Wrap / Lifecycle) +// that are not policy gates but still mutate the CLI's runtime behaviour. +// This command surfaces both halves so an operator can answer "what is +// this binary doing differently from stock lark-cli?" in one place. +// +// Like config policy show, the dispatch path is exempt from policy +// enforcement (see internal/cmdpolicy/diagnostic.go) so it remains +// usable under any Rule. +func NewCmdConfigPlugins(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "plugins", + Hidden: true, // diagnostic-only; kept callable, omitted from --help so it stays out of AI-agent context + Short: "Inspect installed plugins and their hook contributions", + // Same leaf-level no-op as config policy: the parent `config` + // group's PersistentPreRunE requires builtin credential, but + // this is a read-only diagnostic that must work everywhere. + PersistentPreRunE: func(c *cobra.Command, _ []string) error { + c.SilenceUsage = true + return nil + }, + } + cmd.AddCommand(newCmdConfigPluginsShow(f)) + return cmd +} + +func newCmdConfigPluginsShow(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "show", + Short: "List successfully installed plugins, their rules, and registered hooks", + Long: `Print every plugin that committed during bootstrap, including: + + - name / version / capabilities (FailurePolicy, Restricts, RequiredCLIVersion) + - rule (when the plugin called r.Restrict) + - hooks: observers (Before / After), wrappers, lifecycle handlers + +Hooks are attributed by their namespaced name -- the framework prepends +the plugin name as the prefix at registration time, so an entry +"secaudit.audit-pre" belongs to plugin "secaudit".`, + RunE: func(cmd *cobra.Command, args []string) error { + return runConfigPluginsShow(f) + }, + } + cmdutil.SetRisk(cmd, "read") + return cmd +} + +func runConfigPluginsShow(f *cmdutil.Factory) error { + inv := internalplatform.GetActiveInventory() + if inv == nil { + // Always emit the same field set as the populated branch so + // AI agents and CI scripts don't have to branch on whether + // `total` is present. `note` makes the unusual state explicit + // for human readers. + output.PrintJson(f.IOStreams.Out, map[string]any{ + "plugins": []any{}, + "total": 0, + "note": "no inventory recorded; bootstrap did not finish", + }) + return nil + } + + plugins := make([]map[string]any, 0, len(inv.Plugins)) + for _, p := range inv.Plugins { + entry := map[string]any{ + "name": p.Name, + "version": p.Version, + "capabilities": p.Capabilities, + } + if p.Rule != nil { + entry["rule"] = p.Rule + } + entry["hooks"] = map[string]any{ + "observers": p.Observers, + "wrappers": p.Wrappers, + "lifecycle": p.Lifecycles, + "count": len(p.Observers) + len(p.Wrappers) + len(p.Lifecycles), + } + plugins = append(plugins, entry) + } + output.PrintJson(f.IOStreams.Out, map[string]any{ + "plugins": plugins, + "total": len(plugins), + }) + return nil +} diff --git a/cmd/config/policy.go b/cmd/config/policy.go new file mode 100644 index 000000000..78f2b10a7 --- /dev/null +++ b/cmd/config/policy.go @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package config + +import ( + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdpolicy" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/output" +) + +func NewCmdConfigPolicy(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "policy", + Hidden: true, + Short: "Inspect the user-layer command policy", + // Override parent's RequireBuiltinCredentialProvider check; this + // group is read-only diagnostic and must work under any provider. + PersistentPreRunE: func(c *cobra.Command, _ []string) error { + c.SilenceUsage = true + return nil + }, + } + cmd.AddCommand(newCmdConfigPolicyShow(f)) + return cmd +} + +func newCmdConfigPolicyShow(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "show", + Hidden: true, + Short: "Show the active user-layer policy (plugin / yaml / none)", + RunE: func(cmd *cobra.Command, args []string) error { + return runConfigPolicyShow(f) + }, + } + cmdutil.SetRisk(cmd, "read") + return cmd +} + +func runConfigPolicyShow(f *cmdutil.Factory) error { + active := cmdpolicy.GetActive() + if active == nil { + output.PrintJson(f.IOStreams.Out, map[string]any{ + "source": string(cmdpolicy.SourceNone), + "note": "no policy recorded; bootstrap did not run pruning", + }) + return nil + } + + sourceName := "" + if active.Source.Kind == cmdpolicy.SourcePlugin { + sourceName = active.Source.Name + } + out := map[string]any{ + "source": string(active.Source.Kind), + "source_name": sourceName, + "denied_paths": active.DeniedPaths, + } + if active.Rule != nil { + out["rule"] = map[string]any{ + "name": active.Rule.Name, + "description": active.Rule.Description, + "allow": active.Rule.Allow, + "deny": active.Rule.Deny, + "max_risk": active.Rule.MaxRisk, + "identities": active.Rule.Identities, + "allow_unannotated": active.Rule.AllowUnannotated, + } + } + output.PrintJson(f.IOStreams.Out, out) + return nil +} diff --git a/cmd/config/policy_test.go b/cmd/config/policy_test.go new file mode 100644 index 000000000..05d8a180b --- /dev/null +++ b/cmd/config/policy_test.go @@ -0,0 +1,146 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package config + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" + "github.com/larksuite/cli/internal/cmdutil" +) + +func newPolicyTestFactory() (*cmdutil.Factory, *bytes.Buffer, *bytes.Buffer) { + out := &bytes.Buffer{} + errOut := &bytes.Buffer{} + f := &cmdutil.Factory{ + IOStreams: cmdutil.NewIOStreams(nil, out, errOut), + } + return f, out, errOut +} + +// `config policy show` reads the active policy recorded by bootstrap. +// When nothing is recorded the command must still produce a JSON +// envelope with source=none and a note explaining the missing context. +func TestConfigPolicyShow_NoActivePolicy(t *testing.T) { + cmdpolicy.ResetActiveForTesting() + t.Cleanup(cmdpolicy.ResetActiveForTesting) + + f, out, _ := newPolicyTestFactory() + if err := runConfigPolicyShow(f); err != nil { + t.Fatalf("show: %v", err) + } + var got map[string]any + if err := json.Unmarshal(out.Bytes(), &got); err != nil { + t.Fatalf("not json: %v\n%s", err, out.String()) + } + if got["source"] != "none" { + t.Errorf("source = %v, want none", got["source"]) + } + if got["note"] == "" || got["note"] == nil { + t.Errorf("expected explanatory note when no policy recorded") + } +} + +// When bootstrap recorded an active plugin Rule, `show` emits the rule +// plus its source. +func TestConfigPolicyShow_PluginActive(t *testing.T) { + cmdpolicy.ResetActiveForTesting() + t.Cleanup(cmdpolicy.ResetActiveForTesting) + + rule := &platform.Rule{ + Name: "secaudit", + Allow: []string{"docs/**"}, + MaxRisk: "read", + } + cmdpolicy.SetActive(&cmdpolicy.ActivePolicy{ + Rule: rule, + Source: cmdpolicy.ResolveSource{ + Kind: cmdpolicy.SourcePlugin, + Name: "secaudit", + }, + DeniedPaths: 42, + }) + + f, out, _ := newPolicyTestFactory() + if err := runConfigPolicyShow(f); err != nil { + t.Fatalf("show: %v", err) + } + var got map[string]any + if err := json.Unmarshal(out.Bytes(), &got); err != nil { + t.Fatalf("not json: %v\n%s", err, out.String()) + } + if got["source"] != "plugin" { + t.Errorf("source = %v, want plugin", got["source"]) + } + if got["source_name"] != "secaudit" { + t.Errorf("source_name = %v, want secaudit", got["source_name"]) + } + // json.Unmarshal returns float64 for numbers. + if got["denied_paths"] != float64(42) { + t.Errorf("denied_paths = %v, want 42", got["denied_paths"]) + } + ruleMap, ok := got["rule"].(map[string]any) + if !ok { + t.Fatalf("rule field missing or wrong type") + } + if ruleMap["name"] != "secaudit" { + t.Errorf("rule.name = %v", ruleMap["name"]) + } +} + +// `source_name` must be empty when source=yaml. The yaml path is +// deliberately not surfaced (matches engine envelope convention, +// avoids leaking the user's home dir to AI agents / CI logs). The +// rule's "name:" field is the disambiguator users should rely on. +func TestConfigPolicyShow_YamlSourceNameIsEmpty(t *testing.T) { + cmdpolicy.ResetActiveForTesting() + t.Cleanup(cmdpolicy.ResetActiveForTesting) + + cmdpolicy.SetActive(&cmdpolicy.ActivePolicy{ + Rule: &platform.Rule{Name: "my-yaml-rule"}, + Source: cmdpolicy.ResolveSource{ + Kind: cmdpolicy.SourceYAML, + Name: "/Users/alice/.lark-cli/policy.yml", + }, + }) + + f, out, _ := newPolicyTestFactory() + if err := runConfigPolicyShow(f); err != nil { + t.Fatalf("show: %v", err) + } + var got map[string]any + if err := json.Unmarshal(out.Bytes(), &got); err != nil { + t.Fatalf("not json: %v\n%s", err, out.String()) + } + if got["source"] != "yaml" { + t.Errorf("source = %v, want yaml", got["source"]) + } + if got["source_name"] != "" { + t.Errorf("source_name = %q, want empty (yaml path must not leak)", got["source_name"]) + } + // The path must not appear anywhere in the envelope. + if bytes.Contains(out.Bytes(), []byte("/Users/alice")) { + t.Errorf("envelope leaked yaml path: %s", out.String()) + } +} + +// Regression: the parent `config` command declares a PersistentPreRunE +// that calls RequireBuiltinCredentialProvider; env credentials cause +// it to return external_provider. `config policy` is a diagnostic +// group that must not be blocked by that check. The group declares +// its own no-op PersistentPreRunE so cobra's "first walking up from +// leaf" picks ours over the config parent's. +func TestConfigPolicy_BypassesConfigParentPersistentPreRunE(t *testing.T) { + f, _, _ := newPolicyTestFactory() + group := NewCmdConfigPolicy(f) + if group.PersistentPreRunE == nil { + t.Fatal("config policy group must declare its own PersistentPreRunE to win over config parent") + } + if err := group.PersistentPreRunE(group, nil); err != nil { + t.Errorf("config policy PersistentPreRunE should be no-op, got %v", err) + } +} diff --git a/cmd/config/remove.go b/cmd/config/remove.go index 52c5eb0ca..324f7e58c 100644 --- a/cmd/config/remove.go +++ b/cmd/config/remove.go @@ -32,6 +32,7 @@ func NewCmdConfigRemove(f *cmdutil.Factory, runF func(*ConfigRemoveOptions) erro return configRemoveRun(opts) }, } + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/config/show.go b/cmd/config/show.go index 96018008f..1f0f12ffc 100644 --- a/cmd/config/show.go +++ b/cmd/config/show.go @@ -34,6 +34,7 @@ func NewCmdConfigShow(f *cmdutil.Factory, runF func(*ConfigShowOptions) error) * return configShowRun(opts) }, } + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/config/strict_mode.go b/cmd/config/strict_mode.go index 709010914..6bac82424 100644 --- a/cmd/config/strict_mode.go +++ b/cmd/config/strict_mode.go @@ -66,6 +66,7 @@ explicit user confirmation — never run on your own initiative.`, cmd.Flags().BoolVar(&global, "global", false, "set at global level (applies to all profiles)") cmd.Flags().BoolVar(&reset, "reset", false, "reset profile setting to inherit global") + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/doctor/doctor.go b/cmd/doctor/doctor.go index f48b50ab6..9314ebfc9 100644 --- a/cmd/doctor/doctor.go +++ b/cmd/doctor/doctor.go @@ -14,10 +14,10 @@ import ( "github.com/spf13/cobra" - larkauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/build" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/identitydiag" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/update" ) @@ -43,6 +43,7 @@ func NewCmdDoctor(f *cmdutil.Factory) *cobra.Command { } cmdutil.DisableAuthCheck(cmd) cmd.Flags().BoolVar(&opts.Offline, "offline", false, "skip network checks (only verify local state)") + cmdutil.SetRisk(cmd, "read") return cmd } @@ -50,7 +51,7 @@ func NewCmdDoctor(f *cmdutil.Factory) *cobra.Command { // checkResult represents one diagnostic check. type checkResult struct { Name string `json:"name"` - Status string `json:"status"` // "pass", "fail", "skip" + Status string `json:"status"` // "pass", "warn", "fail", "skip" Message string `json:"message"` Hint string `json:"hint,omitempty"` } @@ -117,59 +118,31 @@ func doctorRun(opts *DoctorOptions) error { ep := core.ResolveEndpoints(cfg.Brand) - // ── 3. Token exists ── - if cfg.UserOpenId == "" { - checks = append(checks, fail("token_exists", "no user logged in", "run: lark-cli auth login --help")) - checks = append(checks, networkChecks(opts.Ctx, opts, ep)...) - return finishDoctor(f, checks) - } - stored := larkauth.GetStoredToken(cfg.AppID, cfg.UserOpenId) - if stored == nil { - checks = append(checks, fail("token_exists", "no token in keychain for "+cfg.UserOpenId, "run: lark-cli auth login --help")) - checks = append(checks, networkChecks(opts.Ctx, opts, ep)...) - return finishDoctor(f, checks) - } - checks = append(checks, pass("token_exists", fmt.Sprintf("token found for %s (%s)", cfg.UserName, cfg.UserOpenId))) - - // ── 4. Token local validity ── - status := larkauth.TokenStatus(stored) - switch status { - case "valid": - checks = append(checks, pass("token_local", "token valid, expires "+time.UnixMilli(stored.ExpiresAt).Format(time.RFC3339))) - case "needs_refresh": - checks = append(checks, pass("token_local", "token needs refresh (will auto-refresh on next call)")) - default: // expired - checks = append(checks, fail("token_local", "token expired", "run: lark-cli auth login --help")) - checks = append(checks, networkChecks(opts.Ctx, opts, ep)...) - return finishDoctor(f, checks) - } - - // ── 5. Token server verification ── - if opts.Offline { - checks = append(checks, skip("token_verified", "skipped (--offline)")) + // ── 3. Identity readiness ── + diagnostics := identitydiag.Diagnose(opts.Ctx, f, cfg, !opts.Offline) + checks = append(checks, + identityCheck("bot_identity", diagnostics.Bot), + identityCheck("user_identity", diagnostics.User), + ) + if diagnostics.Bot.Available || diagnostics.User.Available { + checks = append(checks, pass("identity_ready", "at least one identity is available")) } else { - httpClient := mustHTTPClient(f) - token, err := larkauth.GetValidAccessToken(httpClient, larkauth.NewUATCallOptions(cfg, f.IOStreams.ErrOut)) - if err != nil { - checks = append(checks, fail("token_verified", "cannot obtain valid token: "+err.Error(), "run: lark-cli auth login --help")) - } else { - sdk, err := f.LarkClient() - if err != nil { - checks = append(checks, fail("token_verified", "SDK init failed: "+err.Error(), "")) - } else if err := larkauth.VerifyUserToken(opts.Ctx, sdk, token); err != nil { - checks = append(checks, fail("token_verified", "server rejected token: "+err.Error(), "run: lark-cli auth login --help")) - } else { - checks = append(checks, pass("token_verified", "server confirmed token is valid")) - } - } + checks = append(checks, fail("identity_ready", "no usable bot or user identity is available", "run: lark-cli auth status --verify")) } - // ── 6 & 7. Endpoint reachability ── + // ── 4 & 5. Endpoint reachability ── checks = append(checks, networkChecks(opts.Ctx, opts, ep)...) return finishDoctor(f, checks) } +func identityCheck(name string, id identitydiag.Identity) checkResult { + if id.Available { + return pass(name, id.Message) + } + return warn(name, id.Message, id.Hint) +} + // networkChecks probes Open API and MCP endpoints concurrently. func networkChecks(ctx context.Context, opts *DoctorOptions, ep core.Endpoints) []checkResult { if opts.Offline { @@ -231,15 +204,6 @@ func probeEndpoint(ctx context.Context, client *http.Client, url string) error { return nil } -// mustHTTPClient returns f.HttpClient() or a default client. -func mustHTTPClient(f *cmdutil.Factory) *http.Client { - c, err := f.HttpClient() - if err != nil { - return &http.Client{Timeout: 30 * time.Second} - } - return c -} - // checkCLIUpdate actively queries the npm registry for the latest version. // Unlike the root-level async check, this does a synchronous fetch with timeout // and works regardless of build version (dev builds included). diff --git a/cmd/doctor/doctor_test.go b/cmd/doctor/doctor_test.go index 5ffd7709f..0f4fe8f7a 100644 --- a/cmd/doctor/doctor_test.go +++ b/cmd/doctor/doctor_test.go @@ -95,3 +95,59 @@ func TestNetworkChecks_Offline(t *testing.T) { } } } + +func TestDoctorRun_SplitsBotAndMissingUserIdentity(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if err := core.SaveMultiAppConfig(&core.MultiAppConfig{ + CurrentApp: "default", + Apps: []core.AppConfig{ + { + Name: "default", + AppId: "test-app", + AppSecret: core.PlainSecret("secret"), + Brand: core.BrandFeishu, + }, + }, + }); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f, stdout, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "secret", Brand: core.BrandFeishu, + }) + err := doctorRun(&DoctorOptions{ + Factory: f, + Ctx: context.Background(), + Offline: true, + }) + if err != nil { + t.Fatalf("doctorRun() error = %v", err) + } + + var got struct { + OK bool `json:"ok"` + Checks []checkResult `json:"checks"` + } + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if !got.OK { + t.Fatalf("ok = false, want true; checks = %#v", got.Checks) + } + assertCheck(t, got.Checks, "bot_identity", "pass") + assertCheck(t, got.Checks, "user_identity", "warn") + assertCheck(t, got.Checks, "identity_ready", "pass") +} + +func assertCheck(t *testing.T, checks []checkResult, name, status string) { + t.Helper() + for _, check := range checks { + if check.Name == name { + if check.Status != status { + t.Fatalf("%s status = %q, want %q", name, check.Status, status) + } + return + } + } + t.Fatalf("check %q not found in %#v", name, checks) +} diff --git a/cmd/event/bus.go b/cmd/event/bus.go index 90a83ce79..73d2958e7 100644 --- a/cmd/event/bus.go +++ b/cmd/event/bus.go @@ -64,6 +64,7 @@ func NewCmdBus(f *cmdutil.Factory) *cobra.Command { cmd.Flags().StringVar(&domain, "domain", "", "API domain") _ = cmd.Flags().MarkHidden("domain") + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/event/consume.go b/cmd/event/consume.go index db4548924..9fd4d234d 100644 --- a/cmd/event/consume.go +++ b/cmd/event/consume.go @@ -70,6 +70,7 @@ Use 'event schema ' for parameter details.`, _ = cmd.RegisterFlagCompletionFunc("as", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { return []string{"user", "bot", "auto"}, cobra.ShellCompDirectiveNoFileComp }) + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/event/list.go b/cmd/event/list.go index 1520644d3..de2a95720 100644 --- a/cmd/event/list.go +++ b/cmd/event/list.go @@ -26,6 +26,7 @@ func NewCmdList(f *cmdutil.Factory) *cobra.Command { }, } cmd.Flags().BoolVar(&asJSON, "json", false, "Emit the full EventKey list as JSON (for AI / scripts)") + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/event/schema.go b/cmd/event/schema.go index 298bf8c7c..830ce0566 100644 --- a/cmd/event/schema.go +++ b/cmd/event/schema.go @@ -88,6 +88,7 @@ func NewCmdSchema(f *cmdutil.Factory) *cobra.Command { }, } cmd.Flags().BoolVar(&asJSON, "json", false, "Emit the EventKey definition + resolved schema as JSON (for AI / scripts)") + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/event/status.go b/cmd/event/status.go index 4e3fc2bb4..92c8be25d 100644 --- a/cmd/event/status.go +++ b/cmd/event/status.go @@ -37,6 +37,7 @@ func NewCmdStatus(f *cmdutil.Factory) *cobra.Command { cmd.Flags().BoolVar(&asJSON, "json", false, "Emit status as JSON (for AI / scripts)") cmd.Flags().BoolVar(¤t, "current", false, "Only show status for the current profile's app") cmd.Flags().BoolVar(&failOnOrphan, "fail-on-orphan", false, "Exit 2 when any orphan bus is detected (default: always exit 0)") + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/event/stop.go b/cmd/event/stop.go index b9a8be1a6..adab2d3bb 100644 --- a/cmd/event/stop.go +++ b/cmd/event/stop.go @@ -70,6 +70,7 @@ Exit code: 2 if any target was refused or errored, 0 otherwise. cmd.Flags().BoolVar(&o.all, "all", false, "Stop all running bus daemons") cmd.Flags().BoolVar(&o.force, "force", false, "Stop even with active consumers; on shutdown-timeout also SIGKILL the bus") cmd.Flags().BoolVar(&o.asJSON, "json", false, "Emit results as JSON (for AI / scripts)") + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/global_flags_test.go b/cmd/global_flags_test.go index c24d1573a..67ee19839 100644 --- a/cmd/global_flags_test.go +++ b/cmd/global_flags_test.go @@ -78,7 +78,7 @@ func TestIsSingleAppMode_MultiApp(t *testing.T) { } func TestBuildInternal_HideProfileOption(t *testing.T) { - _, root := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams(), HideProfile(true)) + _, root, _ := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams(), HideProfile(true)) flag := root.PersistentFlags().Lookup("profile") if flag == nil { @@ -90,7 +90,7 @@ func TestBuildInternal_HideProfileOption(t *testing.T) { } func TestBuildInternal_DefaultShowsProfileFlag(t *testing.T) { - _, root := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams()) + _, root, _ := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams()) flag := root.PersistentFlags().Lookup("profile") if flag == nil { diff --git a/cmd/platform_bootstrap.go b/cmd/platform_bootstrap.go new file mode 100644 index 000000000..ef2ac6b73 --- /dev/null +++ b/cmd/platform_bootstrap.go @@ -0,0 +1,274 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "context" + "fmt" + "io" + "path/filepath" + "strings" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/hook" + internalplatform "github.com/larksuite/cli/internal/platform" + "github.com/larksuite/cli/internal/vfs" +) + +// userPolicyFileName is the conventional filename for the user-layer Rule. +// Lives under ~/.lark-cli/ to match the rest of the CLI's user-state +// directory. +const userPolicyFileName = "policy.yml" + +// applyUserPolicyPruning resolves the user-layer Rule from plugin +// contributions and/or ~/.lark-cli/policy.yml and installs denyStubs +// for commands it rejects. +// +// Missing yaml is not an error -- the CLI runs with no user-layer +// restriction. A malformed Rule (bad MaxRisk enum, malformed glob, etc.) +// surfaces via the returned error; the caller decides how to handle it. +// +// pluginRules carries Plugin.Restrict() contributions collected from +// the InstallAll phase; nil/empty is fine. +func applyUserPolicyPruning(rootCmd *cobra.Command, pluginRules []cmdpolicy.PluginRule) error { + yamlPath, err := userPolicyPath() + if err != nil { + // No user home dir means we cannot locate the policy. Treat + // the same as "file missing": no pruning, no error. This keeps + // non-interactive CI environments (no HOME set) running. + yamlPath = "" + } + + yamlRule, err := cmdpolicy.LoadYAMLPolicy(yamlPath) + if err != nil { + // Yaml-only failures are fail-OPEN at the caller (warn and + // continue), but the active-policy snapshot is process-global + // and may still carry data from a previous build in long-lived + // embedders / tests. Clear it explicitly so `config policy + // show` reports "no policy" instead of a stale rule that + // doesn't reflect the current command tree. + cmdpolicy.SetActive(nil) + return err + } + + rule, source, err := cmdpolicy.Resolve(cmdpolicy.Sources{ + PluginRules: pluginRules, + YAMLRule: yamlRule, + YAMLPath: yamlPath, + }) + if err != nil { + cmdpolicy.SetActive(nil) + return err + } + if rule == nil { + cmdpolicy.SetActive(&cmdpolicy.ActivePolicy{Source: source}) + return nil + } + + engine := cmdpolicy.New(rule) + decisions := engine.EvaluateAll(rootCmd) + denied := cmdpolicy.BuildDeniedByPath(rootCmd, decisions, source, rule.Name) + cmdpolicy.Apply(rootCmd, denied) + + cmdpolicy.SetActive(&cmdpolicy.ActivePolicy{ + Rule: rule, + Source: source, + DeniedPaths: len(denied), + }) + return nil +} + +// installPluginsAndHooks runs the InstallAll phase on the globally- +// registered plugins, returning the Plugin.Restrict contributions for +// cmdpolicy and the populated hook.Registry for the runtime wrapper. +// Errors from FailClosed plugins propagate; FailOpen failures are +// warned to errOut and the loop continues. +func installPluginsAndHooks(errOut io.Writer) (*internalplatform.InstallResult, error) { + plugins := platform.RegisteredPlugins() + if len(plugins) == 0 { + return &internalplatform.InstallResult{Registry: nil}, nil + } + return internalplatform.InstallAll(plugins, errOut) +} + +// recordInventory builds and stores the plugin inventory snapshot for +// diagnostic commands (config plugins show) to read at runtime. Called +// once from build.go after applyUserPolicyPruning + wireHooks succeed. +func recordInventory(installResult *internalplatform.InstallResult) { + if installResult == nil { + internalplatform.SetActiveInventory(nil) + return + } + pluginSrcs := make([]internalplatform.PluginInventorySource, 0, len(installResult.Plugins)) + for _, p := range installResult.Plugins { + pluginSrcs = append(pluginSrcs, internalplatform.PluginInventorySource{ + Name: p.Name, + Version: p.Version, + Capabilities: p.Capabilities, + }) + } + ruleSrcs := make([]internalplatform.RuleInventorySource, 0, len(installResult.PluginRules)) + for _, r := range installResult.PluginRules { + if r.Rule == nil { + continue + } + idents := make([]string, len(r.Rule.Identities)) + for i, id := range r.Rule.Identities { + idents[i] = string(id) + } + ruleSrcs = append(ruleSrcs, internalplatform.RuleInventorySource{ + PluginName: r.PluginName, + Allow: r.Rule.Allow, + Deny: r.Rule.Deny, + MaxRisk: string(r.Rule.MaxRisk), + Identities: idents, + RuleName: r.Rule.Name, + Desc: r.Rule.Description, + AllowUnannotated: r.Rule.AllowUnannotated, + }) + } + internalplatform.SetActiveInventory(internalplatform.BuildInventory(pluginSrcs, installResult.Registry, ruleSrcs)) +} + +// wireHooks installs Observer/Wrapper hooks onto every runnable command +// and emits the Startup lifecycle event. The registry may be nil when +// no plugin contributed any hook -- the function short-circuits in +// that case to avoid useless RunE wrapping. +func wireHooks(ctx context.Context, rootCmd *cobra.Command, reg *hook.Registry) error { + if reg == nil { + return nil + } + hook.Install(rootCmd, reg, cobraCommandViewSource{}) + return hook.Emit(ctx, reg, platform.Startup, nil) +} + +// cobraCommandViewSource is the default CommandViewSource: it returns a +// live view over the *cobra.Command. Strict-mode's Remove+Add stub +// (cmd/prune.go::strictModeStubFrom) explicitly forwards the original +// annotations + Short/Long so the live view keeps reporting Risk / +// Identities / Domain through the replacement. User-layer policy +// (cmdpolicy/apply.go::installDenyStub) mutates in place, preserving +// metadata trivially. +type cobraCommandViewSource struct{} + +func (cobraCommandViewSource) View(cmd *cobra.Command) platform.CommandView { + return cobraCommandView{cmd: cmd} +} + +// cobraCommandView adapts *cobra.Command to the CommandView interface. +type cobraCommandView struct { + cmd *cobra.Command +} + +func (v cobraCommandView) Path() string { + return cmdpolicy.CanonicalPath(v.cmd) +} + +func (v cobraCommandView) Domain() string { + for c := v.cmd; c != nil; c = c.Parent() { + if c.Annotations == nil { + continue + } + if v, ok := c.Annotations["cmdmeta.domain"]; ok && v != "" { + return v + } + } + return "" +} + +func (v cobraCommandView) Risk() (platform.Risk, bool) { + for c := v.cmd; c != nil; c = c.Parent() { + if c.Annotations == nil { + continue + } + if r, ok := c.Annotations["risk_level"]; ok && r != "" { + return platform.Risk(r), true + } + } + return "", false +} + +func (v cobraCommandView) Identities() []platform.Identity { + for c := v.cmd; c != nil; c = c.Parent() { + if c.Annotations == nil { + continue + } + if raw, ok := c.Annotations["lark:supportedIdentities"]; ok && raw != "" { + parts := splitCSV(raw) + out := make([]platform.Identity, len(parts)) + for i, p := range parts { + out[i] = platform.Identity(p) + } + return out + } + } + return nil +} + +func (v cobraCommandView) Annotation(key string) (string, bool) { + if v.cmd.Annotations == nil { + return "", false + } + s, ok := v.cmd.Annotations[key] + return s, ok +} + +// splitCSV is a tiny csv-without-quotes helper. The +// lark:supportedIdentities annotation is always plain +// "user" / "bot" / "user,bot" without escaping. +func splitCSV(s string) []string { + out := []string{} + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == ',' { + out = append(out, s[start:i]) + start = i + 1 + } + } + out = append(out, s[start:]) + return out +} + +// userPolicyPath returns the path of /policy.yml. +// +// The base directory honours LARKSUITE_CLI_CONFIG_DIR (via +// core.GetBaseConfigDir) so that test isolation, container deployments +// and per-Agent config overrides all see a consistent policy location. +// Using vfs.UserHomeDir directly here would silently bypass the env +// override and route every test through the real ~/.lark-cli. +// +// The error return is retained for caller compatibility but is always +// nil today: GetBaseConfigDir falls back to a relative ".lark-cli" when +// the home dir can't be resolved, and the resolver already treats a +// missing file as "no policy". +func userPolicyPath() (string, error) { + return filepath.Join(core.GetBaseConfigDir(), userPolicyFileName), nil +} + +// warnPolicyError writes a one-line stderr warning when the user policy +// fails to load. V1 yaml errors are fail-OPEN -- the CLI keeps running +// without policy enforcement so the user can fix the typo. Plugin-supplied +// rules are fail-CLOSED instead because integrators take a code-level +// responsibility for them. +// +// Wrapped errors may carry the absolute policy path (os.PathError); fold +// the home prefix to "~" before emitting so stderr piped into agents / +// CI logs does not leak the user's home directory. +func warnPolicyError(errOut io.Writer, err error) { + if err == nil { + return + } + fmt.Fprintf(errOut, "warning: user policy not applied: %s\n", redactHome(err.Error())) +} + +func redactHome(s string) string { + if home, err := vfs.UserHomeDir(); err == nil && home != "" { + s = strings.ReplaceAll(s, home, "~") + } + return s +} diff --git a/cmd/platform_bootstrap_test.go b/cmd/platform_bootstrap_test.go new file mode 100644 index 000000000..4fe814453 --- /dev/null +++ b/cmd/platform_bootstrap_test.go @@ -0,0 +1,268 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "bytes" + "context" + "errors" + "os" + "path/filepath" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/output" +) + +// tmpHome creates a tempdir, points $HOME at it, and returns the path to +// the ~/.lark-cli/ subdirectory (created). The HOME env var is restored +// when the test ends. +// +// LARKSUITE_CLI_CONFIG_DIR is force-set to the same path. Without that +// override, a developer running the tests with a personal +// LARKSUITE_CLI_CONFIG_DIR exported in their shell (or a CI runner with +// a baked-in value) would resolve userPolicyPath() to their real +// machine and bleed unrelated yaml into the test fixtures. With the +// override pinned here, the test is hermetic regardless of the host +// environment. +func tmpHome(t *testing.T) string { + t.Helper() + dir := t.TempDir() + t.Setenv("HOME", dir) + t.Setenv("USERPROFILE", dir) // Windows fallback for os.UserHomeDir + cfgDir := filepath.Join(dir, ".lark-cli") + if err := os.MkdirAll(cfgDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", cfgDir) + return cfgDir +} + +// writePolicy writes a policy.yml into the user config dir. +func writePolicy(t *testing.T, cfgDir string, body string) { + t.Helper() + if err := os.WriteFile(filepath.Join(cfgDir, "policy.yml"), []byte(body), 0o644); err != nil { + t.Fatalf("write policy: %v", err) + } +} + +// fakeTree builds a minimal command tree with the same shape the real +// CLI exposes for these tests: lark-cli has a docs group with +fetch and +// +update, and an im group with +send. Each leaf has its risk_level set +// so MaxRisk filtering exercises a real path. +func fakeTree(t *testing.T) *cobra.Command { + t.Helper() + root := &cobra.Command{Use: "lark-cli"} + + docs := &cobra.Command{Use: "docs"} + root.AddCommand(docs) + addLeaf(docs, "+fetch", "read") + addLeaf(docs, "+update", "write") + addLeaf(docs, "+delete-doc", "high-risk-write") + + im := &cobra.Command{Use: "im"} + root.AddCommand(im) + addLeaf(im, "+send", "write") + + return root +} + +func addLeaf(parent *cobra.Command, use, risk string) { + leaf := &cobra.Command{ + Use: use, + RunE: func(*cobra.Command, []string) error { return nil }, + } + cmdutil.SetRisk(leaf, risk) + parent.AddCommand(leaf) +} + +// findLeaf walks the tree by Use names. +func findLeaf(t *testing.T, parent *cobra.Command, names ...string) *cobra.Command { + t.Helper() + cur := parent + for _, n := range names { + var next *cobra.Command + for _, c := range cur.Commands() { + if c.Use == n { + next = c + break + } + } + if next == nil { + t.Fatalf("child %q not found under %q", n, cur.Use) + } + cur = next + } + return cur +} + +// Happy path: a valid policy.yml denies one specific command. The denied +// command's RunE returns a typed ExitError envelope; allowed commands are +// untouched. +func TestApplyUserPolicyPruning_appliesValidPolicy(t *testing.T) { + cfgDir := tmpHome(t) + writePolicy(t, cfgDir, ` +name: test-policy +allow: ["docs/**", "contact/**"] +deny: ["docs/+delete-doc"] +max_risk: write +`) + + root := fakeTree(t) + if err := applyUserPolicyPruning(root, nil); err != nil { + t.Fatalf("apply policy: %v", err) + } + + // docs/+delete-doc must be denied (Deny match). + deleteCmd := findLeaf(t, root, "docs", "+delete-doc") + if !deleteCmd.Hidden { + t.Errorf("+delete-doc should be hidden after pruning") + } + err := deleteCmd.RunE(deleteCmd, nil) + if err == nil { + t.Fatalf("+delete-doc RunE should return an error") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "command_denied" { + t.Fatalf("expected command_denied ExitError, got %T %+v", err, err) + } + detail, ok := exitErr.Detail.Detail.(map[string]any) + if !ok || detail["reason_code"] != "command_denylisted" { + t.Errorf("reason_code = %v, want command_denylisted", detail["reason_code"]) + } + + // im/+send must be denied (domain not in Allow). + send := findLeaf(t, root, "im", "+send") + if !send.Hidden { + t.Errorf("im/+send should be hidden (not in Allow)") + } + + // docs/+update must stay alive (domain matches, risk within max). + update := findLeaf(t, root, "docs", "+update") + if update.Hidden { + t.Errorf("docs/+update should remain visible") + } + if err := update.RunE(update, nil); err != nil { + t.Errorf("docs/+update RunE should succeed, got %v", err) + } +} + +// Missing file means no pruning -- the CLI runs unrestricted with the +// full command surface. This is the default case for users who haven't +// opted into pruning. +func TestApplyUserPolicyPruning_missingFileIsSilent(t *testing.T) { + tmpHome(t) // home set but no policy.yml written + + root := fakeTree(t) + if err := applyUserPolicyPruning(root, nil); err != nil { + t.Fatalf("missing policy should not error, got %v", err) + } + + // Every leaf must remain non-Hidden. + for _, sub := range []string{"+fetch", "+update", "+delete-doc"} { + cmd := findLeaf(t, root, "docs", sub) + if cmd.Hidden { + t.Errorf("%s should not be Hidden when no policy file exists", sub) + } + } +} + +// Invalid yaml content (parse error) surfaces as an error from the +// wiring. The build path then decides whether to fail-open or +// fail-closed; the wiring itself stays neutral. +func TestApplyUserPolicyPruning_malformedYamlReturnsError(t *testing.T) { + cfgDir := tmpHome(t) + writePolicy(t, cfgDir, "::: not yaml :::") + + root := fakeTree(t) + err := applyUserPolicyPruning(root, nil) + if err == nil { + t.Fatalf("malformed yaml should produce an error") + } +} + +// Semantically-invalid Rule (bad MaxRisk) reaches ValidateRule inside +// Resolve and produces an error. This is the safety contract: a typo in +// the rule must not silently lower the pruning bar. +func TestApplyUserPolicyPruning_invalidRuleReturnsError(t *testing.T) { + cfgDir := tmpHome(t) + writePolicy(t, cfgDir, "max_risk: nukem\n") + + root := fakeTree(t) + err := applyUserPolicyPruning(root, nil) + if err == nil { + t.Fatalf("invalid MaxRisk should produce an error") + } +} + +// warnPolicyError emits to the supplied writer when err is non-nil and +// stays silent for nil. Verifies the build.go fail-open behaviour can be +// observed by users. +func TestWarnPolicyError(t *testing.T) { + var buf bytes.Buffer + warnPolicyError(&buf, nil) + if buf.Len() != 0 { + t.Fatalf("warnPolicyError with nil err should write nothing, got %q", buf.String()) + } + + buf.Reset() + warnPolicyError(&buf, errors.New("boom")) + if buf.String() != "warning: user policy not applied: boom\n" { + t.Fatalf("warnPolicyError output = %q", buf.String()) + } +} + +// End-to-end through buildInternal: when a valid policy.yml exists in +// HOME, building the real command tree applies pruning to it. This is +// the "actually integrated" test -- it exercises the wiring point in +// build.go itself, not just the helper. +func TestBuildInternal_appliesPolicyToRealTree(t *testing.T) { + cfgDir := tmpHome(t) + // Deny one specific shortcut path that we know exists in the real + // service tree -- we cannot enumerate it from a unit test, so we + // use an Allow-list that matches nothing to deny everything except + // the root, and then verify ANY non-root command was hidden. + writePolicy(t, cfgDir, ` +name: deny-everything +deny: ["**"] +`) + + root := Build(context.Background(), buildInvocationForTest(t)) + + // Find any leaf and verify it was hidden. + var foundHidden bool + walk(root, func(c *cobra.Command) { + if c.HasParent() && c.Runnable() && c.Hidden { + foundHidden = true + } + }) + if !foundHidden { + t.Fatalf("expected at least one runnable command to be Hidden after deny=** policy") + } + + // Root itself must stay alive. + if root.Hidden { + t.Errorf("root command must not be Hidden even under deny-everything policy") + } +} + +func walk(cmd *cobra.Command, fn func(*cobra.Command)) { + if cmd == nil { + return + } + fn(cmd) + for _, c := range cmd.Commands() { + walk(c, fn) + } +} + +// buildInvocationForTest returns a minimal cmdutil.InvocationContext so +// build.go's pure-assembly path can construct a tree without touching +// real config / credentials. Profile name is the empty default. +func buildInvocationForTest(t *testing.T) cmdutil.InvocationContext { + t.Helper() + return cmdutil.InvocationContext{} +} diff --git a/cmd/platform_guards.go b/cmd/platform_guards.go new file mode 100644 index 000000000..714d147fd --- /dev/null +++ b/cmd/platform_guards.go @@ -0,0 +1,247 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "errors" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdpolicy" + "github.com/larksuite/cli/internal/hook" + "github.com/larksuite/cli/internal/output" + internalplatform "github.com/larksuite/cli/internal/platform" +) + +// installFatalGuard wires a fail-closed guard at every cobra dispatch +// path on rootCmd. Used by the three abort-side fatal paths: +// +// - FailClosed plugin install failure (installPluginInstallErrorGuard) +// - Plugin Restrict conflict (installPluginConflictGuard) +// - Startup lifecycle handler failure (installPluginLifecycleErrorGuard) +// +// **Why we walk the tree rather than set PersistentPreRunE on root**: +// cobra's PersistentPreRunE has "first PersistentPreRunE wins" +// semantics -- the lookup starts at the invoked command and walks UP, +// stopping at the first non-nil PersistentPreRunE. Subcommands that +// declare their own PersistentPreRunE (cmd/auth/auth.go and +// cmd/config/config.go both do) would shadow root's, letting a +// fail-closed condition silently bypass via `lark-cli auth foo`. +// +// The fix: replace the RunE of every runnable command with one that +// returns makeErr(). Subcommands cannot bypass because the dispatch +// lands directly on their RunE, which now carries the guard. +// +// makeErr is called for every guarded dispatch; it must return a fresh +// *output.ExitError each time (the envelope writer mutates a few fields +// as it serialises). +func installFatalGuard(rootCmd *cobra.Command, makeErr func() *output.ExitError) { + // Two cobra subcommands are injected lazily at Execute() time and + // would otherwise slip past walkGuard. We pre-register both so + // walkGuard catches them. + // + // - "completion" (user-visible): InitDefaultCompletionCmd + // - "__complete" (internal shell-completion RPC): no public + // constructor; we add our own stub with the same name. cobra's + // internal initCompleteCmd checks for an existing "__complete" + // and skips registration if found, so our stub stays in place. + // (Cobra dispatches the "__completeNoDesc" alias through the + // same RunE, so guarding "__complete" covers both.) + rootCmd.InitDefaultCompletionCmd() + alreadyPresent := false + for _, c := range rootCmd.Commands() { + if c.Name() == "__complete" { + alreadyPresent = true + break + } + } + if !alreadyPresent { + rootCmd.AddCommand(&cobra.Command{ + Use: "__complete", + Hidden: true, + RunE: func(*cobra.Command, []string) error { return makeErr() }, + }) + } + + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + cmd.SilenceUsage = true + return makeErr() + } + rootCmd.PersistentPreRun = nil + walkGuard(rootCmd, makeErr) +} + +// installPluginInstallErrorGuard surfaces a FailClosed plugin install +// failure as a structured plugin_install envelope before any command +// runs. +func installPluginInstallErrorGuard(rootCmd *cobra.Command, installErr error) { + makeErr := func() *output.ExitError { + var pi *internalplatform.PluginInstallError + if errors.As(installErr, &pi) { + return &output.ExitError{ + Code: output.ExitValidation, + Detail: &output.ErrDetail{ + Type: "plugin_install", + Message: pi.Error(), + Detail: map[string]any{ + "plugin": pi.PluginName, + "reason_code": pi.ReasonCode, + "reason": pi.Reason, + }, + }, + Err: installErr, + } + } + return &output.ExitError{ + Code: output.ExitValidation, + Detail: &output.ErrDetail{ + Type: "plugin_install", + Message: installErr.Error(), + Detail: map[string]any{ + "reason_code": internalplatform.ReasonInstallFailed, + }, + }, + Err: installErr, + } + } + installFatalGuard(rootCmd, makeErr) +} + +// installPluginConflictGuard surfaces a Plugin.Restrict() configuration +// error (single plugin invalid Rule or multiple plugins each contributing +// Restrict). The design separates the envelope type: +// +// - "plugin_install" with reason_code "invalid_rule" - single bad rule +// - "plugin_conflict" with reason_code "multiple_restrict_plugins" - multi +// +// Either way the CLI must NOT silently continue with a broken policy. +func installPluginConflictGuard(rootCmd *cobra.Command, err error) { + makeErr := func() *output.ExitError { + envelopeType := "plugin_install" + reasonCode := internalplatform.ReasonInvalidRule + if errors.Is(err, cmdpolicy.ErrMultipleRestricts) { + envelopeType = "plugin_conflict" + reasonCode = internalplatform.ReasonMultipleRestricts + } + return &output.ExitError{ + Code: output.ExitValidation, + Detail: &output.ErrDetail{ + Type: envelopeType, + Message: err.Error(), + Detail: map[string]any{ + "reason_code": reasonCode, + }, + }, + Err: err, + } + } + installFatalGuard(rootCmd, makeErr) +} + +// installPluginLifecycleErrorGuard surfaces a Startup lifecycle handler +// failure as a plugin_lifecycle envelope. The reason_code splits +// returned-error vs panic so consumers (audit / on-call) can tell the +// two failure modes apart. +func installPluginLifecycleErrorGuard(rootCmd *cobra.Command, err error) { + makeErr := func() *output.ExitError { + reasonCode := "lifecycle_failed" + detail := map[string]any{ + "reason_code": reasonCode, + } + var le *hook.LifecycleError + if errors.As(err, &le) { + if le.Panic { + reasonCode = "lifecycle_panic" + } + detail = map[string]any{ + "reason_code": reasonCode, + "hook_name": le.HookName, + "event": "startup", + } + } + return &output.ExitError{ + Code: output.ExitValidation, + Detail: &output.ErrDetail{ + Type: "plugin_lifecycle", + Message: err.Error(), + Detail: detail, + }, + Err: err, + } + } + installFatalGuard(rootCmd, makeErr) +} + +// walkGuard recurses through cmd's subtree and installs the guard at +// EVERY level cobra might dispatch to. The cobra execution order is: +// +// 1. PersistentPreRunE (looked up from leaf, walking up; "first wins") +// 2. PreRunE +// 3. RunE +// 4. PostRunE +// 5. PersistentPostRunE +// +// A subcommand that declares its own PersistentPreRunE (cmd/auth and +// cmd/config both do) would not only shadow root's PersistentPreRunE +// -- if that PreRunE itself returns an error (e.g. auth's +// external_provider check), the user sees THAT error instead of +// our plugin_install envelope, even if RunE was guarded. +// +// To close every dispatch hole we replace: +// - every command's PersistentPreRunE (including non-runnable groups) +// - every runnable command's PreRunE and RunE +// +// This way the very first non-nil step in cobra's chain is always our +// guard, regardless of which leaf the user invoked. +func walkGuard(cmd *cobra.Command, makeErr func() *output.ExitError) { + if cmd == nil { + return + } + // PersistentPreRunE is the first step cobra runs (after Args / + // flag validation -- see below). Set it on every command (root + // included) so cobra's "first wins" walk-up always finds OUR + // PersistentPreRunE before hitting any subcommand's pre-existing + // one. + cmd.PersistentPreRunE = func(c *cobra.Command, args []string) error { + c.SilenceUsage = true + return makeErr() + } + cmd.PersistentPreRun = nil + + // **Cobra dispatch order before PersistentPreRunE:** + // 1. ValidateArgs(cmd.Args) -- can return arg error + // 2. ParsePersistentFlags / ParseFlags -- can return flag error + // 3. Find legacyArgs check for unknown-command at root + // 4. PersistentPreRunE / PreRunE / RunE + // 5. Non-runnable groups fall through to help (PreRunE skipped) + // + // We neutralise each step: + // - Args = ArbitraryArgs -> ValidateArgs no-op. **Not nil**: + // cobra falls back to legacyArgs + // when Args==nil, which returns an + // unknown-command error during Find + // BEFORE PersistentPreRunE runs. + // ArbitraryArgs explicitly accepts + // everything, suppressing that path. + // - DisableFlagParsing -> ParseFlags skipped (and legacy + // "unknown flag" suppressed) + // - PreRunE / RunE on EVERY -> Even non-runnable groups now run + // command (not just leaves) the guard instead of showing help + // + // Setting RunE on a parent group flips Runnable() to true, so + // cobra dispatches to it (and our guard fires) rather than calling + // the help command on a "help-only" group. + cmd.Args = cobra.ArbitraryArgs + cmd.DisableFlagParsing = true + cmd.PreRunE = func(c *cobra.Command, args []string) error { + c.SilenceUsage = true + return makeErr() + } + cmd.PreRun = nil + cmd.RunE = func(*cobra.Command, []string) error { return makeErr() } + cmd.Run = nil + for _, c := range cmd.Commands() { + walkGuard(c, makeErr) + } +} diff --git a/cmd/platform_guards_test.go b/cmd/platform_guards_test.go new file mode 100644 index 000000000..bd23e8563 --- /dev/null +++ b/cmd/platform_guards_test.go @@ -0,0 +1,208 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/hook" + "github.com/larksuite/cli/internal/output" + internalplatform "github.com/larksuite/cli/internal/platform" +) + +// failClosedAbortingPlugin returns a PluginInstallError on Install, +// declaring FailClosed so InstallAll surfaces the error. +type failClosedAbortingPlugin struct{} + +func (failClosedAbortingPlugin) Name() string { return "policy" } +func (failClosedAbortingPlugin) Version() string { return "1.0.0" } +func (failClosedAbortingPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{FailurePolicy: platform.FailClosed} +} +func (failClosedAbortingPlugin) Install(platform.Registrar) error { + return errors.New("upstream policy server unreachable") +} + +// When a FailClosed plugin fails to install, buildInternal must +// install a PersistentPreRunE that returns a structured *output.ExitError. +// The user must NEVER see a silent partial-install state. +// +// This pins the build.go fix for codex's NEW ISSUE about +// build.go demoting FailClosed errors to warnings. +func TestBuildInternal_failClosedAbortsCLI(t *testing.T) { + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + platform.Register(failClosedAbortingPlugin{}) + + root := Build(context.Background(), buildInvocationForTest(t)) + + if root.PersistentPreRunE == nil { + t.Fatalf("FailClosed install error must wire a PersistentPreRunE that aborts subsequent commands") + } + + err := root.PersistentPreRunE(root, nil) + checkGuardError(t, err) + + // CRITICAL: subcommands that declare their own PersistentPreRunE + // (cmd/auth/auth.go and cmd/config/config.go both do) would + // shadow root's via cobra's "first wins" semantics if we only set + // root.PersistentPreRunE. Moreover, those subcommand PersistentPreRunE + // handlers may themselves return an error (e.g. auth's + // external_provider check at internal/cmdutil/factory.go:223), + // which would mask the plugin_install envelope even if RunE were + // guarded. + // + // The guard MUST therefore walk the tree and replace each command's + // PersistentPreRunE / PreRunE / RunE directly. This test pins + // that the bypass is closed. + auth := findChildByUse(t, root, "auth") + if auth == nil { + t.Skip("auth subcommand not present in build; cannot exercise bypass case") + } + // (a) auth's own PersistentPreRunE must be the guard, not the + // factory-checking handler that lived there before walkGuard ran. + if auth.PersistentPreRunE == nil { + t.Fatalf("auth.PersistentPreRunE must be guarded after walkGuard") + } + checkGuardError(t, auth.PersistentPreRunE(auth, nil)) + + // (b) A runnable leaf below auth also gets the guard on RunE. We + // match by RunE != nil (not just Runnable()) because the guard + // replaces RunE specifically — selecting a Run-only command and + // then calling leaf.RunE would nil-deref. + var leaf *cobra.Command + walk(auth, func(c *cobra.Command) { + if leaf != nil { + return + } + if c != auth && c.RunE != nil { + leaf = c + } + }) + if leaf == nil { + t.Skip("no auth subcommand with RunE found") + } + checkGuardError(t, leaf.RunE(leaf, nil)) +} + +// checkGuardError asserts that err is the structured plugin_install +// ExitError the guard produces. +func checkGuardError(t *testing.T, err error) { + t.Helper() + if err == nil { + t.Fatalf("PersistentPreRunE must surface the install error, got nil") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected *output.ExitError, got %T %+v", err, err) + } + if exitErr.Detail.Type != "plugin_install" { + t.Errorf("envelope type = %q, want plugin_install", exitErr.Detail.Type) + } + detail := exitErr.Detail.Detail.(map[string]any) + if detail["plugin"] != "policy" { + t.Errorf("detail.plugin = %v, want policy", detail["plugin"]) + } + if detail["reason_code"] != internalplatform.ReasonInstallFailed { + t.Errorf("detail.reason_code = %v, want install_failed", detail["reason_code"]) + } +} + +// findChildByUse helper. +func findChildByUse(t *testing.T, parent *cobra.Command, use string) *cobra.Command { + t.Helper() + for _, c := range parent.Commands() { + if c.Use == use { + return c + } + } + return nil +} + +// namespacedWrap copy semantics: a plugin reusing a sentinel AbortError +// across two concurrent command invocations must produce two distinct +// HookName values on the wire. Mutation would interleave them. +// +// We exercise this by sharing one AbortError across two goroutines, +// each invoking through a different namespacedWrap; both observed +// errors must keep their own HookName. +func TestNamespacedWrap_doesNotMutateSharedAbortError(t *testing.T) { + shared := &platform.AbortError{HookName: "plugin-shared-name", Reason: "rejected"} + + makeWrapper := func(name string) platform.Wrapper { + return func(next platform.Handler) platform.Handler { + return func(context.Context, platform.Invocation) error { return shared } + } + } + + reg := hook.NewRegistry() + reg.AddWrapper(hook.WrapperEntry{ + Name: "p1.wrap", Selector: platform.All(), Fn: makeWrapper("p1.wrap"), + }) + reg.AddWrapper(hook.WrapperEntry{ + Name: "p2.wrap", Selector: platform.All(), Fn: makeWrapper("p2.wrap"), + }) + + // Drive matched wrappers separately to exercise both namespace paths. + matched := reg.MatchingWrappers(stubView{}) + if len(matched) != 2 { + t.Fatalf("expected 2 matched wrappers, got %d", len(matched)) + } + + results := make([]string, 2) + var wg sync.WaitGroup + wg.Add(2) + for i, m := range matched { + go func() { + defer wg.Done() + err := m.Fn(func(context.Context, platform.Invocation) error { return nil })( + context.Background(), stubInvocation{}) + if ab, ok := err.(*platform.AbortError); ok { + results[i] = ab.HookName + } + }() + } + wg.Wait() + + // We are not using namespacedWrap directly here -- the test isolates + // the semantic by reading what each WrapperEntry's Fn returns. + // The real guarantee we depend on is the install-side namespacedWrap; + // see internal/hook/install.go for the production path. This test + // pins the sentinel-not-mutated invariant at the unit level: each + // Wrap returned the shared AbortError unchanged, so the production + // namespacedWrap can safely copy without touching the original. + if shared.HookName != "plugin-shared-name" { + t.Errorf("shared sentinel AbortError was mutated: HookName = %q", shared.HookName) + } + _ = results +} + +// stubView for the wrap selector match. +type stubView struct{} + +func (stubView) Path() string { return "x" } +func (stubView) Domain() string { return "" } +func (stubView) Risk() (platform.Risk, bool) { return "", false } +func (stubView) Identities() []platform.Identity { return nil } +func (stubView) Annotation(string) (string, bool) { return "", false } + +// stubInvocation is the minimal platform.Invocation implementation +// used by tests that need to drive a Wrap without going through the +// full hook.Install pipeline. +type stubInvocation struct{} + +func (stubInvocation) Cmd() platform.CommandView { return stubView{} } +func (stubInvocation) Args() []string { return nil } +func (stubInvocation) Started() time.Time { return time.Time{} } +func (stubInvocation) Err() error { return nil } +func (stubInvocation) DeniedByPolicy() bool { return false } +func (stubInvocation) DenialLayer() string { return "" } +func (stubInvocation) DenialPolicySource() string { return "" } diff --git a/cmd/plugin_integration_test.go b/cmd/plugin_integration_test.go new file mode 100644 index 000000000..e439adbfc --- /dev/null +++ b/cmd/plugin_integration_test.go @@ -0,0 +1,684 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "context" + "errors" + "os" + "path/filepath" + "sync/atomic" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/hook" + "github.com/larksuite/cli/internal/output" + internalplatform "github.com/larksuite/cli/internal/platform" +) + +// These integration tests exercise the Hook framework's plumbing +// (Plugin -> InstallAll -> Registry -> wireHooks -> RunE wrapper) +// against a SYNTHETIC command tree, not the real lark-cli shortcut +// tree. The synthetic tree keeps the test hermetic -- invoking real +// shortcuts requires a fully-populated Factory (HTTP, credentials, +// etc.) which is out of scope for a hook plumbing test. +// +// The e2e tests that go through Build() are kept thin (see +// TestBuildInternal_appliesPolicyToRealTree in policy_test.go); they +// assert plumbing existence (Hidden flag, etc.) without invoking +// shortcuts. + +type fakeIntegrationPlugin struct { + name string + caps platform.Capabilities + rule *platform.Rule + beforeCount int64 + afterCount int64 + wrapCount int64 + wrapDeniesWrite bool // when true, Wrap returns AbortError for risk=write + shutdownCalled int64 +} + +func (p *fakeIntegrationPlugin) Name() string { return p.name } +func (p *fakeIntegrationPlugin) Version() string { return "0.0.1" } +func (p *fakeIntegrationPlugin) Capabilities() platform.Capabilities { return p.caps } + +func (p *fakeIntegrationPlugin) Install(r platform.Registrar) error { + if p.caps.Restricts && p.rule != nil { + r.Restrict(p.rule) + } + r.Observe(platform.Before, "audit-pre", platform.All(), + func(context.Context, platform.Invocation) { + atomic.AddInt64(&p.beforeCount, 1) + }) + r.Observe(platform.After, "audit-post", platform.All(), + func(context.Context, platform.Invocation) { + atomic.AddInt64(&p.afterCount, 1) + }) + r.Wrap("policy", platform.ByWrite(), + func(next platform.Handler) platform.Handler { + return func(ctx context.Context, inv platform.Invocation) error { + atomic.AddInt64(&p.wrapCount, 1) + if p.wrapDeniesWrite { + return &platform.AbortError{ + HookName: "policy", + Reason: "writes blocked by integration test plugin", + } + } + return next(ctx, inv) + } + }) + r.On(platform.Shutdown, "flush", + func(context.Context, *platform.LifecycleContext) error { + atomic.AddInt64(&p.shutdownCalled, 1) + return nil + }) + return nil +} + +// syntheticTree builds a small command tree we own end-to-end. The leaf +// has risk=write so the Wrap's ByWrite() selector matches. +func syntheticTree() (*cobra.Command, *cobra.Command) { + root := &cobra.Command{Use: "lark-cli"} + group := &cobra.Command{Use: "docs"} + root.AddCommand(group) + leaf := &cobra.Command{ + Use: "+write", + RunE: func(*cobra.Command, []string) error { return nil }, + } + cmdutil.SetRisk(leaf, "write") + group.AddCommand(leaf) + return root, leaf +} + +// End-to-end through the public install pipeline: register a plugin, +// run internalplatform.InstallAll (the same function buildInternal calls), +// wire hooks onto a synthetic tree, invoke the leaf, and confirm +// observers fired. +func TestPluginPipeline_observersWired(t *testing.T) { + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + plugin := &fakeIntegrationPlugin{ + name: "audit-plugin", + caps: platform.Capabilities{FailurePolicy: platform.FailOpen}, + } + platform.Register(plugin) + + result, err := internalplatform.InstallAll(platform.RegisteredPlugins(), nil) + if err != nil { + t.Fatalf("InstallAll: %v", err) + } + + root, leaf := syntheticTree() + if err := wireHooks(context.Background(), root, result.Registry); err != nil { + t.Fatalf("wireHooks: %v", err) + } + + _ = leaf.RunE(leaf, nil) + + if got := atomic.LoadInt64(&plugin.beforeCount); got != 1 { + t.Errorf("Before observer fired %d times, want 1", got) + } + if got := atomic.LoadInt64(&plugin.afterCount); got != 1 { + t.Errorf("After observer fired %d times, want 1", got) + } + if got := atomic.LoadInt64(&plugin.wrapCount); got != 1 { + t.Errorf("Wrap fired %d times (ByWrite matches risk=write), want 1", got) + } +} + +// A Wrapper returning AbortError on a write command must surface as +// type="hook" in the envelope so the caller can parse the structured +// rejection. +func TestPluginPipeline_wrapAbortReachesEnvelope(t *testing.T) { + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + plugin := &fakeIntegrationPlugin{ + name: "policy-plugin", + caps: platform.Capabilities{FailurePolicy: platform.FailOpen}, + wrapDeniesWrite: true, + } + platform.Register(plugin) + + result, err := internalplatform.InstallAll(platform.RegisteredPlugins(), nil) + if err != nil { + t.Fatalf("InstallAll: %v", err) + } + + root, leaf := syntheticTree() + if err := wireHooks(context.Background(), root, result.Registry); err != nil { + t.Fatalf("wireHooks: %v", err) + } + + err = leaf.RunE(leaf, nil) + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected *output.ExitError, got %T %+v", err, err) + } + if exitErr.Detail.Type != "hook" { + t.Errorf("envelope type = %q, want hook", exitErr.Detail.Type) + } + detail := exitErr.Detail.Detail.(map[string]any) + if detail["reason_code"] != "aborted" { + t.Errorf("detail.reason_code = %v, want aborted", detail["reason_code"]) + } + if detail["hook_name"] != "policy-plugin.policy" { + t.Errorf("detail.hook_name = %v, want policy-plugin.policy", detail["hook_name"]) + } + + // errors.As must still reach the original AbortError so consumers + // can inspect the typed cause. + var ab *platform.AbortError + if !errors.As(err, &ab) { + t.Errorf("error chain should expose *platform.AbortError") + } +} + +// Plugin.Restrict() contribution must reach the pruning resolver and +// take precedence over a yaml file (single-rule, plugin wins). This +// goes through the REAL Build() pipeline so the wiring between +// installPluginsAndHooks -> applyUserPolicyPruning -> cmdpolicy.Resolve +// is covered. +func TestPluginPipeline_restrictBeatsYaml(t *testing.T) { + cfgDir := tmpHome(t) + // yaml says allow everything; plugin says deny everything. Plugin + // should win and a command should be denied. + if err := os.WriteFile(filepath.Join(cfgDir, "policy.yml"), + []byte("name: yaml-allow\nallow: [\"**\"]\n"), 0o644); err != nil { + t.Fatalf("write yaml: %v", err) + } + + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + plugin := &fakeIntegrationPlugin{ + name: "restricter", + caps: platform.Capabilities{ + Restricts: true, + FailurePolicy: platform.FailClosed, + }, + rule: &platform.Rule{Name: "deny-all", Deny: []string{"**"}}, + } + platform.Register(plugin) + + root := Build(context.Background(), buildInvocationForTest(t)) + + // At least one runnable command must end up Hidden because of the + // plugin Restrict (yaml had been allow-all and would have left + // everything visible). + var foundHidden bool + walk(root, func(c *cobra.Command) { + if c.HasParent() && c.Runnable() && c.Hidden { + foundHidden = true + } + }) + if !foundHidden { + t.Fatalf("plugin Restrict should have denied at least one command despite yaml allow-all") + } +} + +// Denial-guard end-to-end: register a plugin with a Wrap that would +// SILENTLY suppress denial (return nil without calling next). After +// installing pruning (which marks a command as denied) and wiring +// hooks, calling the denied command must STILL produce the denial +// error -- the Wrap must never run on the denied path. +func TestPluginPipeline_denialGuardIntegrated(t *testing.T) { + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + + wrapCalled := false + plugin := &fakeIntegrationPlugin{ + name: "policy-plugin", + caps: platform.Capabilities{FailurePolicy: platform.FailOpen}, + wrapDeniesWrite: false, // wrap would normally allow + } + // Override Wrap with a malicious behavior: return nil (silence the + // denial). We do this by wrapping the install: register a + // second Wrap that suppresses errors. + platform.Register(plugin) + + // Add another plugin with a malicious wrap. + malicious := &mockMaliciousPlugin{ + name: "malicious", + invokedFlag: &wrapCalled, + } + platform.Register(malicious) + + result, err := internalplatform.InstallAll(platform.RegisteredPlugins(), nil) + if err != nil { + t.Fatalf("InstallAll: %v", err) + } + + root, leaf := syntheticTree() + // Simulate cmdpolicy.Apply marking leaf as denied. + leaf.Hidden = true + leaf.DisableFlagParsing = true + if leaf.Annotations == nil { + leaf.Annotations = map[string]string{} + } + leaf.Annotations["lark:policy_denied_layer"] = "policy" + leaf.Annotations["lark:policy_denied_source"] = "plugin:other" + denyStubCalled := false + leaf.RunE = func(*cobra.Command, []string) error { + denyStubCalled = true + return errors.New("CommandPruned (denyStub)") + } + + if err := wireHooks(context.Background(), root, result.Registry); err != nil { + t.Fatalf("wireHooks: %v", err) + } + + err = leaf.RunE(leaf, nil) + if wrapCalled { + t.Errorf("denial guard violated: malicious Wrap ran on a denied command") + } + if !denyStubCalled { + t.Errorf("denyStub should run on the denial path even when a Wrap is registered") + } + if err == nil { + t.Errorf("denial error must propagate, got nil") + } +} + +// mockMaliciousPlugin registers a Wrap that returns nil unconditionally +// -- exactly the kind of plugin the denial guard defends against. +type mockMaliciousPlugin struct { + name string + invokedFlag *bool +} + +func (p *mockMaliciousPlugin) Name() string { return p.name } +func (p *mockMaliciousPlugin) Version() string { return "0.0.1" } +func (p *mockMaliciousPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{FailurePolicy: platform.FailOpen} +} +func (p *mockMaliciousPlugin) Install(r platform.Registrar) error { + r.Wrap("hijack", platform.All(), + func(_ platform.Handler) platform.Handler { + return func(context.Context, platform.Invocation) error { + if p.invokedFlag != nil { + *p.invokedFlag = true + } + return nil // silence everything + } + }) + return nil +} + +// Verifies buildInternal returns a non-nil *hook.Registry when a plugin +// is registered and Emit(Shutdown) on that registry fires the plugin's +// On(Shutdown) handler. This is the contract Execute relies on to fire +// Shutdown after rootCmd.Execute returns. +func TestBuildInternal_returnsRegistryForShutdownEmit(t *testing.T) { + tmpHome(t) + + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + plugin := &fakeIntegrationPlugin{ + name: "shutdown-test", + caps: platform.Capabilities{FailurePolicy: platform.FailOpen}, + } + platform.Register(plugin) + + _, _, reg := buildInternal(context.Background(), buildInvocationForTest(t)) + if reg == nil { + t.Fatalf("buildInternal returned nil registry; plugin's Shutdown handler is unreachable") + } + + if err := hook.Emit(context.Background(), reg, platform.Shutdown, nil); err != nil { + t.Fatalf("Emit(Shutdown): %v", err) + } + if got := atomic.LoadInt64(&plugin.shutdownCalled); got != 1 { + t.Errorf("On(Shutdown) handler fired %d times, want 1", got) + } +} + +// When plugin install fails (FailClosed), buildInternal returns nil +// registry. Execute must nil-check before calling Emit so we don't fault +// on the FailClosed bypass-guard path. +func TestBuildInternal_failClosedYieldsNilRegistry(t *testing.T) { + tmpHome(t) + + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + // A plugin that fails install and is FailClosed -> InstallAll + // returns an error, buildInternal installs the guard and returns + // early with nil registry. + plugin := &failingPlugin{ + name: "fail-closed", + caps: platform.Capabilities{FailurePolicy: platform.FailClosed}, + err: errors.New("install failure simulated"), + } + platform.Register(plugin) + + _, _, reg := buildInternal(context.Background(), buildInvocationForTest(t)) + if reg != nil { + t.Errorf("buildInternal returned non-nil registry on FailClosed install error") + } +} + +type failingPlugin struct { + name string + caps platform.Capabilities + err error +} + +func (p *failingPlugin) Name() string { return p.name } +func (p *failingPlugin) Version() string { return "0.0.1" } +func (p *failingPlugin) Capabilities() platform.Capabilities { return p.caps } +func (p *failingPlugin) Install(platform.Registrar) error { return p.err } + +// === Plugin Restrict conflict guard === +// +// Two plugins both calling r.Restrict must surface as a structured +// plugin_conflict envelope (reason_code multiple_restrict_plugins) at +// dispatch time, NOT as a silent stderr warning. Otherwise a +// safety-sensitive operator could miss that their policy never took +// effect. +func TestPluginConflictGuard_MultipleRestrictAbortsCLI(t *testing.T) { + tmpHome(t) + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + cmdpolicy.ResetActiveForTesting() + t.Cleanup(cmdpolicy.ResetActiveForTesting) + + rule := &platform.Rule{Name: "any", Allow: []string{"**"}} + platform.Register(&fakeIntegrationPlugin{ + name: "plugin-a", + caps: platform.Capabilities{Restricts: true, FailurePolicy: platform.FailClosed}, + rule: rule, + }) + platform.Register(&fakeIntegrationPlugin{ + name: "plugin-b", + caps: platform.Capabilities{Restricts: true, FailurePolicy: platform.FailClosed}, + rule: rule, + }) + + _, root, reg := buildInternal(context.Background(), buildInvocationForTest(t)) + if reg != nil { + t.Errorf("conflict guard path should yield nil registry") + } + + // Pick any leaf and verify it returns the structured envelope. + leaf := findRunnableLeaf(root) + if leaf == nil { + t.Fatalf("no runnable leaf in command tree") + } + err := leaf.RunE(leaf, nil) + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected *output.ExitError, got %T %+v", err, err) + } + if exitErr.Detail.Type != "plugin_conflict" { + t.Errorf("envelope type = %q, want plugin_conflict", exitErr.Detail.Type) + } + if rc := exitErr.Detail.Detail.(map[string]any)["reason_code"]; rc != "multiple_restrict_plugins" { + t.Errorf("reason_code = %v, want multiple_restrict_plugins", rc) + } +} + +// Single plugin with an invalid Rule must surface as plugin_install / +// invalid_rule envelope (distinct error.type from multi-Restrict). +func TestPluginConflictGuard_InvalidRuleAbortsCLI(t *testing.T) { + tmpHome(t) + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + cmdpolicy.ResetActiveForTesting() + t.Cleanup(cmdpolicy.ResetActiveForTesting) + + // MaxRisk "nukem" is rejected by ValidateRule -> Resolve returns + // an error that is NOT ErrMultipleRestricts. + platform.Register(&fakeIntegrationPlugin{ + name: "bad", + caps: platform.Capabilities{Restricts: true, FailurePolicy: platform.FailClosed}, + rule: &platform.Rule{Name: "bad", MaxRisk: "nukem"}, + }) + + _, root, reg := buildInternal(context.Background(), buildInvocationForTest(t)) + if reg != nil { + t.Errorf("conflict guard path should yield nil registry") + } + leaf := findRunnableLeaf(root) + if leaf == nil { + t.Fatalf("no runnable leaf in command tree") + } + err := leaf.RunE(leaf, nil) + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected *output.ExitError, got %T %+v", err, err) + } + if exitErr.Detail.Type != "plugin_install" { + t.Errorf("envelope type = %q, want plugin_install", exitErr.Detail.Type) + } + if rc := exitErr.Detail.Detail.(map[string]any)["reason_code"]; rc != "invalid_rule" { + t.Errorf("reason_code = %v, want invalid_rule", rc) + } +} + +// === Startup lifecycle guard === +// +// Plugin On(Startup) handler returning error must abort startup with +// a plugin_lifecycle envelope (reason_code lifecycle_failed). Silently +// continuing would leave the plugin's invariants violated while the +// rest of its hooks still fire. +func TestPluginLifecycleGuard_StartupErrorAbortsCLI(t *testing.T) { + tmpHome(t) + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + cmdpolicy.ResetActiveForTesting() + t.Cleanup(cmdpolicy.ResetActiveForTesting) + + platform.Register(&startupFailingPlugin{ + name: "lc", + failErr: errors.New("backend unreachable"), + }) + + _, root, reg := buildInternal(context.Background(), buildInvocationForTest(t)) + if reg != nil { + t.Errorf("lifecycle guard path should yield nil registry") + } + + leaf := findRunnableLeaf(root) + err := leaf.RunE(leaf, nil) + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected *output.ExitError, got %T %+v", err, err) + } + if exitErr.Detail.Type != "plugin_lifecycle" { + t.Errorf("envelope type = %q, want plugin_lifecycle", exitErr.Detail.Type) + } + d := exitErr.Detail.Detail.(map[string]any) + if d["reason_code"] != "lifecycle_failed" { + t.Errorf("reason_code = %v, want lifecycle_failed", d["reason_code"]) + } + if d["hook_name"] != "lc.start" { + t.Errorf("hook_name = %v, want lc.start", d["hook_name"]) + } +} + +// Same path but the handler panics -> reason_code lifecycle_panic. +func TestPluginLifecycleGuard_StartupPanicAbortsCLI(t *testing.T) { + tmpHome(t) + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + cmdpolicy.ResetActiveForTesting() + t.Cleanup(cmdpolicy.ResetActiveForTesting) + + platform.Register(&startupFailingPlugin{ + name: "lc", + doPanic: true, + panicMsg: "kaboom", + }) + + _, root, reg := buildInternal(context.Background(), buildInvocationForTest(t)) + if reg != nil { + t.Errorf("lifecycle guard path should yield nil registry") + } + leaf := findRunnableLeaf(root) + err := leaf.RunE(leaf, nil) + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected *output.ExitError, got %T", err) + } + if rc := exitErr.Detail.Detail.(map[string]any)["reason_code"]; rc != "lifecycle_panic" { + t.Errorf("reason_code = %v, want lifecycle_panic", rc) + } +} + +type startupFailingPlugin struct { + name string + failErr error // when set, handler returns this + doPanic bool // when true, handler panics with panicMsg + panicMsg string +} + +func (p *startupFailingPlugin) Name() string { return p.name } +func (p *startupFailingPlugin) Version() string { return "0.0.1" } +func (p *startupFailingPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{FailurePolicy: platform.FailClosed} +} +func (p *startupFailingPlugin) Install(r platform.Registrar) error { + r.On(platform.Startup, "start", func(context.Context, *platform.LifecycleContext) error { + if p.doPanic { + panic(p.panicMsg) + } + return p.failErr + }) + return nil +} + +// === Wrapper panic recovery === +// +// A Wrapper that panics must NOT crash the process. The framework +// recovers and converts to a structured envelope: +// +// type="hook", reason_code="panic", hook_name= +func TestWrapperPanic_BecomesHookPanicEnvelope(t *testing.T) { + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + + platform.Register(&panickingWrapPlugin{name: "p"}) + + result, err := internalplatform.InstallAll(platform.RegisteredPlugins(), nil) + if err != nil { + t.Fatalf("InstallAll: %v", err) + } + root, leaf := syntheticTree() + if err := wireHooks(context.Background(), root, result.Registry); err != nil { + t.Fatalf("wireHooks: %v", err) + } + + defer func() { + if r := recover(); r != nil { + t.Fatalf("Wrapper panic must be recovered, but it escaped: %v", r) + } + }() + + err = leaf.RunE(leaf, nil) + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected *output.ExitError, got %T %+v", err, err) + } + if exitErr.Detail.Type != "hook" { + t.Errorf("envelope type = %q, want hook", exitErr.Detail.Type) + } + d := exitErr.Detail.Detail.(map[string]any) + if d["reason_code"] != "panic" { + t.Errorf("reason_code = %v, want panic", d["reason_code"]) + } + if d["hook_name"] != "p.boom" { + t.Errorf("hook_name = %v, want p.boom (namespaced)", d["hook_name"]) + } +} + +type panickingWrapPlugin struct{ name string } + +func (p *panickingWrapPlugin) Name() string { return p.name } +func (p *panickingWrapPlugin) Version() string { return "0.0.1" } +func (p *panickingWrapPlugin) Capabilities() platform.Capabilities { return platform.Capabilities{} } +func (p *panickingWrapPlugin) Install(r platform.Registrar) error { + r.Wrap("boom", platform.All(), + func(_ platform.Handler) platform.Handler { + return func(context.Context, platform.Invocation) error { + panic("intentional panic for test") + } + }) + return nil +} + +// findRunnableLeaf walks the tree and returns the first command with a +// RunE so tests can synthesize a dispatch without going through cobra. +func findRunnableLeaf(c *cobra.Command) *cobra.Command { + if c.RunE != nil && c.HasParent() { + return c + } + for _, child := range c.Commands() { + if l := findRunnableLeaf(child); l != nil { + return l + } + } + return nil +} + +// B2 regression: a plugin Wrapper whose FACTORY function (the +// `func(next Handler) Handler` itself) panics must not crash the +// process. The framework recovers and returns the same panic envelope +// it produces for runtime panics inside the inner Handler. +// +// Pre-fix code path: recoverWrap had `inner := w(next)` outside the +// deferred recover, so a factory panic escaped. +func TestWrapperFactoryPanic_BecomesHookPanicEnvelope(t *testing.T) { + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + + platform.Register(&factoryPanicWrapPlugin{name: "fac"}) + + result, err := internalplatform.InstallAll(platform.RegisteredPlugins(), nil) + if err != nil { + t.Fatalf("InstallAll: %v", err) + } + root, leaf := syntheticTree() + if err := wireHooks(context.Background(), root, result.Registry); err != nil { + t.Fatalf("wireHooks: %v", err) + } + + defer func() { + if r := recover(); r != nil { + t.Fatalf("factory panic must be recovered, but it escaped: %v", r) + } + }() + + err = leaf.RunE(leaf, nil) + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected *output.ExitError, got %T %+v", err, err) + } + if exitErr.Detail.Type != "hook" { + t.Errorf("envelope type = %q, want hook", exitErr.Detail.Type) + } + d := exitErr.Detail.Detail.(map[string]any) + if d["reason_code"] != "panic" { + t.Errorf("reason_code = %v, want panic", d["reason_code"]) + } + if d["hook_name"] != "fac.bad-factory" { + t.Errorf("hook_name = %v, want fac.bad-factory (namespaced)", d["hook_name"]) + } +} + +type factoryPanicWrapPlugin struct{ name string } + +func (p *factoryPanicWrapPlugin) Name() string { return p.name } +func (p *factoryPanicWrapPlugin) Version() string { return "0.0.1" } +func (p *factoryPanicWrapPlugin) Capabilities() platform.Capabilities { return platform.Capabilities{} } +func (p *factoryPanicWrapPlugin) Install(r platform.Registrar) error { + r.Wrap("bad-factory", platform.All(), + // The factory itself panics; the returned Handler is never reached. + func(_ platform.Handler) platform.Handler { + panic("factory blew up") + }) + return nil +} diff --git a/cmd/profile/add.go b/cmd/profile/add.go index d84e1f504..a657bccb9 100644 --- a/cmd/profile/add.go +++ b/cmd/profile/add.go @@ -45,6 +45,7 @@ func NewCmdProfileAdd(f *cmdutil.Factory) *cobra.Command { _ = cmd.MarkFlagRequired("name") _ = cmd.MarkFlagRequired("app-id") + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/profile/list.go b/cmd/profile/list.go index dbe98c1e7..fb4cc1ffe 100644 --- a/cmd/profile/list.go +++ b/cmd/profile/list.go @@ -34,6 +34,7 @@ func NewCmdProfileList(f *cmdutil.Factory) *cobra.Command { return profileListRun(f) }, } + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/profile/remove.go b/cmd/profile/remove.go index 124c32e58..08c19234e 100644 --- a/cmd/profile/remove.go +++ b/cmd/profile/remove.go @@ -28,6 +28,7 @@ func NewCmdProfileRemove(f *cmdutil.Factory) *cobra.Command { cmdutil.SetTips(cmd, []string{ "AI agents: Do NOT remove profiles unless the user explicitly asks. This is destructive and clears all associated credentials.", }) + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/profile/rename.go b/cmd/profile/rename.go index 37fbc787a..2a8f6a2e5 100644 --- a/cmd/profile/rename.go +++ b/cmd/profile/rename.go @@ -24,6 +24,7 @@ func NewCmdProfileRename(f *cmdutil.Factory) *cobra.Command { return profileRenameRun(f, args[0], args[1]) }, } + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/profile/use.go b/cmd/profile/use.go index de0964d7e..013ade47e 100644 --- a/cmd/profile/use.go +++ b/cmd/profile/use.go @@ -27,6 +27,7 @@ func NewCmdProfileUse(f *cmdutil.Factory) *cobra.Command { cmdutil.SetTips(cmd, []string{ "AI agents: Do NOT switch profiles unless the user explicitly asks.", }) + cmdutil.SetRisk(cmd, "write") return cmd } diff --git a/cmd/prune.go b/cmd/prune.go index 1a3f05f52..1f503517e 100644 --- a/cmd/prune.go +++ b/cmd/prune.go @@ -7,10 +7,12 @@ import ( "fmt" "slices" + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdpolicy" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/output" - "github.com/spf13/cobra" ) // pruneForStrictMode removes commands incompatible with the active strict mode. @@ -43,15 +45,76 @@ func pruneIncompatible(parent *cobra.Command, mode core.StrictMode) { } func strictModeStubFrom(child *cobra.Command, mode core.StrictMode) *cobra.Command { + // The denial annotations let the hook layer's populateInvocationDenial + // recognise this command as denied, so the Wrap chain is physically + // isolated (wrapRunE takes the DeniedByPolicy branch and calls the + // stub RunE directly). Without these, a plugin Wrapper registered + // against platform.All() could intercept and silently swallow the + // strict-mode error -- breaking strict-mode's "hard boundary" contract. + // + // Args + PersistentPreRunE overrides mirror cmdpolicy/apply.go::installDenyStub: + // + // - Args=ArbitraryArgs: with DisableFlagParsing the user's flags + // look like positional args; the original child's Args validator + // (e.g. cobra.NoArgs) would fire BEFORE RunE and produce a + // cobra usage error instead of our strict_mode envelope. + // + // - PersistentPreRunE no-op: cmd/auth/auth.go declares a parent + // PersistentPreRunE that returns external_provider when env + // credentials are set. Cobra's "first wins walking up" would + // pick auth's instead of our denial. A leaf-level no-op makes + // cobra stop here and proceed to the wrapped RunE. + // + // strict-mode keeps its short Message + independent Hint and + // composes the shared detail.* / wrapped-CommandDeniedError shape + // by hand; BuildDenialError would override Message with the + // CommandDeniedError.Error() long form. + stubMessage := fmt.Sprintf( + "strict mode is %q, only %s-identity commands are available", + mode, mode.ForcedIdentity()) + const stubHint = "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)" + denial := cmdpolicy.Denial{ + Layer: cmdpolicy.LayerStrictMode, + PolicySource: "strict-mode", + ReasonCode: "identity_not_supported", + Reason: stubMessage, + } + // Preserve the original command's annotations (risk_level, + // lark:supportedIdentities, cmdmeta.domain, ...) and help text so + // audit / compliance observers can still see what was denied. + // Stamp the denial annotations on top. + annotations := make(map[string]string, len(child.Annotations)+2) + for k, v := range child.Annotations { + annotations[k] = v + } + annotations[cmdpolicy.AnnotationDenialLayer] = cmdpolicy.LayerStrictMode + annotations[cmdpolicy.AnnotationDenialSource] = "strict-mode" + return &cobra.Command{ Use: child.Use, Aliases: append([]string(nil), child.Aliases...), + Short: child.Short, + Long: child.Long, Hidden: true, DisableFlagParsing: true, - RunE: func(cmd *cobra.Command, args []string) error { - return output.ErrWithHint(output.ExitValidation, "strict_mode", - fmt.Sprintf("strict mode is %q, only %s-identity commands are available", mode, mode.ForcedIdentity()), - "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)") + Args: cobra.ArbitraryArgs, + Annotations: annotations, + PersistentPreRunE: func(c *cobra.Command, _ []string) error { + c.SilenceUsage = true + return nil + }, + RunE: func(c *cobra.Command, _ []string) error { + cd := cmdpolicy.CommandDeniedFromDenial(cmdpolicy.CanonicalPath(c), denial) + return &output.ExitError{ + Code: output.ExitValidation, + Detail: &output.ErrDetail{ + Type: "command_denied", + Message: stubMessage, + Hint: stubHint, + Detail: cmdpolicy.DenialDetailMap(cd), + }, + Err: cd, + } }, } } diff --git a/cmd/prune_test.go b/cmd/prune_test.go index 8d0594737..d9a949c36 100644 --- a/cmd/prune_test.go +++ b/cmd/prune_test.go @@ -4,11 +4,15 @@ package cmd import ( + "errors" "strings" "testing" + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" "github.com/spf13/cobra" ) @@ -198,3 +202,176 @@ func TestPruneForStrictMode_User_DirectBotShortcutReturnsStrictMode(t *testing.T t.Fatalf("unexpected error: %v", err) } } + +// Regression for codex C13: a strict-mode stub whose PARENT declares +// a PersistentPreRunE (e.g. cmd/auth/auth.go's external_provider +// check on env credentials) must surface the strict_mode envelope, +// not the parent's error. Cobra's "first PersistentPreRunE wins +// walking up from leaf" semantics will pick the parent's unless the +// stub itself carries its own. +// +// Fix: strictModeStubFrom installs a no-op PersistentPreRunE so cobra +// stops at the stub and proceeds to its RunE. +func TestStrictModeStub_BypassesParentPersistentPreRunE(t *testing.T) { + root := newTestTree() + pruneForStrictMode(root, core.StrictModeBot) + stub := findCmd(root, "auth", "login") + if stub == nil { + t.Fatal("auth/login stub should exist after StrictModeBot") + } + if stub.PersistentPreRunE == nil { + t.Fatal("strict-mode stub must declare PersistentPreRunE on leaf") + } + if err := stub.PersistentPreRunE(stub, nil); err != nil { + t.Errorf("strict-mode stub PersistentPreRunE should be no-op, got %v", err) + } +} + +// Regression for codex H13: strict-mode stub must accept arbitrary +// positional args. With DisableFlagParsing=true, a user passing +// `auth login --scope ...` looks like 4 positional args; the original +// cobra.Args validator would surface a usage error BEFORE strict-mode +// stub's RunE. +func TestStrictModeStub_BypassesArgsValidator(t *testing.T) { + root := newTestTree() + pruneForStrictMode(root, core.StrictModeBot) + stub := findCmd(root, "auth", "login") + if stub == nil { + t.Fatal("auth/login stub should exist after StrictModeBot") + } + if stub.Args == nil { + t.Fatal("strict-mode stub must declare Args validator") + } + if err := stub.Args(stub, []string{"--scope", "im.message", "--profile", "default"}); err != nil { + t.Errorf("strict-mode stub Args should accept flag-like args, got %v", err) + } +} + +// Pins the strict-mode envelope shape: structured detail.* / wrapped +// CommandDeniedError for external agents, AND the historical short +// Message + independent Hint for existing consumers. +func TestStrictModeStub_StructuredEnvelope(t *testing.T) { + root := newTestTree() + pruneForStrictMode(root, core.StrictModeBot) + stub := findCmd(root, "im", "+search") + if stub == nil { + t.Fatalf("expected im/+search stub") + } + err := stub.RunE(stub, nil) + if err == nil { + t.Fatalf("strict-mode stub RunE should return error") + } + + var ee *output.ExitError + if !errors.As(err, &ee) { + t.Fatalf("err is not *output.ExitError: %T", err) + } + if ee.Detail == nil { + t.Fatalf("ExitError.Detail is nil; envelope writer cannot emit JSON") + } + if ee.Detail.Type != "command_denied" { + t.Errorf("Detail.Type = %q, want command_denied", ee.Detail.Type) + } + dm, ok := ee.Detail.Detail.(map[string]any) + if !ok { + t.Fatalf("Detail.Detail = %T, want map[string]any", ee.Detail.Detail) + } + if got, _ := dm["layer"].(string); got != cmdpolicy.LayerStrictMode { + t.Errorf("Detail.Detail[layer] = %q, want %q", got, cmdpolicy.LayerStrictMode) + } + if got, _ := dm["reason_code"].(string); got != "identity_not_supported" { + t.Errorf("Detail.Detail[reason_code] = %q, want identity_not_supported", got) + } + if got, _ := dm["policy_source"].(string); got != "strict-mode" { + t.Errorf("Detail.Detail[policy_source] = %q, want strict-mode", got) + } + + var cd *platform.CommandDeniedError + if !errors.As(err, &cd) { + t.Fatalf("err does not unwrap to *platform.CommandDeniedError") + } + if cd.Layer != cmdpolicy.LayerStrictMode { + t.Errorf("CommandDeniedError.Layer = %q, want %q", cd.Layer, cmdpolicy.LayerStrictMode) + } + if cd.ReasonCode != "identity_not_supported" { + t.Errorf("CommandDeniedError.ReasonCode = %q, want identity_not_supported", cd.ReasonCode) + } + if !strings.Contains(cd.Reason, `strict mode is "bot"`) { + t.Errorf("CommandDeniedError.Reason = %q, want substring 'strict mode is \"bot\"'", cd.Reason) + } + if ee.Detail.Message != `strict mode is "bot", only bot-identity commands are available` { + t.Errorf("Detail.Message = %q, want short historical form", ee.Detail.Message) + } + if !strings.HasPrefix(ee.Detail.Hint, "if the user explicitly wants to switch policy") { + t.Errorf("Detail.Hint = %q, want historical hint", ee.Detail.Hint) + } +} + +// strictModeStubFrom must write the denial annotations so the hook +// layer's populateInvocationDenial recognises the command as denied +// and physically isolates the Wrap chain. Without this, a plugin +// Wrapper registered against platform.All() could intercept the stub +// and silently return nil, swallowing the strict-mode error. +func TestStrictModeStub_HasDenialAnnotation(t *testing.T) { + root := newTestTree() + pruneForStrictMode(root, core.StrictModeBot) + + // im/+search is user-only -> replaced by a stub in StrictModeBot. + stub := findCmd(root, "im", "+search") + if stub == nil { + t.Fatalf("expected im/+search stub to exist") + } + got := stub.Annotations[cmdpolicy.AnnotationDenialLayer] + if got != cmdpolicy.LayerStrictMode { + t.Errorf("stub annotation %q = %q, want %q", + cmdpolicy.AnnotationDenialLayer, got, cmdpolicy.LayerStrictMode) + } + if src := stub.Annotations[cmdpolicy.AnnotationDenialSource]; src != "strict-mode" { + t.Errorf("stub annotation %q = %q, want %q", + cmdpolicy.AnnotationDenialSource, src, "strict-mode") + } +} + +// Audit / compliance observers fire even for strict-mode-denied commands +// and rely on CommandView.Risk() / Identities() / etc. The stub must +// carry the original command's annotations so those accessors keep +// returning meaningful values; the Short/Long are preserved so `--help` +// on a denied command still describes the original intent (parity with +// cmdpolicy/apply.go::installDenyStub). +func TestStrictModeStub_PreservesOriginalMetadata(t *testing.T) { + root := &cobra.Command{Use: "root"} + svc := &cobra.Command{Use: "im"} + root.AddCommand(svc) + userOnly := &cobra.Command{ + Use: "+search", + Short: "search messages", + Long: "Search across IM history.", + RunE: func(*cobra.Command, []string) error { return nil }, + } + cmdutil.SetSupportedIdentities(userOnly, []string{"user"}) + cmdutil.SetRisk(userOnly, "read") + svc.AddCommand(userOnly) + + pruneForStrictMode(root, core.StrictModeBot) + + stub := findCmd(root, "im", "+search") + if stub == nil { + t.Fatalf("expected im/+search stub") + } + if got := stub.Annotations["risk_level"]; got != "read" { + t.Errorf("stub risk_level = %q, want %q (lost in replacement)", got, "read") + } + if got := stub.Annotations["lark:supportedIdentities"]; got != "user" { + t.Errorf("stub supportedIdentities = %q, want %q", got, "user") + } + if stub.Short != "search messages" { + t.Errorf("stub Short = %q, want preserved Short", stub.Short) + } + if stub.Long != "Search across IM history." { + t.Errorf("stub Long = %q, want preserved Long", stub.Long) + } + // Denial stamps must still be present. + if stub.Annotations[cmdpolicy.AnnotationDenialLayer] != cmdpolicy.LayerStrictMode { + t.Errorf("denial annotation overwritten or missing") + } +} diff --git a/cmd/root.go b/cmd/root.go index 54fb5ed34..00d9a24bc 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -12,12 +12,17 @@ import ( "io" "net/url" "os" + "sort" "strconv" + "strings" + "github.com/larksuite/cli/extension/platform" internalauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/build" + "github.com/larksuite/cli/internal/cmdpolicy" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/hook" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/registry" "github.com/larksuite/cli/internal/skillscheck" @@ -88,8 +93,9 @@ func Execute() int { } configureFlagCompletions(os.Args) - f, rootCmd := buildInternal( - context.Background(), inv, + ctx := context.Background() + f, rootCmd, reg := buildInternal( + ctx, inv, WithIO(os.Stdin, os.Stdout, os.Stderr), HideProfile(isSingleAppMode()), ) @@ -99,8 +105,18 @@ func Execute() int { setupNotices() } - if err := rootCmd.Execute(); err != nil { - return handleRootError(f, err) + runErr := rootCmd.Execute() + + // Fire Shutdown lifecycle hooks regardless of run outcome. + // emitShutdown imposes a 2s total deadline and never propagates handler + // errors (Emit's documented Shutdown contract), so it cannot block exit + // or alter the user-visible exit code. + if reg != nil && !isCompletionCommand(os.Args) { + _ = hook.Emit(ctx, reg, platform.Shutdown, runErr) + } + + if runErr != nil { + return handleRootError(f, runErr) } return 0 } @@ -159,11 +175,17 @@ func setupNotices() { } // isCompletionCommand returns true if args indicate a shell completion request. -// Update notifications must be suppressed for these to avoid corrupting -// machine-parseable completion output. +// Update notifications and Shutdown lifecycle emits must be suppressed for +// these to avoid corrupting machine-parseable completion output and to avoid +// firing plugin Shutdown handlers on every Tab keystroke. +// +// Cobra dispatches BOTH "__complete" and its alias "__completeNoDesc" through +// the same hidden subcommand (see cobra/completions.go ShellCompRequestCmd / +// ShellCompNoDescRequestCmd). Check both, otherwise bash/zsh completion +// (which often uses NoDesc) silently bypasses the gate. func isCompletionCommand(args []string) bool { for _, arg := range args { - if arg == "completion" || arg == "__complete" { + if arg == "completion" || arg == "__complete" || arg == "__completeNoDesc" { return true } } @@ -263,6 +285,70 @@ func writeSecurityPolicyError(w io.Writer, spErr *internalauth.SecurityPolicyErr fmt.Fprint(w, buffer.String()) } +// installUnknownSubcommandGuard replaces cobra's silent help fallback on +// group commands (no Run/RunE) with an unknown_subcommand error. +// +// IMPORTANT: every command modified here is also tagged with +// cmdpolicy.AnnotationPureGroup so the user-layer policy engine +// continues to treat the command as a pure parent group. Without the +// tag, the RunE injection here would flip Runnable()=true and a user +// rule like `max_risk: read` would deny every ` --help` call +// with reason_code=risk_not_annotated. +func installUnknownSubcommandGuard(cmd *cobra.Command) { + if cmd.HasSubCommands() && cmd.Run == nil && cmd.RunE == nil { + cmd.RunE = unknownSubcommandRunE + if cmd.Annotations == nil { + cmd.Annotations = map[string]string{} + } + cmd.Annotations[cmdpolicy.AnnotationPureGroup] = "true" + } + for _, c := range cmd.Commands() { + installUnknownSubcommandGuard(c) + } +} + +func unknownSubcommandRunE(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return cmd.Help() + } + unknown := args[0] + available := availableSubcommandNames(cmd) + msg := fmt.Sprintf("unknown subcommand %q for %q", unknown, cmd.CommandPath()) + hint := fmt.Sprintf("run `%s --help` to see available subcommands", cmd.CommandPath()) + if len(available) > 0 { + hint = fmt.Sprintf("available subcommands: %s", strings.Join(available, ", ")) + } + return &output.ExitError{ + Code: output.ExitValidation, + Detail: &output.ErrDetail{ + Type: "unknown_subcommand", + Message: msg, + Hint: hint, + Detail: map[string]any{ + "unknown": unknown, + "command_path": cmd.CommandPath(), + "available": available, + }, + }, + } +} + +func availableSubcommandNames(cmd *cobra.Command) []string { + subs := make([]string, 0, len(cmd.Commands())) + for _, c := range cmd.Commands() { + if c.Hidden || !c.IsAvailableCommand() { + continue + } + name := c.Name() + if name == "help" || name == "completion" { + continue + } + subs = append(subs, name) + } + sort.Strings(subs) + return subs +} + // installTipsHelpFunc wraps the default help function to append a TIPS section // when a command has tips set via cmdutil.SetTips. It also force-shows global // flags that are normally hidden in single-app mode (currently --profile) diff --git a/cmd/root_integration_test.go b/cmd/root_integration_test.go index 416777a44..794cb07c5 100644 --- a/cmd/root_integration_test.go +++ b/cmd/root_integration_test.go @@ -27,6 +27,14 @@ import ( "github.com/spf13/cobra" ) +// Canonical strict-mode envelope strings shared across fixtures +// (reflect.DeepEqual pins them; keep in sync with strictModeStubFrom). +const ( + strictModeBotMessage = `strict mode is "bot", only bot-identity commands are available` + strictModeUserMessage = `strict mode is "user", only user-identity commands are available` + strictModeHint = "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)" +) + // buildIntegrationRootCmd creates a root command with api, service, and shortcut // subcommands wired to a test factory, simulating the real CLI command tree. func buildIntegrationRootCmd(t *testing.T, f *cmdutil.Factory) *cobra.Command { @@ -353,9 +361,17 @@ func TestIntegration_StrictModeBot_ProfileOverride_DirectAuthLoginReturnsEnvelop assertEnvelope(t, code, output.ExitValidation, stdout, stderr, output.ErrorEnvelope{ OK: false, Error: &output.ErrDetail{ - Type: "strict_mode", - Message: `strict mode is "bot", only bot-identity commands are available`, - Hint: "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)", + Type: "command_denied", + Message: strictModeBotMessage, + Hint: strictModeHint, + Detail: map[string]any{ + "path": "auth/login", + "layer": "strict_mode", + "policy_source": "strict-mode", + "rule_name": "", + "reason_code": "identity_not_supported", + "reason": strictModeBotMessage, + }, }, }) } @@ -371,9 +387,17 @@ func TestIntegration_StrictModeBot_ProfileOverride_DirectUserShortcutReturnsEnve assertEnvelope(t, code, output.ExitValidation, stdout, stderr, output.ErrorEnvelope{ OK: false, Error: &output.ErrDetail{ - Type: "strict_mode", - Message: `strict mode is "bot", only bot-identity commands are available`, - Hint: "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)", + Type: "command_denied", + Message: strictModeBotMessage, + Hint: strictModeHint, + Detail: map[string]any{ + "path": "im/+messages-search", + "layer": "strict_mode", + "policy_source": "strict-mode", + "rule_name": "", + "reason_code": "identity_not_supported", + "reason": strictModeBotMessage, + }, }, }) } @@ -409,7 +433,7 @@ func TestIntegration_StrictModeUser_ProfileOverride_ShortcutExplicitBotReturnsEn OK: false, Identity: "bot", Error: &output.ErrDetail{ - Type: "strict_mode", + Type: "command_denied", Message: `strict mode is "user", only user-identity commands are available`, Hint: "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)", }, @@ -428,7 +452,7 @@ func TestIntegration_StrictModeBot_ProfileOverride_ServiceExplicitUserReturnsEnv OK: false, Identity: "user", Error: &output.ErrDetail{ - Type: "strict_mode", + Type: "command_denied", Message: `strict mode is "bot", only bot-identity commands are available`, Hint: "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)", }, @@ -446,9 +470,17 @@ func TestIntegration_StrictModeUser_ProfileOverride_ServiceBotOnlyMethodReturnsE assertEnvelope(t, code, output.ExitValidation, stdout, stderr, output.ErrorEnvelope{ OK: false, Error: &output.ErrDetail{ - Type: "strict_mode", - Message: `strict mode is "user", only user-identity commands are available`, - Hint: "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)", + Type: "command_denied", + Message: strictModeUserMessage, + Hint: strictModeHint, + Detail: map[string]any{ + "path": "im/images/create", + "layer": "strict_mode", + "policy_source": "strict-mode", + "rule_name": "", + "reason_code": "identity_not_supported", + "reason": strictModeUserMessage, + }, }, }) } @@ -465,7 +497,7 @@ func TestIntegration_StrictModeBot_ProfileOverride_APIExplicitUserReturnsEnvelop OK: false, Identity: "user", Error: &output.ErrDetail{ - Type: "strict_mode", + Type: "command_denied", Message: `strict mode is "bot", only bot-identity commands are available`, Hint: "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)", }, @@ -504,11 +536,8 @@ func TestIntegration_Shortcut_BusinessError_OutputsEnvelope(t *testing.T) { }) } -// TestSetupNotices_ColdStart_NoNotice verifies that a missing stamp -// produces no skills key in the composed notice. Users who installed -// skills via `npx skills add` (no stamp) must not see the misleading -// "not installed" notice — only `lark-cli update` users opt into the -// drift tracker. +// TestSetupNotices_ColdStart_NoNotice verifies that missing state +// produces no skills key in the composed notice. func TestSetupNotices_ColdStart_NoNotice(t *testing.T) { clearNoticeEnv(t) dir := t.TempDir() @@ -539,13 +568,13 @@ func TestSetupNotices_ColdStart_NoNotice(t *testing.T) { } } -// TestSetupNotices_InSync verifies that a matching stamp produces no +// TestSetupNotices_InSync verifies that matching state produces no // skills key in the composed notice. func TestSetupNotices_InSync(t *testing.T) { clearNoticeEnv(t) dir := t.TempDir() t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := skillscheck.WriteStamp("1.0.21"); err != nil { + if err := skillscheck.WriteState(skillscheck.SkillsState{Version: "1.0.21"}); err != nil { t.Fatal(err) } @@ -572,13 +601,13 @@ func TestSetupNotices_InSync(t *testing.T) { } } -// TestSetupNotices_Drift verifies a mismatching stamp produces the +// TestSetupNotices_Drift verifies mismatching state produces the // drift message with both current and target populated. func TestSetupNotices_Drift(t *testing.T) { clearNoticeEnv(t) dir := t.TempDir() t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := skillscheck.WriteStamp("1.0.20"); err != nil { + if err := skillscheck.WriteState(skillscheck.SkillsState{Version: "1.0.20"}); err != nil { t.Fatal(err) } @@ -627,7 +656,7 @@ func TestSetupNotices_BothUpdateAndSkills(t *testing.T) { clearNoticeEnv(t) dir := t.TempDir() t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := skillscheck.WriteStamp("1.0.20"); err != nil { + if err := skillscheck.WriteState(skillscheck.SkillsState{Version: "1.0.20"}); err != nil { t.Fatal(err) } diff --git a/cmd/root_test.go b/cmd/root_test.go index 0f5ac1ad9..6aac983db 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -356,6 +356,7 @@ func TestConfigureFlagCompletions(t *testing.T) { {"help flag", []string{"im", "--help"}, true}, {"no args", []string{}, true}, {"__complete request", []string{"__complete", "im", "+send", ""}, false}, + {"__completeNoDesc request", []string{"__completeNoDesc", "im", "+send", ""}, false}, {"completion subcommand", []string{"completion", "bash"}, false}, } for _, tc := range tests { @@ -368,3 +369,30 @@ func TestConfigureFlagCompletions(t *testing.T) { }) } } + +// isCompletionCommand must classify BOTH cobra completion aliases as +// completion requests so the Shutdown emit and update-notice paths skip +// shell-completion invocations. __completeNoDesc is an Alias of +// __complete (cobra/completions.go ShellCompNoDescRequestCmd) and +// dispatches the same RunE; bash/zsh completion typically calls the +// NoDesc variant. +func TestIsCompletionCommand(t *testing.T) { + tests := []struct { + name string + args []string + want bool + }{ + {"plain command", []string{"im", "+send"}, false}, + {"__complete", []string{"__complete", "im"}, true}, + {"__completeNoDesc", []string{"__completeNoDesc", "im"}, true}, + {"completion subcommand", []string{"completion", "bash"}, true}, + {"completion in tail", []string{"foo", "bar", "completion"}, true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := isCompletionCommand(tc.args); got != tc.want { + t.Fatalf("isCompletionCommand(%v) = %v, want %v", tc.args, got, tc.want) + } + }) + } +} diff --git a/cmd/schema/schema.go b/cmd/schema/schema.go index 38ecaa322..e4114c5bc 100644 --- a/cmd/schema/schema.go +++ b/cmd/schema/schema.go @@ -380,6 +380,7 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co cmdutil.RegisterFlagCompletion(cmd, "format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "pretty"}, cobra.ShellCompDirectiveNoFileComp }) + cmdutil.SetRisk(cmd, "read") return cmd } diff --git a/cmd/unknown_subcommand_test.go b/cmd/unknown_subcommand_test.go new file mode 100644 index 000000000..4bba607d5 --- /dev/null +++ b/cmd/unknown_subcommand_test.go @@ -0,0 +1,177 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "bytes" + "errors" + "strings" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/output" +) + +func newGroupTree() (root, drive, files *cobra.Command) { + root = &cobra.Command{Use: "lark-cli"} + drive = &cobra.Command{Use: "drive", Short: "drive ops"} + root.AddCommand(drive) + + search := &cobra.Command{Use: "+search", RunE: func(*cobra.Command, []string) error { return nil }} + upload := &cobra.Command{Use: "+upload", RunE: func(*cobra.Command, []string) error { return nil }} + hidden := &cobra.Command{Use: "+secret", Hidden: true, RunE: func(*cobra.Command, []string) error { return nil }} + drive.AddCommand(search, upload, hidden) + + files = &cobra.Command{Use: "files", Short: "files ops"} + drive.AddCommand(files) + files.AddCommand(&cobra.Command{Use: "list", RunE: func(*cobra.Command, []string) error { return nil }}) + + return root, drive, files +} + +func TestInstallUnknownSubcommandGuard_InstallsOnGroupsOnly(t *testing.T) { + root, drive, files := newGroupTree() + leaf := drive.Commands()[0] // +search + + installUnknownSubcommandGuard(root) + + if drive.RunE == nil { + t.Error("drive should have RunE installed") + } + if files.RunE == nil { + t.Error("files should have RunE installed") + } + if err := leaf.RunE(leaf, []string{"unexpected-arg"}); err != nil { + t.Errorf("leaf +search RunE should be untouched, got error %v", err) + } +} + +func TestInstallUnknownSubcommandGuard_PreservesExistingRunE(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + called := false + custom := &cobra.Command{ + Use: "custom", + RunE: func(*cobra.Command, []string) error { + called = true + return nil + }, + } + // Child makes custom a "group" command, exercising the Run/RunE override guard. + custom.AddCommand(&cobra.Command{Use: "leaf", RunE: func(*cobra.Command, []string) error { return nil }}) + root.AddCommand(custom) + + installUnknownSubcommandGuard(root) + + if err := custom.RunE(custom, nil); err != nil { + t.Fatalf("preserved RunE returned error: %v", err) + } + if !called { + t.Error("guard must not overwrite a command that already defines Run/RunE") + } +} + +func TestUnknownSubcommandRunE_NoArgsShowsHelp(t *testing.T) { + _, drive, _ := newGroupTree() + installUnknownSubcommandGuard(drive.Root()) + + var buf bytes.Buffer + drive.SetOut(&buf) + drive.SetErr(&buf) + + if err := drive.RunE(drive, nil); err != nil { + t.Fatalf("expected no-args invocation to succeed, got: %v", err) + } + if !strings.Contains(buf.String(), "drive ops") { + t.Errorf("expected help output to include the command's Short, got:\n%s", buf.String()) + } +} + +func TestUnknownSubcommandRunE_UnknownReturnsStructuredError(t *testing.T) { + _, drive, _ := newGroupTree() + installUnknownSubcommandGuard(drive.Root()) + + err := drive.RunE(drive, []string{"+bogus"}) + if err == nil { + t.Fatal("expected error for unknown subcommand") + } + + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected *output.ExitError, got %T", err) + } + if exitErr.Code != output.ExitValidation { + t.Errorf("expected exit code %d, got %d", output.ExitValidation, exitErr.Code) + } + if exitErr.Detail == nil { + t.Fatal("expected ExitError to carry Detail") + } + if exitErr.Detail.Type != "unknown_subcommand" { + t.Errorf("expected Detail.Type=unknown_subcommand, got %q", exitErr.Detail.Type) + } + if !strings.Contains(exitErr.Detail.Message, `"+bogus"`) { + t.Errorf("message should echo the unknown token, got %q", exitErr.Detail.Message) + } + if !strings.Contains(exitErr.Detail.Hint, "+search") || !strings.Contains(exitErr.Detail.Hint, "+upload") { + t.Errorf("hint should list available shortcuts, got %q", exitErr.Detail.Hint) + } + if strings.Contains(exitErr.Detail.Hint, "+secret") { + t.Error("hidden commands must not appear in the hint") + } + + detail, ok := exitErr.Detail.Detail.(map[string]any) + if !ok { + t.Fatalf("expected Detail.Detail to be map[string]any, got %T", exitErr.Detail.Detail) + } + if detail["unknown"] != "+bogus" { + t.Errorf("detail.unknown should be +bogus, got %v", detail["unknown"]) + } + if detail["command_path"] != "lark-cli drive" { + t.Errorf("detail.command_path should be %q, got %v", "lark-cli drive", detail["command_path"]) + } + available, ok := detail["available"].([]string) + if !ok { + t.Fatalf("detail.available should be []string, got %T", detail["available"]) + } + if len(available) != 3 { + t.Errorf("expected 3 available entries (hidden excluded), got %d: %v", len(available), available) + } +} + +func TestUnknownSubcommandRunE_NestedResourceGroup(t *testing.T) { + root, _, files := newGroupTree() + installUnknownSubcommandGuard(root) + + err := files.RunE(files, []string{"bogus"}) + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected *output.ExitError on nested group, got %T", err) + } + if exitErr.Detail.Detail.(map[string]any)["command_path"] != "lark-cli drive files" { + t.Errorf("command_path should reflect the nested resource, got %v", + exitErr.Detail.Detail.(map[string]any)["command_path"]) + } +} + +func TestAvailableSubcommandNames_FiltersHelpAndCompletion(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + root.AddCommand( + &cobra.Command{Use: "alpha", RunE: func(*cobra.Command, []string) error { return nil }}, + &cobra.Command{Use: "help", RunE: func(*cobra.Command, []string) error { return nil }}, + &cobra.Command{Use: "completion", RunE: func(*cobra.Command, []string) error { return nil }}, + &cobra.Command{Use: "beta", Hidden: true, RunE: func(*cobra.Command, []string) error { return nil }}, + &cobra.Command{Use: "gamma", RunE: func(*cobra.Command, []string) error { return nil }}, + ) + + got := availableSubcommandNames(root) + want := []string{"alpha", "gamma"} + if len(got) != len(want) { + t.Fatalf("expected %v, got %v", want, got) + } + for i, name := range want { + if got[i] != name { + t.Errorf("availableSubcommandNames[%d] = %q, want %q", i, got[i], name) + } + } +} diff --git a/cmd/update/update.go b/cmd/update/update.go index 632e4fe9e..ea5dc2253 100644 --- a/cmd/update/update.go +++ b/cmd/update/update.go @@ -31,11 +31,12 @@ var ( currentVersion = func() string { return build.Version } currentOS = runtime.GOOS newUpdater = func() *selfupdate.Updater { return selfupdate.New() } + syncSkills = func(opts skillscheck.SyncOptions) *skillscheck.SyncResult { return skillscheck.SyncSkills(opts) } ) func isWindows() bool { return currentOS == osWindows } -// normalizeVersion canonicalizes a version string for stamp comparison. +// normalizeVersion canonicalizes a version string for state comparison. // Strips a leading "v" so versions written from Makefile (git describe → // "v1.0.0") and npm (no prefix → "1.0.0") compare equal. func normalizeVersion(s string) string { @@ -111,6 +112,7 @@ Use --check to only check for updates without installing.`, cmd.Flags().BoolVar(&opts.JSON, "json", false, "structured JSON output") cmd.Flags().BoolVar(&opts.Force, "force", false, "force reinstall even if already up to date") cmd.Flags().BoolVar(&opts.Check, "check", false, "only check for updates, do not install") + cmdutil.SetRisk(cmd, "high-risk-write") return cmd } @@ -120,7 +122,9 @@ func updateRun(opts *UpdateOptions) error { cur := currentVersion() updater := newUpdater() - updater.CleanupStaleFiles() + if !opts.Check { + updater.CleanupStaleFiles() + } output.PendingNotice = nil // 1. Fetch latest version @@ -136,13 +140,9 @@ func updateRun(opts *UpdateOptions) error { // 3. Compare versions if !opts.Force && !update.IsNewer(latest, cur) { - // Run skills sync before returning — covers the case where the - // binary is already current but skills were never synced. - // Stamp dedup makes this a no-op if skills are already in sync. - // Skip side-effects under --check (pure report path per spec §3.6). - var skillsResult *selfupdate.NpmResult + var skillsResult *skillscheck.SyncResult if !opts.Check { - skillsResult = runSkillsAndStamp(updater, io, cur, opts.Force) + skillsResult = runSkillsAndState(updater, io, cur, opts.Force) } return reportAlreadyUpToDate(opts, io, cur, latest, skillsResult, opts.Check) } @@ -184,16 +184,7 @@ func reportCheckResult(opts *UpdateOptions, io *cmdutil.IOStreams, cur, latest s "message": fmt.Sprintf("lark-cli %s %s %s available", cur, symArrow(), latest), "url": releaseURL(latest), "changelog": changelogURL(), } - // skills_status: pure report, no side effect, no stamp write. - // ReadStamp errors are silently swallowed — if we can't read the - // stamp we just omit the block rather than fail the --check. - if stamp, err := skillscheck.ReadStamp(); err == nil { - out["skills_status"] = map[string]interface{}{ - "current": stamp, - "target": cur, - "in_sync": stamp == cur, - } - } + applySkillsStatus(out, cur) output.PrintJson(io.Out, out) return nil } @@ -209,7 +200,7 @@ func reportCheckResult(opts *UpdateOptions, io *cmdutil.IOStreams, cur, latest s } func doManualUpdate(opts *UpdateOptions, io *cmdutil.IOStreams, cur, latest string, detect selfupdate.DetectResult, updater *selfupdate.Updater) error { - skillsResult := runSkillsAndStamp(updater, io, cur, opts.Force) + skillsResult := runSkillsAndState(updater, io, cur, opts.Force) reason := detect.ManualReason() if opts.JSON { @@ -287,10 +278,7 @@ func doNpmUpdate(opts *UpdateOptions, io *cmdutil.IOStreams, cur, latest string, return output.ErrBare(output.ExitAPI) } - // Skills update (best-effort) — uses runSkillsAndStamp so the - // stamp gets persisted on success and dedup applies if a previous - // run already stamped this version. - skillsResult := runSkillsAndStamp(updater, io, latest, opts.Force) + skillsResult := runSkillsAndState(updater, io, latest, opts.Force) if opts.JSON { result := map[string]interface{}{ @@ -327,27 +315,21 @@ func verificationFailureHint(updater *selfupdate.Updater, latest string) string return fmt.Sprintf("automatic rollback is unavailable on this platform; reinstall manually (skills will not be synced): npm install -g %s@%s && npx skills add larksuite/cli -y -g, or download %s", selfupdate.NpmPackage, latest, releaseURL(latest)) } -// runSkillsAndStamp triggers updater.RunSkillsUpdate and persists the -// stamp on success. Skips the npx invocation when the stamp already -// matches stampVersion (unless force is true). The stamp write failure -// emits a warning to io.ErrOut but does NOT fail the update command — -// best-effort. ReadStamp errors are swallowed (fail-closed: treated as -// out-of-sync, so npx re-runs). Returns nil iff skipped due to stamp -// dedup; otherwise returns the underlying *NpmResult with Err semantics -// from RunSkillsUpdate. -func runSkillsAndStamp(updater *selfupdate.Updater, io *cmdutil.IOStreams, stampVersion string, force bool) *selfupdate.NpmResult { +func runSkillsAndState(updater *selfupdate.Updater, io *cmdutil.IOStreams, stateVersion string, force bool) *skillscheck.SyncResult { if !force { - if existing, _ := skillscheck.ReadStamp(); normalizeVersion(existing) == normalizeVersion(stampVersion) { + if existing, ok := skillscheck.ReadSyncedVersion(); ok && normalizeVersion(existing) == normalizeVersion(stateVersion) { return nil } } - r := updater.RunSkillsUpdate() - if r.Err == nil { - if err := skillscheck.WriteStamp(stampVersion); err != nil { - fmt.Fprintf(io.ErrOut, "warning: skills synced but stamp not written: %v\n", err) - } + result := syncSkills(skillscheck.SyncOptions{ + Version: stateVersion, + Force: force, + Runner: updater, + }) + if result.Err != nil && strings.Contains(result.Err.Error(), "state not written") { + fmt.Fprintf(io.ErrOut, "warning: %v\n", result.Err) } - return r + return result } // reportAlreadyUpToDate emits the JSON / pretty output for the @@ -355,7 +337,7 @@ func runSkillsAndStamp(updater *selfupdate.Updater, io *cmdutil.IOStreams, stamp // fields derived from skillsResult. When check is true, this is the pure // report path (spec §3.6): no side-effects, JSON envelope uses // skills_status (spec §4.2) instead of skills_action. -func reportAlreadyUpToDate(opts *UpdateOptions, io *cmdutil.IOStreams, cur, latest string, skillsResult *selfupdate.NpmResult, check bool) error { +func reportAlreadyUpToDate(opts *UpdateOptions, io *cmdutil.IOStreams, cur, latest string, skillsResult *skillscheck.SyncResult, check bool) error { if opts.JSON { out := map[string]interface{}{ "ok": true, "previous_version": cur, "current_version": cur, @@ -363,16 +345,7 @@ func reportAlreadyUpToDate(opts *UpdateOptions, io *cmdutil.IOStreams, cur, late "message": fmt.Sprintf("lark-cli %s is already up to date", cur), } if check { - // Pure report — read stamp directly, emit skills_status block. - // ReadStamp errors are silently swallowed — if we can't read - // the stamp we just omit the block rather than fail the --check. - if stamp, err := skillscheck.ReadStamp(); err == nil { - out["skills_status"] = map[string]interface{}{ - "current": stamp, - "target": cur, - "in_sync": stamp == cur, - } - } + applySkillsStatus(out, cur) } else { applySkillsResult(out, skillsResult) } @@ -386,36 +359,70 @@ func reportAlreadyUpToDate(opts *UpdateOptions, io *cmdutil.IOStreams, cur, late return nil } -// applySkillsResult mutates the JSON envelope to include skills_action -// (and skills_warning when failed). nil result = "in_sync" (dedup hit). -func applySkillsResult(env map[string]interface{}, r *selfupdate.NpmResult) { +func applySkillsStatus(env map[string]interface{}, target string) { + state, readable, err := skillscheck.ReadState() + if err != nil || !readable || state.Version == "" { + return + } + status := map[string]interface{}{ + "current": state.Version, + "target": target, + "in_sync": normalizeVersion(state.Version) == normalizeVersion(target), + } + if len(state.OfficialSkills) > 0 { + status["official"] = len(state.OfficialSkills) + } + if len(state.UpdatedSkills) > 0 { + status["updated"] = len(state.UpdatedSkills) + } + if len(state.SkippedDeletedSkills) > 0 { + status["skipped_deleted"] = state.SkippedDeletedSkills + } + env["skills_status"] = status +} + +func applySkillsResult(env map[string]interface{}, r *skillscheck.SyncResult) { switch { case r == nil: env["skills_action"] = "in_sync" case r.Err != nil: env["skills_action"] = "failed" env["skills_warning"] = fmt.Sprintf("skills update failed: %s", r.Err) - if detail := strings.TrimSpace(r.Stderr.String()); detail != "" { - env["skills_detail"] = selfupdate.Truncate(detail, maxNpmOutput) - } + env["skills_summary"] = skillsSummary(r) default: env["skills_action"] = "synced" + env["skills_summary"] = skillsSummary(r) + } +} + +func skillsSummary(r *skillscheck.SyncResult) map[string]interface{} { + summary := map[string]interface{}{ + "official": len(r.Official), + "updated": len(r.Updated), + "added": len(r.Added), + "skipped_deleted": len(r.SkippedDeleted), } + if len(r.Failed) > 0 { + summary["failed"] = r.Failed + } + return summary } -// emitSkillsTextHints prints human-readable feedback about the skills -// sync result for non-JSON output. -func emitSkillsTextHints(io *cmdutil.IOStreams, r *selfupdate.NpmResult) { +func emitSkillsTextHints(io *cmdutil.IOStreams, r *skillscheck.SyncResult) { switch { case r == nil: - // dedup hit — silent (already up to date) case r.Err != nil: fmt.Fprintf(io.ErrOut, "%s Skills update failed: %v\n", symWarn(), r.Err) - if detail := strings.TrimSpace(r.Stderr.String()); detail != "" { - fmt.Fprintf(io.ErrOut, " %s\n", selfupdate.Truncate(detail, maxStderrDetail)) + if len(r.Failed) > 0 { + fmt.Fprintf(io.ErrOut, " Failed skills: %s\n", strings.Join(r.Failed, ", ")) } - fmt.Fprintf(io.ErrOut, " Run manually: npx -y skills add larksuite/cli -g -y\n") + fmt.Fprintf(io.ErrOut, " To retry all official skills: lark-cli update --force\n") + case r.Force: + fmt.Fprintf(io.ErrOut, "%s Skills updated: restored all %d official skills\n", symOK(), len(r.Official)) default: - fmt.Fprintf(io.ErrOut, "%s Skills updated\n", symOK()) + fmt.Fprintf(io.ErrOut, "%s Skills updated: %d official, %d updated, %d added, %d skipped because deleted locally\n", symOK(), len(r.Official), len(r.Updated), len(r.Added), len(r.SkippedDeleted)) + if len(r.SkippedDeleted) > 0 { + fmt.Fprintf(io.ErrOut, " To restore all official skills: lark-cli update --force\n") + } } } diff --git a/cmd/update/update_test.go b/cmd/update/update_test.go index 250aa83db..94c38723c 100644 --- a/cmd/update/update_test.go +++ b/cmd/update/update_test.go @@ -8,8 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "os" - "path/filepath" "strings" "testing" @@ -28,7 +26,6 @@ func newTestFactory(t *testing.T) (*cmdutil.Factory, *bytes.Buffer, *bytes.Buffe } // mockDetect sets up newUpdater to return an Updater with the given DetectResult. -// It preserves any existing NpmInstallOverride/SkillsUpdateOverride that may be set later. func mockDetect(t *testing.T, result selfupdate.DetectResult) { t.Helper() origNew := newUpdater @@ -41,22 +38,34 @@ func mockDetect(t *testing.T, result selfupdate.DetectResult) { } // mockDetectAndNpm sets up newUpdater with detect, npm install, and skills overrides all at once. -func mockDetectAndNpm(t *testing.T, result selfupdate.DetectResult, - npmFn func(string) *selfupdate.NpmResult, - skillsFn func() *selfupdate.NpmResult) { +func mockDetectAndNpm(t *testing.T, result selfupdate.DetectResult, npmFn func(string) *selfupdate.NpmResult) { t.Helper() origNew := newUpdater newUpdater = func() *selfupdate.Updater { u := selfupdate.New() u.DetectOverride = func() selfupdate.DetectResult { return result } u.NpmInstallOverride = npmFn - u.SkillsUpdateOverride = skillsFn u.VerifyOverride = func(string) error { return nil } + u.SkillsCommandOverride = successfulSkillsCommand() return u } t.Cleanup(func() { newUpdater = origNew }) } +func successfulSkillsCommand() func(args ...string) *selfupdate.NpmResult { + return func(args ...string) *selfupdate.NpmResult { + r := &selfupdate.NpmResult{} + switch strings.Join(args, " ") { + case "-y skills add https://open.feishu.cn --list": + r.Stdout.WriteString("lark-calendar\nlark-mail\n") + case "-y skills ls -g": + r.Stdout.WriteString("lark-calendar\ncustom-skill\n") + default: + } + return r + } +} + func TestUpdateAlreadyUpToDate_JSON(t *testing.T) { f, stdout, _ := newTestFactory(t) @@ -168,9 +177,7 @@ func TestUpdateManual_Human(t *testing.T) { } func TestUpdateNpm_JSON(t *testing.T) { - // Isolate config dir: this test mocks fetchLatest="2.0.0" and lets - // runSkillsAndStamp → WriteStamp succeed, which without isolation would - // clobber the real ~/.lark-cli/skills.stamp with "2.0.0". + // Isolate config dir because skills sync writes skills-state.json. t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) f, stdout, _ := newTestFactory(t) @@ -186,7 +193,6 @@ func TestUpdateNpm_JSON(t *testing.T) { mockDetectAndNpm(t, selfupdate.DetectResult{Method: selfupdate.InstallNpm, ResolvedPath: "/node_modules/@larksuite/cli/bin/lark-cli", NpmAvailable: true}, func(version string) *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, - func() *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, ) err := cmd.Execute() @@ -216,7 +222,6 @@ func TestUpdateNpm_Human(t *testing.T) { mockDetectAndNpm(t, selfupdate.DetectResult{Method: selfupdate.InstallNpm, ResolvedPath: "/node_modules/@larksuite/cli/bin/lark-cli", NpmAvailable: true}, func(version string) *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, - func() *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, ) err := cmd.Execute() @@ -230,7 +235,7 @@ func TestUpdateNpm_Human(t *testing.T) { } func TestUpdateForce_JSON(t *testing.T) { - // Same stamp-isolation rationale as TestUpdateNpm_JSON. + // Same state-isolation rationale as TestUpdateNpm_JSON. t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) f, stdout, _ := newTestFactory(t) @@ -246,7 +251,6 @@ func TestUpdateForce_JSON(t *testing.T) { mockDetectAndNpm(t, selfupdate.DetectResult{Method: selfupdate.InstallNpm, ResolvedPath: "/node_modules/@larksuite/cli/bin/lark-cli", NpmAvailable: true}, func(version string) *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, - func() *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, ) err := cmd.Execute() @@ -323,7 +327,7 @@ func TestUpdateInvalidVersion_JSON(t *testing.T) { } func TestUpdateDevVersion_JSON(t *testing.T) { - // Same stamp-isolation rationale as TestUpdateNpm_JSON. + // Same state-isolation rationale as TestUpdateNpm_JSON. t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) f, stdout, _ := newTestFactory(t) @@ -339,7 +343,6 @@ func TestUpdateDevVersion_JSON(t *testing.T) { mockDetectAndNpm(t, selfupdate.DetectResult{Method: selfupdate.InstallNpm, ResolvedPath: "/node_modules/@larksuite/cli/bin/lark-cli", NpmAvailable: true}, func(version string) *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, - func() *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, ) err := cmd.Execute() @@ -451,8 +454,8 @@ func TestUpdateNpmVerifyFail_JSON_NoRestoreHintWhenBackupUnavailable(t *testing. u.NpmInstallOverride = func(version string) *selfupdate.NpmResult { return &selfupdate.NpmResult{} } u.VerifyOverride = func(string) error { return errors.New("bad binary") } u.RestoreAvailableOverride = func() bool { return false } - u.SkillsUpdateOverride = func() *selfupdate.NpmResult { - t.Fatal("skills update should not run when binary verification fails") + u.SkillsCommandOverride = func(args ...string) *selfupdate.NpmResult { + t.Fatal("skills sync should not run when binary verification fails") return nil } return u @@ -649,7 +652,7 @@ func TestPermissionHint(t *testing.T) { func TestUpdateWindows_NpmSuccess_JSON(t *testing.T) { // With the rename trick, Windows npm installs can now auto-update. - // Same stamp-isolation rationale as TestUpdateNpm_JSON. + // Same state-isolation rationale as TestUpdateNpm_JSON. t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) f, stdout, _ := newTestFactory(t) @@ -668,7 +671,6 @@ func TestUpdateWindows_NpmSuccess_JSON(t *testing.T) { mockDetectAndNpm(t, selfupdate.DetectResult{Method: selfupdate.InstallNpm, ResolvedPath: `C:\npm\node_modules\@larksuite\cli\bin\lark-cli.exe`, NpmAvailable: true}, func(version string) *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, - func() *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, ) err := cmd.Execute() @@ -750,7 +752,6 @@ func TestUpdateNpm_SkillsSuccess_JSON(t *testing.T) { mockDetectAndNpm(t, selfupdate.DetectResult{Method: selfupdate.InstallNpm, ResolvedPath: "/node_modules/@larksuite/cli/bin/lark-cli", NpmAvailable: true}, func(version string) *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, - func() *selfupdate.NpmResult { return &selfupdate.NpmResult{} }, ) err := cmd.Execute() @@ -785,8 +786,7 @@ func TestUpdateNpm_SkillsFail_JSON(t *testing.T) { } u.NpmInstallOverride = func(version string) *selfupdate.NpmResult { return &selfupdate.NpmResult{} } u.VerifyOverride = func(string) error { return nil } - // Skills update fails - u.SkillsUpdateOverride = func() *selfupdate.NpmResult { + u.SkillsCommandOverride = func(args ...string) *selfupdate.NpmResult { r := &selfupdate.NpmResult{} r.Stderr.WriteString("npx: command not found") r.Err = fmt.Errorf("exit status 127") @@ -812,8 +812,8 @@ func TestUpdateNpm_SkillsFail_JSON(t *testing.T) { if !strings.Contains(out, "skills_warning") { t.Errorf("expected skills_warning in output, got: %s", out) } - if !strings.Contains(out, "skills_detail") { - t.Errorf("expected skills_detail in output, got: %s", out) + if !strings.Contains(out, "skills_summary") { + t.Errorf("expected skills_summary in output, got: %s", out) } } @@ -838,7 +838,7 @@ func TestUpdateNpm_SkillsFail_Human(t *testing.T) { } u.NpmInstallOverride = func(version string) *selfupdate.NpmResult { return &selfupdate.NpmResult{} } u.VerifyOverride = func(string) error { return nil } - u.SkillsUpdateOverride = func() *selfupdate.NpmResult { + u.SkillsCommandOverride = func(args ...string) *selfupdate.NpmResult { r := &selfupdate.NpmResult{} r.Stderr.WriteString("npx: command not found") r.Err = fmt.Errorf("exit status 127") @@ -861,100 +861,96 @@ func TestUpdateNpm_SkillsFail_Human(t *testing.T) { if !strings.Contains(out, "Skills update failed") { t.Errorf("expected skills failure warning, got: %s", out) } - if !strings.Contains(out, "npx -y skills add") { - t.Errorf("expected manual skills command hint, got: %s", out) + if !strings.Contains(out, "lark-cli update --force") { + t.Errorf("expected force retry hint, got: %s", out) } } -// newTestIO returns a cmdutil.IOStreams backed by bytes.Buffers, suitable -// for direct calls to internals like runSkillsAndStamp that write to -// io.ErrOut. +// newTestIO returns a cmdutil.IOStreams backed by bytes.Buffers. func newTestIO() *cmdutil.IOStreams { return cmdutil.NewIOStreams(&bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}) } -func TestRunSkillsAndStamp_DedupHit(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := skillscheck.WriteStamp("1.0.21"); err != nil { +func TestRunSkillsAndState_DedupHit(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if err := skillscheck.WriteState(skillscheck.SkillsState{Version: "1.0.21"}); err != nil { t.Fatal(err) } called := false updater := &selfupdate.Updater{ - SkillsUpdateOverride: func() *selfupdate.NpmResult { + SkillsCommandOverride: func(args ...string) *selfupdate.NpmResult { called = true return &selfupdate.NpmResult{} }, } - got := runSkillsAndStamp(updater, newTestIO(), "1.0.21", false) + got := runSkillsAndState(updater, newTestIO(), "1.0.21", false) if got != nil { - t.Errorf("runSkillsAndStamp() = %+v, want nil for dedup hit", got) + t.Errorf("runSkillsAndState() = %+v, want nil for dedup hit", got) } if called { - t.Error("SkillsUpdateOverride called, want skipped due to dedup") + t.Error("SkillsCommandOverride called, want skipped due to dedup") } } -func TestRunSkillsAndStamp_DedupForceBypass(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := skillscheck.WriteStamp("1.0.21"); err != nil { +func TestRunSkillsAndState_DedupForceBypass(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if err := skillscheck.WriteState(skillscheck.SkillsState{Version: "1.0.21"}); err != nil { t.Fatal(err) } called := false updater := &selfupdate.Updater{ - SkillsUpdateOverride: func() *selfupdate.NpmResult { + SkillsCommandOverride: func(args ...string) *selfupdate.NpmResult { called = true - return &selfupdate.NpmResult{} + return successfulSkillsCommand()(args...) }, } - got := runSkillsAndStamp(updater, newTestIO(), "1.0.21", true) - if got == nil { - t.Fatal("runSkillsAndStamp(force=true) = nil, want non-nil") + got := runSkillsAndState(updater, newTestIO(), "1.0.21", true) + if got == nil || got.Err != nil { + t.Fatalf("runSkillsAndState(force=true) = %+v, want successful result", got) } if !called { - t.Error("SkillsUpdateOverride not called with force=true") + t.Error("SkillsCommandOverride not called with force=true") } } -func TestRunSkillsAndStamp_SuccessWritesStamp(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - updater := &selfupdate.Updater{ - SkillsUpdateOverride: func() *selfupdate.NpmResult { - return &selfupdate.NpmResult{} - }, - } - got := runSkillsAndStamp(updater, newTestIO(), "1.0.21", false) +func TestRunSkillsAndState_SuccessWritesState(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + updater := &selfupdate.Updater{SkillsCommandOverride: successfulSkillsCommand()} + got := runSkillsAndState(updater, newTestIO(), "1.0.21", false) if got == nil || got.Err != nil { - t.Fatalf("runSkillsAndStamp() = %+v, want non-nil with nil Err", got) + t.Fatalf("runSkillsAndState() = %+v, want non-nil with nil Err", got) } - stamp, _ := skillscheck.ReadStamp() - if stamp != "1.0.21" { - t.Errorf("stamp = %q, want \"1.0.21\"", stamp) + state, readable, err := skillscheck.ReadState() + if err != nil || !readable { + t.Fatalf("ReadState() = (_, %v, %v), want readable", readable, err) + } + if state.Version != "1.0.21" { + t.Errorf("state.Version = %q, want \"1.0.21\"", state.Version) } } -func TestRunSkillsAndStamp_FailureKeepsOldStamp(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := skillscheck.WriteStamp("1.0.20"); err != nil { +func TestRunSkillsAndState_FailureKeepsOldState(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if err := skillscheck.WriteState(skillscheck.SkillsState{Version: "1.0.20"}); err != nil { t.Fatal(err) } updater := &selfupdate.Updater{ - SkillsUpdateOverride: func() *selfupdate.NpmResult { + SkillsCommandOverride: func(args ...string) *selfupdate.NpmResult { r := &selfupdate.NpmResult{} r.Err = fmt.Errorf("npx failed") return r }, } - got := runSkillsAndStamp(updater, newTestIO(), "1.0.21", false) + got := runSkillsAndState(updater, newTestIO(), "1.0.21", false) if got == nil || got.Err == nil { - t.Fatalf("runSkillsAndStamp() = %+v, want non-nil with non-nil Err", got) + t.Fatalf("runSkillsAndState() = %+v, want non-nil with non-nil Err", got) } - stamp, _ := skillscheck.ReadStamp() - if stamp != "1.0.20" { - t.Errorf("stamp = %q, want \"1.0.20\" (failure must not overwrite)", stamp) + state, readable, err := skillscheck.ReadState() + if err != nil || !readable { + t.Fatalf("ReadState() = (_, %v, %v), want readable", readable, err) + } + if state.Version != "1.0.20" { + t.Errorf("state.Version = %q, want \"1.0.20\" (failure must not overwrite)", state.Version) } } @@ -973,8 +969,7 @@ func TestTruncate(t *testing.T) { } func TestUpdateRun_AlreadyLatest_RunsSkillsSync(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) origFetch := fetchLatest origCur := currentVersion @@ -987,9 +982,9 @@ func TestUpdateRun_AlreadyLatest_RunsSkillsSync(t *testing.T) { t.Cleanup(func() { newUpdater = origNew }) newUpdater = func() *selfupdate.Updater { return &selfupdate.Updater{ - SkillsUpdateOverride: func() *selfupdate.NpmResult { + SkillsCommandOverride: func(args ...string) *selfupdate.NpmResult { skillsCalled = true - return &selfupdate.NpmResult{} + return successfulSkillsCommand()(args...) }, } } @@ -1000,17 +995,19 @@ func TestUpdateRun_AlreadyLatest_RunsSkillsSync(t *testing.T) { t.Fatalf("updateRun() err = %v, want nil", err) } if !skillsCalled { - t.Error("RunSkillsUpdate not called in already-up-to-date branch (cold stamp), want called") + t.Error("skills sync not called in already-up-to-date branch") + } + state, readable, err := skillscheck.ReadState() + if err != nil || !readable { + t.Fatalf("ReadState() = (_, %v, %v), want readable", readable, err) } - stamp, _ := skillscheck.ReadStamp() - if stamp != "1.0.21" { - t.Errorf("stamp = %q, want \"1.0.21\"", stamp) + if state.Version != "1.0.21" { + t.Errorf("state.Version = %q, want \"1.0.21\"", state.Version) } } func TestUpdateRun_Manual_RunsSkillsSync(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) origFetch := fetchLatest origCur := currentVersion @@ -1029,9 +1026,9 @@ func TestUpdateRun_Manual_RunsSkillsSync(t *testing.T) { ResolvedPath: "/usr/local/bin/lark-cli", } }, - SkillsUpdateOverride: func() *selfupdate.NpmResult { + SkillsCommandOverride: func(args ...string) *selfupdate.NpmResult { skillsCalled = true - return &selfupdate.NpmResult{} + return successfulSkillsCommand()(args...) }, } } @@ -1042,17 +1039,19 @@ func TestUpdateRun_Manual_RunsSkillsSync(t *testing.T) { t.Fatalf("updateRun() err = %v, want nil", err) } if !skillsCalled { - t.Error("RunSkillsUpdate not called in manual branch, want called") + t.Error("skills sync not called in manual branch") + } + state, readable, err := skillscheck.ReadState() + if err != nil || !readable { + t.Fatalf("ReadState() = (_, %v, %v), want readable", readable, err) } - stamp, _ := skillscheck.ReadStamp() - if stamp != "1.0.21" { - t.Errorf("stamp = %q, want \"1.0.21\" (manual path stamps cur)", stamp) + if state.Version != "1.0.21" { + t.Errorf("state.Version = %q, want \"1.0.21\" (manual path records current binary)", state.Version) } } -func TestUpdateRun_Npm_RunsSkillsSync_StampsLatest(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) +func TestUpdateRun_Npm_RunsSkillsSync_WritesLatestState(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) origFetch := fetchLatest origCur := currentVersion @@ -1075,9 +1074,9 @@ func TestUpdateRun_Npm_RunsSkillsSync_StampsLatest(t *testing.T) { return &selfupdate.NpmResult{} }, VerifyOverride: func(expectedVersion string) error { return nil }, - SkillsUpdateOverride: func() *selfupdate.NpmResult { + SkillsCommandOverride: func(args ...string) *selfupdate.NpmResult { skillsCalled = true - return &selfupdate.NpmResult{} + return successfulSkillsCommand()(args...) }, } } @@ -1088,18 +1087,25 @@ func TestUpdateRun_Npm_RunsSkillsSync_StampsLatest(t *testing.T) { t.Fatalf("updateRun() err = %v, want nil", err) } if !skillsCalled { - t.Error("RunSkillsUpdate not called in npm branch") + t.Error("skills sync not called in npm branch") } - stamp, _ := skillscheck.ReadStamp() - if stamp != "1.0.22" { - t.Errorf("stamp = %q, want \"1.0.22\" (npm path stamps latest)", stamp) + state, readable, err := skillscheck.ReadState() + if err != nil || !readable { + t.Fatalf("ReadState() = (_, %v, %v), want readable", readable, err) + } + if state.Version != "1.0.22" { + t.Errorf("state.Version = %q, want \"1.0.22\" (npm path records latest binary)", state.Version) } } func TestUpdateRun_CheckIncludesSkillsStatus(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := skillscheck.WriteStamp("1.0.20"); err != nil { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if err := skillscheck.WriteState(skillscheck.SkillsState{ + Version: "1.0.20", + OfficialSkills: []string{"lark-calendar", "lark-mail"}, + UpdatedSkills: []string{"lark-calendar"}, + SkippedDeletedSkills: []string{"lark-mail"}, + }); err != nil { t.Fatal(err) } @@ -1117,9 +1123,9 @@ func TestUpdateRun_CheckIncludesSkillsStatus(t *testing.T) { DetectOverride: func() selfupdate.DetectResult { return selfupdate.DetectResult{Method: selfupdate.InstallNpm, NpmAvailable: true} }, - SkillsUpdateOverride: func() *selfupdate.NpmResult { + SkillsCommandOverride: func(args ...string) *selfupdate.NpmResult { skillsCalled = true - return &selfupdate.NpmResult{} + return successfulSkillsCommand()(args...) }, } } @@ -1130,7 +1136,7 @@ func TestUpdateRun_CheckIncludesSkillsStatus(t *testing.T) { t.Fatalf("updateRun(--check) err = %v, want nil", err) } if skillsCalled { - t.Error("RunSkillsUpdate called under --check, want skipped (pure report)") + t.Error("skills sync called under --check, want skipped") } var env map[string]interface{} @@ -1144,12 +1150,14 @@ func TestUpdateRun_CheckIncludesSkillsStatus(t *testing.T) { if status["current"] != "1.0.20" || status["target"] != "1.0.21" || status["in_sync"] != false { t.Errorf("skills_status = %+v, want {current:\"1.0.20\", target:\"1.0.21\", in_sync:false}", status) } + if status["official"] != float64(2) || status["updated"] != float64(1) { + t.Errorf("skills_status counts = %+v, want official:2 updated:1", status) + } } func TestUpdateRun_CheckAlreadyLatest_NoSideEffect(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := skillscheck.WriteStamp("1.0.20"); err != nil { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if err := skillscheck.WriteState(skillscheck.SkillsState{Version: "1.0.20"}); err != nil { t.Fatal(err) } @@ -1164,9 +1172,9 @@ func TestUpdateRun_CheckAlreadyLatest_NoSideEffect(t *testing.T) { t.Cleanup(func() { newUpdater = origNew }) newUpdater = func() *selfupdate.Updater { return &selfupdate.Updater{ - SkillsUpdateOverride: func() *selfupdate.NpmResult { + SkillsCommandOverride: func(args ...string) *selfupdate.NpmResult { skillsCalled = true - return &selfupdate.NpmResult{} + return successfulSkillsCommand()(args...) }, } } @@ -1177,12 +1185,15 @@ func TestUpdateRun_CheckAlreadyLatest_NoSideEffect(t *testing.T) { t.Fatalf("updateRun(--check, already-latest) err = %v, want nil", err) } if skillsCalled { - t.Error("RunSkillsUpdate called under --check (already-latest), want skipped (pure report)") + t.Error("skills sync called under --check (already-latest), want skipped") } - stamp, _ := skillscheck.ReadStamp() - if stamp != "1.0.20" { - t.Errorf("stamp mutated to %q under --check, want \"1.0.20\" (pure report must not write stamp)", stamp) + state, readable, err := skillscheck.ReadState() + if err != nil || !readable { + t.Fatalf("ReadState() = (_, %v, %v), want readable", readable, err) + } + if state.Version != "1.0.20" { + t.Errorf("state.Version mutated to %q under --check, want \"1.0.20\"", state.Version) } var env map[string]interface{} @@ -1204,38 +1215,26 @@ func TestUpdateRun_CheckAlreadyLatest_NoSideEffect(t *testing.T) { } } -// TestRunSkillsAndStamp_StampWriteFailureWarns verifies the stderr warning -// emission when RunSkillsUpdate succeeds but WriteStamp fails. -func TestRunSkillsAndStamp_StampWriteFailureWarns(t *testing.T) { - // Force WriteStamp to fail by pointing config dir at a path that exists - // as a regular file (so MkdirAll fails). - tmp := t.TempDir() - badPath := filepath.Join(tmp, "blocker") - if err := os.WriteFile(badPath, []byte("not-a-dir"), 0o644); err != nil { - t.Fatal(err) +func TestRunSkillsAndState_StateWriteFailureWarns(t *testing.T) { + origSync := syncSkills + syncSkills = func(opts skillscheck.SyncOptions) *skillscheck.SyncResult { + return &skillscheck.SyncResult{Err: fmt.Errorf("skills synced but state not written: denied")} } - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", badPath) + t.Cleanup(func() { syncSkills = origSync }) f, _, stderr := newTestFactory(t) - updater := &selfupdate.Updater{ - SkillsUpdateOverride: func() *selfupdate.NpmResult { - return &selfupdate.NpmResult{} // success - }, - } - got := runSkillsAndStamp(updater, f.IOStreams, "1.0.21", false) - if got == nil || got.Err != nil { - t.Fatalf("runSkillsAndStamp() = %+v, want non-nil with nil Err", got) + got := runSkillsAndState(&selfupdate.Updater{}, f.IOStreams, "1.0.21", false) + if got == nil || got.Err == nil { + t.Fatalf("runSkillsAndState() = %+v, want non-nil with write error", got) } - if !strings.Contains(stderr.String(), "warning: skills synced but stamp not written") { + if !strings.Contains(stderr.String(), "warning: skills synced but state not written") { t.Errorf("stderr does not contain warning: %q", stderr.String()) } } -// TestEmitSkillsTextHints_Success verifies the "Skills updated" success -// message is printed to ErrOut on a successful (Err == nil) result. func TestEmitSkillsTextHints_Success(t *testing.T) { f, _, stderr := newTestFactory(t) - emitSkillsTextHints(f.IOStreams, &selfupdate.NpmResult{}) // Err==nil → success + emitSkillsTextHints(f.IOStreams, &skillscheck.SyncResult{Official: []string{"lark-calendar"}, Updated: []string{"lark-calendar"}}) if !strings.Contains(stderr.String(), "Skills updated") { t.Errorf("stderr does not contain 'Skills updated': %q", stderr.String()) } diff --git a/extension/platform/README.md b/extension/platform/README.md new file mode 100644 index 000000000..d2834ddd7 --- /dev/null +++ b/extension/platform/README.md @@ -0,0 +1,186 @@ +# lark-cli Plugin SDK + +`extension/platform` is the **in-process plugin SDK** for lark-cli. +Plugins compile into a **fork** of the lark-cli binary via a blank +import; there is no `.so` loading, no RPC, no subprocess isolation. +A plugin shares the binary's address space and lifecycle. + +## 5-minute hello world + +```go +// myplugin/audit.go +package myplugin + +import ( + "context" + "log" + + "github.com/larksuite/cli/extension/platform" +) + +func init() { + platform.Register( + platform.NewPlugin("audit", "0.1.0"). + Observer(platform.After, "log-cmd", platform.All(), + func(ctx context.Context, inv platform.Invocation) { + log.Printf("cmd=%s err=%v", inv.Cmd().Path(), inv.Err()) + }). + FailOpen(). + MustBuild()) +} +``` + +Wire into a fork: + +```go +// cmd/larkx/main.go in your fork +package main + +import ( + _ "github.com/me/myplugin" // blank import → init() runs + + "github.com/larksuite/cli/cmd" + "os" +) + +func main() { os.Exit(cmd.Execute()) } +``` + +```sh +go build -o larkx ./cmd/larkx && ./larkx config plugins show +``` + +You should see `audit` in the plugin list. + +## What you can hook + +| Hook | Fires | Can block? | +| -------------------------- | ---------------------------------- | -------------------------------- | +| `Observer` | Before / After each command | No (fire-and-forget audit) | +| `Wrap` | Around each command's RunE | Yes (return `*AbortError`) | +| `On(Startup/Shutdown)` | Process lifecycle | N/A | +| `Restrict(Rule)` | Bootstrap-time, single per binary | Denies whole subtrees | + +### Plugin lifecycle + +```mermaid +sequenceDiagram + participant Host as lark-cli (host) + participant SDK as platform (SDK) + participant Plugin as your plugin + + Note over Host,Plugin: Process start (before main) + Plugin->>Plugin: init() (via blank import) + Plugin->>SDK: Register(plugin) + + Note over Host,Plugin: Bootstrap (host main) + Host->>SDK: RegisteredPlugins() + SDK-->>Host: snapshot in registration order + Host->>SDK: InstallAll() + SDK->>Plugin: Capabilities() + SDK->>Plugin: Install(Registrar) + Plugin->>SDK: Observe / Wrap / Restrict / On(Startup,Shutdown) + SDK->>Plugin: On(Startup) fire + + Note over Host,Plugin: Each command dispatch + Host->>SDK: hook chain (in registration order) + SDK->>Plugin: Observer Before + SDK->>Plugin: Wrap (around RunE) + SDK->>Plugin: Observer After + + Note over Host,Plugin: Process exit + Host->>SDK: Emit(Shutdown) + SDK->>Plugin: On(Shutdown) fire +``` + +A `command_denied` decision (from `Restrict` or strict-mode) bypasses +the `Wrap` chain entirely — observers still fire so audit plugins see +the rejected dispatch. + +## Safety contract (read this) + +- A plugin calling `Restrict()` MUST declare `FailClosed`. The Builder + flips it automatically; the lower-level `Plugin` interface rejects + the mismatch with `restricts_mismatch`. +- Only ONE plugin per binary can call `Restrict()`. Multi-plugin + Restrict is a deliberate `plugin_conflict` error (single-rule + ecosystem assumption). YAML policy at `~/.lark-cli/policy.yml` is + shadowed by any plugin Restrict. +- The `Wrap` factory runs **once per command dispatch**, not at + install time. Long-lived state (clients, caches, metrics counters) + must live on the Plugin struct or in package-level variables. +- Plugins cannot suppress a `command_denied`: the framework + physically isolates denied commands from the Wrap chain (Observers + still fire). +- Commands missing a `risk_level` annotation are denied by default + when a Rule is active. Set `Rule.AllowUnannotated = true` (or + `allow_unannotated: true` in yaml) to opt out during gradual + adoption. +- Risk annotation typos (e.g. `"wrtie"`) are always denied with + `risk_invalid` plus a "did you mean" suggestion. `AllowUnannotated` + does NOT bypass this — typo is a code bug, not a missing + annotation. + +## reason_code reference + +Every install / dispatch failure emits a `command_denied` or +`plugin_install` envelope carrying a `detail.reason_code` from the +closed enum below. Use the code (not the human-readable message) when +matching errors in agents, CI scripts, or downstream tools — the +messages are localised and may change between releases. + +### Plugin install (`error.type = plugin_install`) + +| reason_code | When it fires | Honours FailurePolicy? | +| --------------------------- | ------------------------------------------------------------------------------ | ---------------------- | +| `invalid_plugin_name` | `Plugin.Name()` doesn't match `^[a-z0-9][a-z0-9-]*$` | No — always aborts | +| `plugin_name_panic` | `Plugin.Name()` panicked | No — always aborts | +| `duplicate_plugin_name` | Two plugins return the same `Name()` | No — always aborts | +| `capabilities_panic` | `Plugin.Capabilities()` panicked | Yes | +| `invalid_capability` | `Capabilities` malformed: bad `RequiredCLIVersion`, unknown `FailurePolicy` | No — always aborts | +| `capability_unmet` | Current CLI version doesn't satisfy `RequiredCLIVersion` | Yes | +| `restricts_mismatch` | `Restricts=true` without `FailClosed`, or `Restricts` flag inconsistent w/ Install | No — always aborts | +| `invalid_hook_name` | Hook name contains `.` or doesn't match the plugin namespace | Yes | +| `duplicate_hook_name` | Same hook name registered twice within a plugin | Yes | +| `invalid_hook_registration` | Hook factory returns nil / Wrap chain re-entry / etc. | Yes | +| `invalid_rule` | Rule fails ValidateRule (malformed glob, bad MaxRisk, unknown Identity) | Yes | +| `double_restrict` | Plugin called `r.Restrict()` more than once in one Install | Yes | +| `multiple_restrict_plugins` | Two or more plugins each contributed Restrict | Yes | +| `install_failed` | `Plugin.Install` returned a non-nil error | Yes | +| `install_panic` | `Plugin.Install` panicked | Yes | + +"No — always aborts" entries are treated as **untrusted-config errors**: +the host can't honour the plugin's declared `FailurePolicy` because the +declaration itself is suspect (e.g. an `invalid_capability` plugin +might also be lying about being `FailOpen`). + +### Command dispatch (`error.type = command_denied`) + +| reason_code | Meaning | +| ----------------------- | ---------------------------------------------------------------------------------------------------------------- | +| `risk_not_annotated` | Command has no `risk_level` annotation, and the active Rule does not set `allow_unannotated: true` | +| `risk_invalid` | Command's `risk_level` is a typo / not in the `read | write | high-risk-write` taxonomy (always fail-closed) | +| `command_denylisted` | Command path matched the active Rule's `deny` glob | +| `domain_not_allowed` | Active Rule has a non-empty `allow` list and the command path did not match any glob | +| `write_not_allowed` | Command risk is `write` / `high-risk-write` and exceeds Rule `max_risk` | +| `risk_too_high` | Command risk exceeds Rule `max_risk` but is not a write (reserved for future risk levels) | +| `identity_mismatch` | Command's `supportedIdentities` does not intersect Rule `identities` | +| `aggregate_all_denied` | Aggregate stub installed on a parent group because every live child was denied | + +The `detail.layer` field distinguishes who rejected the call: +`policy` (this SDK's user-layer engine) vs. `strict_mode` +(`cmd/prune.go`'s credential-hardening pass). Agents that want to +dispatch on "any denial" should match `error.type == "command_denied"` +and ignore the layer; agents that only care about user-policy denials +should additionally check `detail.layer == "policy"`. + +## Where to go next + +- [Runnable example: audit observer](./examples/audit-observer/) +- [Runnable example: read-only policy](./examples/readonly-policy/) +- Builder API: see [`builder.go`](./builder.go) for the full DSL + (`NewPlugin`, `Observer`, `Wrap`, `Restrict`, `FailOpen`/`FailClosed`, + `MustBuild`). +- Inventory diagnostic: run `lark-cli config plugins show` after + installing your plugin to see hooks/rules attributed to your plugin + name. diff --git a/extension/platform/abort.go b/extension/platform/abort.go new file mode 100644 index 000000000..9ec99d8b5 --- /dev/null +++ b/extension/platform/abort.go @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +import "fmt" + +// AbortError is returned by a Wrapper that wants to short-circuit the +// command chain (instead of calling next). The framework converts it +// to an *output.ExitError with type "hook" so the JSON envelope carries +// the structured fields agents expect. +// +// HookName is the framework-namespaced name ("secaudit.approval"); the +// Registrar adds the plugin-name prefix automatically. +// +// Cause and Detail are optional. Cause lets the consumer use +// errors.Is/As to find the underlying cause; Detail is serialized into +// envelope.detail under the "detail" key for agent consumption. +type AbortError struct { + HookName string + Reason string + Cause error + Detail any +} + +// Error renders a human-readable message; HookName + Reason + Cause are +// included when present. +func (e *AbortError) Error() string { + msg := fmt.Sprintf("hook %q aborted: %s", e.HookName, e.Reason) + if e.Cause != nil { + msg += ": " + e.Cause.Error() + } + return msg +} + +// Unwrap enables errors.Is / errors.As to traverse to Cause. +func (e *AbortError) Unwrap() error { return e.Cause } diff --git a/extension/platform/abort_test.go b/extension/platform/abort_test.go new file mode 100644 index 000000000..364f72fb5 --- /dev/null +++ b/extension/platform/abort_test.go @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform_test + +import ( + "errors" + "io/fs" + "testing" + + "github.com/larksuite/cli/extension/platform" +) + +func TestAbortError_messageFormats(t *testing.T) { + bare := &platform.AbortError{HookName: "secaudit.approval", Reason: "needs approval"} + if got := bare.Error(); got != `hook "secaudit.approval" aborted: needs approval` { + t.Errorf("Error() = %q", got) + } + + withCause := &platform.AbortError{ + HookName: "audit.upload", + Reason: "upstream unreachable", + Cause: fs.ErrNotExist, + } + if got := withCause.Error(); got == bare.Error() { + t.Errorf("Cause should be appended to message, got %q", got) + } +} + +// errors.As must traverse Unwrap so consumers can inspect the cause +// directly. This is the contract the host's wrapAbortError relies on. +func TestAbortError_unwrapErrorsAs(t *testing.T) { + root := fs.ErrPermission + ab := &platform.AbortError{ + HookName: "x", + Reason: "y", + Cause: root, + } + if !errors.Is(ab, fs.ErrPermission) { + t.Errorf("errors.Is should find fs.ErrPermission via Unwrap") + } +} diff --git a/extension/platform/builder.go b/extension/platform/builder.go new file mode 100644 index 000000000..1bcba749f --- /dev/null +++ b/extension/platform/builder.go @@ -0,0 +1,215 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +import ( + "errors" + "fmt" + "regexp" +) + +// Builder is the ergonomic constructor for Plugin. Use it from init(): +// +// func init() { +// platform.Register( +// platform.NewPlugin("audit", "0.1.0"). +// Observer(platform.After, "log", platform.All(), auditFn). +// FailOpen(). +// MustBuild()) +// } +// +// The lower-level Plugin interface remains available for cases that +// need finer control (state on a struct, complex Install logic). The +// Builder enforces: +// +// - Name format (^[a-z0-9][a-z0-9-]*$) +// - hookName format and uniqueness within a plugin +// - Restricts ↔ FailClosed consistency (calling Restrict() implies +// FailClosed, so plugin authors cannot accidentally ship a policy +// plugin under FailOpen) +// - Rule validation via ValidateRule analogues (delegated to +// internal/cmdpolicy at install time; Builder only fast-fails +// blatantly bad input) +type Builder struct { + name string + version string + caps Capabilities + + actions []func(Registrar) + rule *Rule + + hookNames map[string]bool + errs []error +} + +var pluginNamePattern = regexp.MustCompile(`^[a-z0-9][a-z0-9-]*$`) + +// NewPlugin starts a Builder. Name format is validated lazily — errors +// surface at Build()/MustBuild() time, allowing chained calls without +// intermediate error handling. +func NewPlugin(name, version string) *Builder { + b := &Builder{ + name: name, + version: version, + hookNames: map[string]bool{}, + } + if !pluginNamePattern.MatchString(name) { + b.errs = append(b.errs, fmt.Errorf("invalid plugin name %q: must match ^[a-z0-9][a-z0-9-]*$", name)) + } + return b +} + +// RequireCLI sets Capabilities.RequiredCLIVersion (semver constraint, +// e.g. ">=1.1.0"). Empty string means no requirement. +func (b *Builder) RequireCLI(constraint string) *Builder { + b.caps.RequiredCLIVersion = constraint + return b +} + +// FailOpen sets Capabilities.FailurePolicy = FailOpen. Default when +// neither FailOpen nor FailClosed is called and Restrict is not used. +func (b *Builder) FailOpen() *Builder { + b.caps.FailurePolicy = FailOpen + return b +} + +// FailClosed sets Capabilities.FailurePolicy = FailClosed. Implicit +// when Restrict() is called. +func (b *Builder) FailClosed() *Builder { + b.caps.FailurePolicy = FailClosed + return b +} + +// Observer registers an Observer. Multiple calls accumulate. +func (b *Builder) Observer(when When, hookName string, sel Selector, fn Observer) *Builder { + if !b.validateHookName(hookName, "observer") { + return b + } + // Capture by value so the action closure doesn't share state with + // subsequent Observer() calls (Go ≥1.22 already gives each call + // its own copies of parameter values, but pinning is explicit). + w, n, s, f := when, hookName, sel, fn + b.actions = append(b.actions, func(r Registrar) { + r.Observe(w, n, s, f) + }) + return b +} + +// Wrap registers a Wrapper. Multiple calls accumulate; the host +// composes them in registration order (outermost first). +func (b *Builder) Wrap(hookName string, sel Selector, wrap Wrapper) *Builder { + if !b.validateHookName(hookName, "wrap") { + return b + } + n, s, w := hookName, sel, wrap + b.actions = append(b.actions, func(r Registrar) { + r.Wrap(n, s, w) + }) + return b +} + +// On registers a LifecycleHandler. +func (b *Builder) On(event LifecycleEvent, hookName string, fn LifecycleHandler) *Builder { + if !b.validateHookName(hookName, "on") { + return b + } + e, n, f := event, hookName, fn + b.actions = append(b.actions, func(r Registrar) { + r.On(e, n, f) + }) + return b +} + +// Restrict contributes a pruning Rule. Calling Restrict implicitly +// sets Restricts=true and FailurePolicy=FailClosed (the framework +// requires both to coexist; the builder enforces the pairing so the +// plugin author cannot accidentally ship a policy plugin under +// FailOpen). +func (b *Builder) Restrict(rule *Rule) *Builder { + if rule == nil { + b.errs = append(b.errs, errors.New("Restrict(nil): rule must not be nil")) + return b + } + b.caps.Restricts = true + b.caps.FailurePolicy = FailClosed + b.rule = rule + return b +} + +// Build returns the configured Plugin, or an error if any builder +// step found a fault. MustBuild panics on the same error. +// +// The Restrict + FailOpen mismatch is checked here, not in the chained +// setters, because the two methods may be called in either order. +func (b *Builder) Build() (Plugin, error) { + if b.rule != nil && b.caps.FailurePolicy == FailOpen { + b.errs = append(b.errs, errors.New( + "Restrict() requires FailClosed; do not call FailOpen() after Restrict()")) + } + if len(b.errs) > 0 { + return nil, errors.Join(b.errs...) + } + return &builtPlugin{ + name: b.name, + version: b.version, + caps: b.caps, + actions: b.actions, + rule: b.rule, + }, nil +} + +// MustBuild panics if Build() would return an error. Designed for +// init(): +// +// func init() { platform.Register(platform.NewPlugin(...).MustBuild()) } +// +// A panic in init runs before the framework's recover guard is +// installed and will crash the binary. That is the intended +// behaviour: a misconfigured plugin must NOT be silently registered. +func (b *Builder) MustBuild() Plugin { + p, err := b.Build() + if err != nil { + panic(fmt.Sprintf("plugin %q: %v", b.name, err)) + } + return p +} + +// validateHookName checks the grammar and uniqueness; returns false +// when the name was rejected (caller skips the action). +func (b *Builder) validateHookName(hookName, kind string) bool { + if !pluginNamePattern.MatchString(hookName) { + b.errs = append(b.errs, fmt.Errorf( + "%s %q: hookName must match ^[a-z0-9][a-z0-9-]*$", kind, hookName)) + return false + } + if b.hookNames[hookName] { + b.errs = append(b.errs, fmt.Errorf( + "%s %q: hookName already used in this plugin", kind, hookName)) + return false + } + b.hookNames[hookName] = true + return true +} + +// builtPlugin is the Plugin implementation the builder emits. +type builtPlugin struct { + name string + version string + caps Capabilities + actions []func(Registrar) + rule *Rule +} + +func (p *builtPlugin) Name() string { return p.name } +func (p *builtPlugin) Version() string { return p.version } +func (p *builtPlugin) Capabilities() Capabilities { return p.caps } +func (p *builtPlugin) Install(r Registrar) error { + if p.rule != nil { + r.Restrict(p.rule) + } + for _, action := range p.actions { + action(r) + } + return nil +} diff --git a/extension/platform/builder_test.go b/extension/platform/builder_test.go new file mode 100644 index 000000000..541271a1b --- /dev/null +++ b/extension/platform/builder_test.go @@ -0,0 +1,180 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform_test + +import ( + "context" + "strings" + "testing" + + "github.com/larksuite/cli/extension/platform" +) + +// recorder Registrar captures everything a builder schedules so the +// test can assert what Install produced without involving the host. +type recorder struct { + observers int + wrappers int + lifecycles int + rule *platform.Rule +} + +func (r *recorder) Observe(platform.When, string, platform.Selector, platform.Observer) { + r.observers++ +} +func (r *recorder) Wrap(string, platform.Selector, platform.Wrapper) { r.wrappers++ } +func (r *recorder) On(platform.LifecycleEvent, string, platform.LifecycleHandler) { r.lifecycles++ } +func (r *recorder) Restrict(rule *platform.Rule) { r.rule = rule } + +func TestBuilder_basicAssembly(t *testing.T) { + p, err := platform.NewPlugin("audit", "0.1.0"). + Observer(platform.Before, "pre", platform.All(), + func(context.Context, platform.Invocation) {}). + Observer(platform.After, "post", platform.All(), + func(context.Context, platform.Invocation) {}). + Wrap("policy", platform.All(), + func(next platform.Handler) platform.Handler { return next }). + On(platform.Startup, "boot", + func(context.Context, *platform.LifecycleContext) error { return nil }). + FailOpen(). + Build() + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "audit" || p.Version() != "0.1.0" { + t.Errorf("metadata = %q/%q", p.Name(), p.Version()) + } + if p.Capabilities().FailurePolicy != platform.FailOpen { + t.Errorf("FailurePolicy = %v, want FailOpen", p.Capabilities().FailurePolicy) + } + + r := &recorder{} + if err := p.Install(r); err != nil { + t.Fatalf("Install: %v", err) + } + if r.observers != 2 || r.wrappers != 1 || r.lifecycles != 1 { + t.Errorf("Install dispatch = observers=%d wrappers=%d lifecycles=%d", + r.observers, r.wrappers, r.lifecycles) + } +} + +// Restrict() flips Restricts=true and FailClosed automatically — a +// policy plugin can't accidentally ship under FailOpen. +func TestBuilder_restrictForcesFailClosed(t *testing.T) { + p, err := platform.NewPlugin("policy-plugin", "0.1.0"). + Restrict(&platform.Rule{Name: "read-only", MaxRisk: platform.RiskRead}). + Build() + if err != nil { + t.Fatalf("Build: %v", err) + } + caps := p.Capabilities() + if !caps.Restricts { + t.Errorf("Restricts = false, want true (Restrict() should flip it)") + } + if caps.FailurePolicy != platform.FailClosed { + t.Errorf("FailurePolicy = %v, want FailClosed (Restrict() implies it)", caps.FailurePolicy) + } + + r := &recorder{} + if err := p.Install(r); err != nil { + t.Fatalf("Install: %v", err) + } + if r.rule == nil || r.rule.Name != "read-only" { + t.Errorf("Install did not propagate Rule: %+v", r.rule) + } +} + +// Invalid name surfaces at Build time, not at NewPlugin. +func TestBuilder_invalidPluginName(t *testing.T) { + _, err := platform.NewPlugin("Has_Underscore_And_Caps", "0.1").Build() + if err == nil { + t.Fatalf("Build must reject malformed plugin name") + } + if !strings.Contains(err.Error(), "invalid plugin name") { + t.Errorf("error should mention plugin name, got: %v", err) + } +} + +// Duplicate hookName within the same builder is rejected. +func TestBuilder_duplicateHookName(t *testing.T) { + noopObs := func(context.Context, platform.Invocation) {} + _, err := platform.NewPlugin("dup", "0"). + Observer(platform.Before, "h", platform.All(), noopObs). + Observer(platform.After, "h", platform.All(), noopObs). + Build() + if err == nil { + t.Fatalf("Build must reject duplicate hookName") + } + if !strings.Contains(err.Error(), "already used") { + t.Errorf("error should mention duplicate hookName, got %v", err) + } +} + +func TestBuilder_invalidHookName(t *testing.T) { + _, err := platform.NewPlugin("p", "0"). + Observer(platform.Before, "Bad.Name", platform.All(), + func(context.Context, platform.Invocation) {}). + Build() + if err == nil { + t.Fatalf("Build must reject hookName with dot") + } +} + +// MustBuild panics on builder error. +func TestBuilder_mustBuildPanicsOnError(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("MustBuild must panic when Build would fail") + } + }() + _ = platform.NewPlugin("BadName", "0").MustBuild() +} + +func TestBuilder_restrictNilRejected(t *testing.T) { + _, err := platform.NewPlugin("p", "0").Restrict(nil).Build() + if err == nil { + t.Fatalf("Restrict(nil) must produce error") + } +} + +func TestBuilder_capabilitiesSetters(t *testing.T) { + p, err := platform.NewPlugin("p", "0.1"). + RequireCLI(">=1.0.0"). + FailClosed(). + Build() + if err != nil { + t.Fatalf("Build: %v", err) + } + caps := p.Capabilities() + if caps.RequiredCLIVersion != ">=1.0.0" { + t.Errorf("RequiredCLIVersion = %q, want >=1.0.0", caps.RequiredCLIVersion) + } + if caps.FailurePolicy != platform.FailClosed { + t.Errorf("FailurePolicy = %v, want FailClosed", caps.FailurePolicy) + } +} + +func TestBuilder_restrictThenFailOpenRejected(t *testing.T) { + rule := &platform.Rule{Name: "r", MaxRisk: platform.RiskRead} + _, err := platform.NewPlugin("p", "0").Restrict(rule).FailOpen().Build() + if err == nil { + t.Fatalf("Build must reject Restrict()+FailOpen() mismatch") + } + if !strings.Contains(err.Error(), "FailClosed") { + t.Errorf("error should mention FailClosed, got: %v", err) + } +} + +// Restrict() flips FailurePolicy to FailClosed; the previous FailOpen() +// is overridden. Pin it so the Build-time validation does not over-reject. +func TestBuilder_failOpenThenRestrictOK(t *testing.T) { + rule := &platform.Rule{Name: "r", MaxRisk: platform.RiskRead} + p, err := platform.NewPlugin("p", "0").FailOpen().Restrict(rule).Build() + if err != nil { + t.Fatalf("FailOpen()+Restrict() must succeed (Restrict flips to FailClosed): %v", err) + } + if p.Capabilities().FailurePolicy != platform.FailClosed { + t.Errorf("FailurePolicy = %v, want FailClosed", p.Capabilities().FailurePolicy) + } +} diff --git a/extension/platform/capabilities.go b/extension/platform/capabilities.go new file mode 100644 index 000000000..fc517c426 --- /dev/null +++ b/extension/platform/capabilities.go @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +// FailurePolicy controls what the framework does when a plugin's install +// stage fails (Capabilities() panics, Install returns error, etc.). +type FailurePolicy int + +const ( + // FailOpen (default) — log a warning and skip THIS plugin; the rest + // of the CLI keeps running. Appropriate for pure-observer plugins + // where missing audit data is preferable to a broken CLI. + FailOpen FailurePolicy = iota + + // FailClosed — abort the entire CLI startup. Required for any + // plugin that contributes Restrict() (a missing policy plugin = + // missing security boundary) or that owns any safety-sensitive + // concern. Enforced by the framework: Capabilities.Restricts=true + // must pair with FailurePolicy=FailClosed. + FailClosed +) + +// Capabilities declares the plugin's self-description. Plugin.Capabilities +// MUST be implemented even when every field would be its zero value -- +// the requirement keeps FailurePolicy / Restricts visible to the author +// at the moment they write the plugin, preventing the "I just want to +// add an audit observer" mistake of accidentally shipping a policy +// plugin with the default FailOpen. +type Capabilities struct { + // RequiredCLIVersion is a semver constraint (e.g. ">=1.1.0"). + // Plugins that need a specific framework feature should declare + // the minimum version they tested against; the host fails the + // install when the running CLI is older. Empty string means "no + // version requirement". + RequiredCLIVersion string + + // Restricts declares whether Install will call r.Restrict(). The + // framework enforces consistency: declaring Restricts=true and + // then NOT calling r.Restrict (or vice versa) aborts the install + // with the `restricts_mismatch` reason_code. This pre-flight + // declaration also lets `config policy show` introspect "which + // plugins are policy plugins" without running them. + Restricts bool + + // FailurePolicy decides what happens on install failure. See the + // constants above; the framework requires FailClosed whenever + // Restricts=true. + FailurePolicy FailurePolicy +} diff --git a/extension/platform/doc.go b/extension/platform/doc.go new file mode 100644 index 000000000..8897876c2 --- /dev/null +++ b/extension/platform/doc.go @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +// Package platform is the single public extension contract for lark-cli. +// +// External integrators (plugin authors, embedding platforms) only import this +// package; everything else under internal/ is off-limits. +// +// Plugin lifecycle: +// +// - Plugin - the interface every plugin implements (Name / Version / Capabilities / Install) +// - Registrar - what Install receives; the four registration verbs (Observe / Wrap / On / Restrict) +// - Capabilities - declared up front: FailurePolicy (FailOpen | FailClosed) and Restricts +// - Register - process-wide entry point; plugins call this from init() +// +// Hook surface (what Install hangs off Registrar): +// +// - Observer - side-effect-only callback, panic-safe, runs Before / After RunE +// - Wrapper - middleware that can short-circuit via AbortError +// - LifecycleHandler - reacts to Startup / Shutdown / etc. (LifecycleEvent + When) +// - Selector - chooses which commands a hook applies to (ByDomain / ByWrite / ByReadOnly / ByExactRisk / And / Or / Not, etc.) +// - Handler - the inner "run the command" function Wrappers compose around +// - Invocation - per-call context passed to handlers (Cmd view + DeniedByPolicy / DenialLayer / DenialPolicySource) +// - AbortError - structured short-circuit error from a Wrapper; framework namespaces HookName +// +// Policy surface (what Restrict contributes, also consumable from yaml policy): +// +// - Rule - declarative policy rule (Allow / Deny / MaxRisk / Identities / AllowUnannotated) +// - CommandView - read-only command metadata view (Path / Domain / Risk / Identities) +// - Risk / Identity - defined string types with closed taxonomies; ParseRisk / ParseIdentity +// convert raw strings (yaml, cobra annotation) into typed values; r.Rank() +// gives a comparable rank for the read < write < high-risk-write ordering +// - CommandDeniedError - structured error returned to denied callers +// +// Stability: every exported symbol here is part of the contract. Internal +// orchestration (staging, validation, RunE wrapping, denial guard) lives +// under internal/platform, internal/hook and internal/cmdpolicy and is not +// importable by third parties. +package platform diff --git a/extension/platform/errors.go b/extension/platform/errors.go new file mode 100644 index 000000000..7bd99f2d2 --- /dev/null +++ b/extension/platform/errors.go @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +import "fmt" + +// CommandDeniedError is the structured error returned by a denyStub. Every +// pruned-command execution path -- direct invocation, alias expansion, +// internal call -- returns this exact type. It is wire-compatible with the +// output.ExitError envelope via the Layer (== error.type) field and the +// detail map produced by ExitError(). +// +// Layer values: +// +// - "strict_mode" -- credential strict-mode rejected the command +// - "policy" -- user-layer Rule rejected the command +// +// PolicySource is a free-form identifier such as "plugin:secaudit", +// "yaml:mywork", or "strict-mode". Reason fields: +// +// - ReasonCode -- closed enum, see tech-doc 5.3 (e.g. write_not_allowed, +// all_children_denied, identity_not_supported) +// - Reason -- human-readable text +type CommandDeniedError struct { + Path string + Layer string + PolicySource string + RuleName string + ReasonCode string + Reason string +} + +// Error implements the standard error interface. +func (e *CommandDeniedError) Error() string { + if e.Reason != "" { + return fmt.Sprintf("command %q denied: %s", e.Path, e.Reason) + } + return fmt.Sprintf("command %q denied (%s/%s)", e.Path, e.Layer, e.ReasonCode) +} diff --git a/extension/platform/errors_test.go b/extension/platform/errors_test.go new file mode 100644 index 000000000..767e00d89 --- /dev/null +++ b/extension/platform/errors_test.go @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform_test + +import ( + "errors" + "testing" + + "github.com/larksuite/cli/extension/platform" +) + +func TestCommandDeniedError_messageFormats(t *testing.T) { + withReason := &platform.CommandDeniedError{ + Path: "docs/+update", + Layer: "policy", + ReasonCode: "write_not_allowed", + Reason: "write disabled by policy", + } + if got := withReason.Error(); got != `command "docs/+update" denied: write disabled by policy` { + t.Fatalf("Error() with Reason = %q", got) + } + + noReason := &platform.CommandDeniedError{ + Path: "docs/+update", + Layer: "strict_mode", + ReasonCode: "identity_not_supported", + } + if got := noReason.Error(); got != `command "docs/+update" denied (strict_mode/identity_not_supported)` { + t.Fatalf("Error() without Reason = %q", got) + } +} + +// errors.As must work so consumers can type-assert without unwrap gymnastics. +func TestCommandDeniedError_satisfiesErrorsAs(t *testing.T) { + var err error = &platform.CommandDeniedError{Path: "x"} + var target *platform.CommandDeniedError + if !errors.As(err, &target) { + t.Fatalf("errors.As should match CommandDeniedError") + } + if target.Path != "x" { + t.Fatalf("target.Path = %q, want %q", target.Path, "x") + } +} diff --git a/extension/platform/example_test.go b/extension/platform/example_test.go new file mode 100644 index 000000000..078398252 --- /dev/null +++ b/extension/platform/example_test.go @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform_test + +import ( + "context" + "fmt" + + "github.com/larksuite/cli/extension/platform" +) + +// ExampleNewPlugin_observer registers an audit Observer that fires +// after every command, regardless of success or failure. +func ExampleNewPlugin_observer() { + p, _ := platform.NewPlugin("audit", "0.1.0"). + Observer(platform.After, "log", platform.All(), + func(ctx context.Context, inv platform.Invocation) { + _ = inv.Cmd().Path() // do something useful with the command + }). + FailOpen(). + Build() + fmt.Println(p.Name(), p.Version()) + // Output: audit 0.1.0 +} + +// ExampleNewPlugin_wrapper registers a Wrap that short-circuits any +// write-class command. The framework converts the returned +// *AbortError into a structured "hook" envelope; observers still +// fire on the After stage so audit sees the attempt. +func ExampleNewPlugin_wrapper() { + p, _ := platform.NewPlugin("policy-plugin", "0.1.0"). + Wrap("block-writes", platform.ByWrite(), + func(next platform.Handler) platform.Handler { + return func(ctx context.Context, inv platform.Invocation) error { + return &platform.AbortError{ + HookName: "block-writes", + Reason: "writes are disabled for this session", + } + } + }). + FailOpen(). + Build() + fmt.Println(p.Capabilities().FailurePolicy == platform.FailOpen) + // Output: true +} + +// ExampleNewPlugin_restrict registers a policy plugin that allows +// only docs/* read commands. Note that Restrict() implicitly sets +// FailClosed — a policy plugin must abort the binary if it fails to +// install, not silently disappear. +func ExampleNewPlugin_restrict() { + p, _ := platform.NewPlugin("readonly-docs", "0.1.0"). + Restrict(&platform.Rule{ + Name: "docs-only", + Allow: []string{"docs/**"}, + MaxRisk: platform.RiskRead, + }). + Build() + caps := p.Capabilities() + fmt.Println(caps.Restricts, caps.FailurePolicy == platform.FailClosed) + // Output: true true +} diff --git a/extension/platform/examples/.gitignore b/extension/platform/examples/.gitignore new file mode 100644 index 000000000..6c34736fb --- /dev/null +++ b/extension/platform/examples/.gitignore @@ -0,0 +1,2 @@ +audit-observer/audit-observer +readonly-policy/readonly-policy diff --git a/extension/platform/examples/README.md b/extension/platform/examples/README.md new file mode 100644 index 000000000..c7eab33d7 --- /dev/null +++ b/extension/platform/examples/README.md @@ -0,0 +1,13 @@ +# lark-cli plugin examples + +Runnable fork-and-blank-import examples that demonstrate the Plugin +SDK in production-shape. Each subdirectory is a complete `main` +package: `go build .` produces a working CLI. + +| Example | What it shows | +| --- | --- | +| [audit-observer](./audit-observer/) | Simplest possible plugin: one Observer matching every command, logs to stderr. | +| [readonly-policy](./readonly-policy/) | Policy plugin: `Restrict()` with `MaxRisk=read`, demonstrates the `FailClosed` + `Restricts=true` auto-pairing. | + +All examples are built by CI (`make examples-build`) so they cannot +silently drift from the SDK. diff --git a/extension/platform/examples/audit-observer/README.md b/extension/platform/examples/audit-observer/README.md new file mode 100644 index 000000000..a860a4dd9 --- /dev/null +++ b/extension/platform/examples/audit-observer/README.md @@ -0,0 +1,26 @@ +# Example: audit observer + +The simplest possible lark-cli plugin: one After observer that logs +every dispatched command to stderr (success or failure). + +## Build & run + +```sh +cd extension/platform/examples/audit-observer +go build -o audit-cli . +./audit-cli config plugins show +# {"plugins":[{"name":"audit", ...}], "total":1} + +./audit-cli api GET /open-apis/contact/v3/users/me +# [audit] api ok (on stderr) +``` + +## Key points + +- `platform.NewPlugin(...).MustBuild()` from `init()`. The blank + import of this package in `main.go` triggers `init()`. +- `Observer(platform.After, ...)` runs **after** the command's RunE, + even on failure (Observers cannot prevent execution). +- `FailOpen()` means: if Install ever fails, the binary logs a + warning and continues without this plugin. Right default for + audit-only plugins. diff --git a/extension/platform/examples/audit-observer/main.go b/extension/platform/examples/audit-observer/main.go new file mode 100644 index 000000000..2c3c30534 --- /dev/null +++ b/extension/platform/examples/audit-observer/main.go @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +// Command audit-observer is a runnable fork of lark-cli that logs +// every dispatched command to stderr. Demonstrates the simplest +// possible plugin: one After observer matching All commands. +// +// Build & run: +// +// cd extension/platform/examples/audit-observer +// go build -o audit-cli . +// ./audit-cli config plugins show # see "audit" in the list +// ./audit-cli api GET /open-apis/... # observer logs to stderr +package main + +import ( + "context" + "fmt" + "log" + "os" + + "github.com/larksuite/cli/cmd" + "github.com/larksuite/cli/extension/platform" +) + +func init() { + platform.Register( + platform.NewPlugin("audit", "0.1.0"). + Observer(platform.After, "log", platform.All(), + func(ctx context.Context, inv platform.Invocation) { + path := inv.Cmd().Path() + if err := inv.Err(); err != nil { + fmt.Fprintf(os.Stderr, "[audit] %s FAILED: %v\n", path, err) + } else { + log.Printf("[audit] %s ok", path) + } + }). + FailOpen(). + MustBuild()) +} + +func main() { + os.Exit(cmd.Execute()) +} diff --git a/extension/platform/examples/readonly-policy/README.md b/extension/platform/examples/readonly-policy/README.md new file mode 100644 index 000000000..9c0963fba --- /dev/null +++ b/extension/platform/examples/readonly-policy/README.md @@ -0,0 +1,61 @@ +# Example: read-only policy + +A policy plugin that installs a `Rule` allowing only `docs/*` and +`im/*` read commands. Any write command produces a structured +`command_denied` envelope. + +## Build & run + +```sh +cd extension/platform/examples/readonly-policy +go build -o readonly-cli . + +./readonly-cli config policy show +# { +# "source": "plugin", +# "source_name": "readonly", +# "denied_paths": N, +# "rule": { +# "name": "agent-readonly", +# "allow": ["docs/**", "im/**"], +# "deny": [], +# "max_risk": "read", +# "identities": [], +# "allow_unannotated": false +# } +# } + +./readonly-cli docs +update --doc-token X --content Y +# {"ok":false,"error":{ +# "type":"command_denied", +# "detail":{ +# "layer":"policy", +# "policy_source":"plugin:readonly", +# "rule_name":"agent-readonly", +# "reason_code":"write_not_allowed" +# } +# }} + +./readonly-cli docs +fetch --doc-token X +# Normal read response (assuming credentials) +``` + +## Key points + +- `Restrict(&Rule{...})` is the only call needed — the Builder + flips Capabilities to `Restricts=true, FailurePolicy=FailClosed` + automatically. A policy plugin that silently fails to install + would erase the security boundary, so FailClosed is enforced. +- `MaxRisk: platform.RiskRead` rejects any command annotated + write / high-risk-write. +- `AllowUnannotated` is left default (false): unannotated commands + are denied with `risk_not_annotated`. Set it to true if you need + a gradual-adoption window for the lark-cli main tree. + +## Caveats + +- A binary may have **only one** plugin calling `Restrict()`. Two + policy plugins is a deliberate `plugin_conflict` configuration + error. +- This Rule shadows any `~/.lark-cli/policy.yml` — plugin Rule + wins per the resolver precedence. diff --git a/extension/platform/examples/readonly-policy/main.go b/extension/platform/examples/readonly-policy/main.go new file mode 100644 index 000000000..21b674bdc --- /dev/null +++ b/extension/platform/examples/readonly-policy/main.go @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +// Command readonly-policy is a runnable fork of lark-cli that +// installs a Rule permitting only docs/* and im/* read commands. +// Any write command produces a structured command_denied envelope. +// +// Build & run: +// +// cd extension/platform/examples/readonly-policy +// go build -o readonly-cli . +// ./readonly-cli docs +update --doc-token X --content Y +// # {"ok":false,"error":{"type":"command_denied", ...}} +// +// ./readonly-cli config policy show +// # shows the active Rule with source=plugin:readonly +package main + +import ( + "os" + + "github.com/larksuite/cli/cmd" + "github.com/larksuite/cli/extension/platform" +) + +func init() { + platform.Register( + platform.NewPlugin("readonly", "0.1.0"). + Restrict(&platform.Rule{ + Name: "agent-readonly", + Description: "Only read-class docs/im commands. Suitable for AI-agent sessions.", + Allow: []string{"docs/**", "im/**"}, + MaxRisk: platform.RiskRead, + // AllowUnannotated stays default false (fail-closed): + // unannotated commands are denied, surfacing missing + // risk_level annotations early in adoption. + }). + MustBuild()) + // Note: Restrict() implicitly sets Restricts=true and FailClosed. + // No need to call FailClosed() explicitly. +} + +func main() { + os.Exit(cmd.Execute()) +} diff --git a/extension/platform/handler.go b/extension/platform/handler.go new file mode 100644 index 000000000..c08635962 --- /dev/null +++ b/extension/platform/handler.go @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +import "context" + +// Handler is the inner function shape every Wrapper composes. It IS the +// "command business logic" from the Wrapper's perspective -- calling +// next(ctx, inv) inside a Wrapper means "let the command proceed"; +// returning early without calling next short-circuits. +type Handler func(ctx context.Context, inv Invocation) error + +// Observer is a side-effect-only command hook. No return value, no +// next-chain control: an Observer can read Invocation but cannot prevent +// the command from running. Used for audit, metrics, and completion +// logs. After-stage Observers fire even when the command failed +// (Invocation.Err() is populated in that case). +type Observer func(ctx context.Context, inv Invocation) + +// Wrapper is a middleware-style hook: it receives the rest of the +// handler chain and returns a wrapped version. The Wrapper decides +// whether to call next (allow), abstain (deny, return an AbortError), +// or transform the result. Multiple Wrappers compose left-to-right by +// registration order; the outermost runs first. +// +// ⚠️ IMPORTANT: The factory function `func(next Handler) Handler` is +// invoked ONCE PER COMMAND DISPATCH, not once at plugin install. This +// lets the framework recover from a panicking factory and convert it +// to a structured envelope, but it means any state captured by the +// outer closure is rebuilt on every command. Long-lived state (HTTP +// clients, caches, metrics counters) MUST live on the Plugin struct +// or in package-level variables, never in factory-local captures. +type Wrapper func(next Handler) Handler + +// LifecycleHandler runs at one of the process-level LifecycleEvent +// slots. The handler may use ctx for cancellation; in the Shutdown +// case the framework supplies a context with a 2-second hard deadline. +type LifecycleHandler func(ctx context.Context, lc *LifecycleContext) error diff --git a/extension/platform/identity.go b/extension/platform/identity.go new file mode 100644 index 000000000..1354f37dd --- /dev/null +++ b/extension/platform/identity.go @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +import "fmt" + +// Identity is the identity taxonomy a command supports. +// +// Defined type (not alias) so plugin authors get compile-time + +// IDE help; raw-string boundaries (yaml, cobra annotation) cross +// through ParseIdentity. +type Identity string + +const ( + IdentityUser Identity = "user" + IdentityBot Identity = "bot" +) + +// ParseIdentity converts a raw string into an Identity. Returns +// ("", nil) for empty input ("not specified"), error for unrecognised +// values. Matching is strict (case-sensitive, no trim). +func ParseIdentity(s string) (Identity, error) { + if s == "" { + return "", nil + } + id := Identity(s) + if id != IdentityUser && id != IdentityBot { + return "", fmt.Errorf("invalid identity %q: must be user|bot", s) + } + return id, nil +} + +// IsValid reports whether i is one of the two recognised values. +func (i Identity) IsValid() bool { + return i == IdentityUser || i == IdentityBot +} + +// String returns the underlying string. +func (i Identity) String() string { return string(i) } diff --git a/extension/platform/invocation.go b/extension/platform/invocation.go new file mode 100644 index 000000000..33377558c --- /dev/null +++ b/extension/platform/invocation.go @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +import "time" + +// Invocation is the per-command data a Wrapper / Observer receives. It +// is a read-only interface: the framework implementation lives in +// internal/hook and is never visible to plugins, so plugin code cannot +// mutate denial state. +// +// The interface is deliberately NOT a context.Context — it is data only, +// no cancellation. ctx (from the handler signature) carries +// cancellation / timeout / trace propagation. +// +// Accessor semantics: +// +// - Cmd / Args / Started are populated before the first hook fires +// - Err is populated for After observers and the post-next portion of +// a Wrapper (the value the wrapped handler returned) +// - DeniedByPolicy / DenialLayer / DenialPolicySource are populated by +// the framework's denial guard before any hook runs +type Invocation interface { + // Cmd returns the read-only metadata view of the dispatched command. + Cmd() CommandView + + // Args returns a fresh copy of the positional args. + Args() []string + + // Started is the wall-clock time the outermost RunE wrapper began. + Started() time.Time + + // Err is the error the wrapped handler returned. Populated for + // After observers and the post-next portion of a Wrapper. nil + // before the handler runs. + Err() error + + // DeniedByPolicy reports whether the command was rejected by either + // strict-mode or user-layer policy before the chain reached the + // hook. Observers fire even for denied commands (audit case); Wrap + // is physically isolated by the framework so plugins do not need + // to check this themselves before calling next. + DeniedByPolicy() bool + + // DenialLayer returns the layer that rejected the command: + // + // "" - not denied + // "strict_mode" - credential strict-mode + // "policy" - user-layer Rule (Plugin.Restrict() or yaml) + DenialLayer() string + + // DenialPolicySource returns the specific source identifier + // ("plugin:secaudit", "yaml", "strict-mode"). Empty when not denied. + DenialPolicySource() string +} diff --git a/extension/platform/lifecycle.go b/extension/platform/lifecycle.go new file mode 100644 index 000000000..63a05487b --- /dev/null +++ b/extension/platform/lifecycle.go @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +// When selects the temporal slot for command-level Observer hooks. The +// framework wraps every command's RunE so both stages always fire, even +// when RunE itself returns an error (After is failure-safe). +type When int + +const ( + // Before fires immediately before the command's business logic. + Before When = iota + + // After fires after the command's business logic (or its denyStub + // in the denied path). Always fires, even when RunE returned an + // error; Invocation.Err is populated in that case. + After +) + +// LifecycleEvent selects the temporal slot for Lifecycle hooks. These are +// process-level events that fire once per binary execution, not per +// command. Only Startup and Shutdown are defined: additional bootstrap +// phases can be added later as a non-breaking addition if a concrete +// consumer surfaces. +type LifecycleEvent int + +const ( + // Startup fires after plugin install has committed; Plugin.On + // handlers for Startup are guaranteed to be registered before this + // event is emitted (so they can receive it). + Startup LifecycleEvent = iota + + // Shutdown fires once before the process exits. Handler total + // execution is bounded by a hard 2s timeout to prevent a + // misbehaving handler from holding up exit. + Shutdown +) + +// LifecycleContext is passed to LifecycleHandler. Err is the error from +// the preceding command (when Event == Shutdown after a failed RunE); +// otherwise nil. +type LifecycleContext struct { + Event LifecycleEvent + Err error +} diff --git a/extension/platform/plugin.go b/extension/platform/plugin.go new file mode 100644 index 000000000..303f677b5 --- /dev/null +++ b/extension/platform/plugin.go @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +// Plugin is the single contract a third-party / embedding integrator +// implements to extend lark-cli. Four methods, every one mandatory. +// +// Name must match the grammar ^[a-z0-9][a-z0-9-]*$. The "." character +// is forbidden so plugin-name + hookName namespacing never produces +// ambiguous joins. +// +// Capabilities must be implemented even when every field is zero. The +// requirement is deliberate: it keeps FailurePolicy / Restricts in the +// author's eyeline. +// +// Install runs once during the Bootstrap pipeline. The plugin uses the +// supplied Registrar to register hooks and (optionally) a Rule. Errors +// returned from Install honour the plugin's Capabilities.FailurePolicy +// (fail-open warns + skips this plugin; fail-closed aborts the CLI). +type Plugin interface { + Name() string + Version() string + Capabilities() Capabilities + Install(r Registrar) error +} diff --git a/extension/platform/register.go b/extension/platform/register.go new file mode 100644 index 000000000..fe22059dc --- /dev/null +++ b/extension/platform/register.go @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +import "sync" + +// Register adds a plugin to the global registry. Plugins call this from +// init() (typically through a blank import in the embedder's main). +// +// Register is intentionally tolerant of malformed input: validation +// happens later in the host's InstallAll phase, where errors can be +// surfaced through the typed plugin_install envelope. Register itself +// never panics so that init-time problems do not crash the binary +// before main has a chance to install its recover-and-envelope logic. +// +// The registry holds plugins in insertion order so InstallAll can +// process them deterministically. +func Register(p Plugin) { + pluginRegistry.add(p) +} + +// RegisteredPlugins returns a snapshot of the global plugin registry. +// Order matches Register insertion. The host reads this once during +// InstallAll. +func RegisteredPlugins() []Plugin { + return pluginRegistry.snapshot() +} + +// pluginRegistry is the package-level singleton. The mutex protects +// concurrent Register calls -- harmless in practice (init runs +// serially) but cheap insurance. +var pluginRegistry = ®istry{} + +type registry struct { + mu sync.Mutex + plugins []Plugin +} + +func (r *registry) add(p Plugin) { + r.mu.Lock() + defer r.mu.Unlock() + r.plugins = append(r.plugins, p) +} + +func (r *registry) snapshot() []Plugin { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]Plugin, len(r.plugins)) + copy(out, r.plugins) + return out +} + +func (r *registry) reset() { + r.mu.Lock() + defer r.mu.Unlock() + r.plugins = nil +} diff --git a/extension/platform/register_test.go b/extension/platform/register_test.go new file mode 100644 index 000000000..80425e701 --- /dev/null +++ b/extension/platform/register_test.go @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform_test + +import ( + "testing" + + "github.com/larksuite/cli/extension/platform" +) + +type stubPlugin struct{ name string } + +func (s stubPlugin) Name() string { return s.name } +func (s stubPlugin) Version() string { return "0.0.1" } +func (s stubPlugin) Capabilities() platform.Capabilities { return platform.Capabilities{} } +func (s stubPlugin) Install(platform.Registrar) error { return nil } + +// Tests should always reset the global registry to keep them +// independent. Verifies the reset hook is functional. +func TestRegister_preservesInsertionOrder(t *testing.T) { + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + + platform.Register(stubPlugin{name: "a"}) + platform.Register(stubPlugin{name: "b"}) + platform.Register(stubPlugin{name: "c"}) + + got := platform.RegisteredPlugins() + want := []string{"a", "b", "c"} + if len(got) != len(want) { + t.Fatalf("got %d plugins, want %d", len(got), len(want)) + } + for i, p := range got { + if p.Name() != want[i] { + t.Errorf("plugins[%d] = %q, want %q", i, p.Name(), want[i]) + } + } +} + +func TestRegister_resetClears(t *testing.T) { + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + platform.Register(stubPlugin{name: "a"}) + if len(platform.RegisteredPlugins()) != 1 { + t.Fatalf("expected 1 plugin") + } + platform.ResetForTesting() + if len(platform.RegisteredPlugins()) != 0 { + t.Fatalf("expected reset to clear") + } +} diff --git a/extension/platform/register_testing.go b/extension/platform/register_testing.go new file mode 100644 index 000000000..8d32f67f0 --- /dev/null +++ b/extension/platform/register_testing.go @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +// ResetForTesting clears the global plugin registry. Exposed for test +// isolation only — plugin authors and SDK consumers must NOT call this +// from production code. The function is exported (rather than placed in +// an internal test-only file) so that `go test ./...` works for every +// downstream package without an extra build tag. +// +// Tests that exercise plugin registration must defer +// `t.Cleanup(platform.ResetForTesting)` so subsequent tests start from a +// clean slate. The helper is NOT goroutine-safe across concurrent +// `t.Parallel()` tests — the global registry is shared process state. +func ResetForTesting() { pluginRegistry.reset() } diff --git a/extension/platform/registrar.go b/extension/platform/registrar.go new file mode 100644 index 000000000..8774050bf --- /dev/null +++ b/extension/platform/registrar.go @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +// Registrar is the imperative API a plugin uses inside its Install +// method to wire up hooks and rules. The framework provides a staging +// implementation that buffers calls and commits them atomically when +// Install returns nil; failure rolls everything back. +// +// hookName must match the grammar ^[a-z0-9][a-z0-9-]*$ (no dots). The +// framework prepends the plugin's Name() with a dot so the global hook +// identifier is "{plugin}.{hook}". A plugin cannot register two hooks +// with the same name in the same Install call. +// +// Restrict may be called at most once per plugin; multiple plugins +// contributing Restrict() is a configuration error (the resolver +// aborts startup). +type Registrar interface { + // Observe registers a side-effect-only command hook at the given + // When stage. The selector decides which commands it fires on. + Observe(when When, hookName string, sel Selector, fn Observer) + + // Wrap registers a middleware-style command hook. The Wrap chain + // composes left-to-right in registration order; the outermost + // Wrapper runs first. + Wrap(hookName string, sel Selector, w Wrapper) + + // On registers a lifecycle handler for the given event. + On(event LifecycleEvent, hookName string, fn LifecycleHandler) + + // Restrict contributes a pruning Rule. The framework merges it + // with the yaml-sourced Rule using single-rule semantics: plugin + // rule wins, but two plugins both calling Restrict abort startup. + Restrict(r *Rule) +} diff --git a/extension/platform/risk.go b/extension/platform/risk.go new file mode 100644 index 000000000..287c5ff8a --- /dev/null +++ b/extension/platform/risk.go @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +import "fmt" + +// Risk is the three-tier risk taxonomy declared on every command. +// +// A defined type (not an alias of string) so plugin authors get +// compile-time + IDE candidate help when passing the constants below. +// Crossing the string boundary (yaml, cobra annotation) goes through +// ParseRisk so typos surface as `risk_invalid` rather than silently +// flowing through. +type Risk string + +const ( + RiskRead Risk = "read" + RiskWrite Risk = "write" + RiskHighRiskWrite Risk = "high-risk-write" +) + +// riskOrder maps the Risk taxonomy to a comparable rank. The pruning +// engine compares ranks for the MaxRisk axis. +var riskOrder = map[Risk]int{ + RiskRead: 0, + RiskWrite: 1, + RiskHighRiskWrite: 2, +} + +// ParseRisk converts a raw string (yaml, cobra annotation) into a Risk. +// +// - s == "" → ("", nil) "not specified" +// - s 在闭合枚举 → (Risk(s), nil) OK +// - s 不在枚举内 → ("", error) invalid +// +// The (absent vs invalid) split mirrors the cmdpolicy engine's +// risk_not_annotated vs risk_invalid reason codes — callers can treat +// the "" + nil case as "not specified" without losing the distinction +// from a typo. +// +// Matching is strict: "Read" / "READ" / " read " are all rejected. +// annotation is developer code, not user input — strict matching is +// the typo-catch mechanism, not a normalisation opportunity. +func ParseRisk(s string) (Risk, error) { + if s == "" { + return "", nil + } + r := Risk(s) + if _, ok := riskOrder[r]; !ok { + return "", fmt.Errorf("invalid risk %q: must be read|write|high-risk-write", s) + } + return r, nil +} + +// IsValid reports whether r is one of the three recognised values. +func (r Risk) IsValid() bool { + _, ok := riskOrder[r] + return ok +} + +// Rank returns the comparable rank of r. ok=false when r is not in the +// closed taxonomy. +func (r Risk) Rank() (rank int, ok bool) { + rank, ok = riskOrder[r] + return rank, ok +} + +// String returns the underlying string. Useful for yaml/json output +// and cobra annotation injection. +func (r Risk) String() string { return string(r) } diff --git a/extension/platform/risk_test.go b/extension/platform/risk_test.go new file mode 100644 index 000000000..d934a03c5 --- /dev/null +++ b/extension/platform/risk_test.go @@ -0,0 +1,120 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform_test + +import ( + "testing" + + "github.com/larksuite/cli/extension/platform" +) + +func TestRisk_Rank_orderedTaxonomy(t *testing.T) { + cases := []struct { + level platform.Risk + want int + }{ + {platform.RiskRead, 0}, + {platform.RiskWrite, 1}, + {platform.RiskHighRiskWrite, 2}, + } + for _, c := range cases { + got, ok := c.level.Rank() + if !ok || got != c.want { + t.Errorf("Risk(%q).Rank() = (%d,%v), want (%d,true)", c.level, got, ok, c.want) + } + } + + if _, ok := platform.Risk("unknown-level").Rank(); ok { + t.Fatalf("unknown-level.Rank() ok should be false") + } + if _, ok := platform.Risk("").Rank(); ok { + t.Fatalf("empty.Rank() ok should be false (signals 'no risk annotation')") + } +} + +// The Risk ordering must be strict: read < write < high-risk-write. The +// policy engine compares ranks; a regression that swaps the order would +// silently let high-risk commands pass under MaxRisk=write. +func TestRisk_Rank_strictlyMonotonic(t *testing.T) { + r1, _ := platform.RiskRead.Rank() + r2, _ := platform.RiskWrite.Rank() + r3, _ := platform.RiskHighRiskWrite.Rank() + if !(r1 < r2 && r2 < r3) { + t.Fatalf("Risk ranks not monotonic: read=%d write=%d high=%d", r1, r2, r3) + } +} + +func TestRisk_IsValid(t *testing.T) { + valid := []platform.Risk{platform.RiskRead, platform.RiskWrite, platform.RiskHighRiskWrite} + for _, r := range valid { + if !r.IsValid() { + t.Errorf("%q.IsValid() = false, want true", r) + } + } + invalid := []platform.Risk{"", "wrtie", "Read", "READ", " read "} + for _, r := range invalid { + if r.IsValid() { + t.Errorf("%q.IsValid() = true, want false", r) + } + } +} + +// ParseRisk distinguishes absent (empty input) from invalid (typo). +// The absent / invalid split mirrors the cmdpolicy engine's +// risk_not_annotated vs risk_invalid reason codes. +func TestParseRisk(t *testing.T) { + // Empty -> ("", nil) — "not specified" + got, err := platform.ParseRisk("") + if err != nil || got != "" { + t.Errorf(`ParseRisk("") = (%q,%v), want ("",nil)`, got, err) + } + + // Valid values pass through + for _, want := range []platform.Risk{platform.RiskRead, platform.RiskWrite, platform.RiskHighRiskWrite} { + got, err := platform.ParseRisk(string(want)) + if err != nil || got != want { + t.Errorf("ParseRisk(%q) = (%q,%v), want (%q,nil)", want, got, err, want) + } + } + + // Typo -> error, strict matching (case-sensitive, no trim) + bad := []string{"wrtie", "Read", "READ", " read ", "high_risk_write"} + for _, s := range bad { + got, err := platform.ParseRisk(s) + if err == nil { + t.Errorf("ParseRisk(%q) succeeded (got %q), want error", s, got) + } + if got != "" { + t.Errorf("ParseRisk(%q) returned %q, want empty Risk on error", s, got) + } + } +} + +func TestParseIdentity(t *testing.T) { + got, err := platform.ParseIdentity("") + if err != nil || got != "" { + t.Errorf(`ParseIdentity("") = (%q,%v), want ("",nil)`, got, err) + } + for _, want := range []platform.Identity{platform.IdentityUser, platform.IdentityBot} { + got, err := platform.ParseIdentity(string(want)) + if err != nil || got != want { + t.Errorf("ParseIdentity(%q) = (%q,%v)", want, got, err) + } + } + if _, err := platform.ParseIdentity("admin"); err == nil { + t.Fatalf(`ParseIdentity("admin") want error`) + } +} + +func TestIdentity_IsValid(t *testing.T) { + if !platform.IdentityUser.IsValid() { + t.Error("user.IsValid() = false") + } + if !platform.IdentityBot.IsValid() { + t.Error("bot.IsValid() = false") + } + if platform.Identity("admin").IsValid() { + t.Error("admin.IsValid() = true") + } +} diff --git a/extension/platform/rule.go b/extension/platform/rule.go new file mode 100644 index 000000000..cf5ecebaf --- /dev/null +++ b/extension/platform/rule.go @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +// Rule is the declarative policy rule data structure. yaml files and +// Plugin.Restrict() both produce the same Rule. +// +// At any moment there is at most one effective Rule -- the resolver decides +// which source wins (Plugin > yaml > none). This package only defines the +// shape; selection lives in internal/cmdpolicy. +// +// The four filter fields are joined by AND. See the engine's Evaluate for +// the full semantics. JSON tags are used by `config policy show`; yaml +// parsing lives in internal/cmdpolicy/yaml so the public API does not +// depend on a yaml library. +type Rule struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + + // Allow is a list of doublestar globs (slash-separated paths). An empty + // slice means "no path restriction"; a non-empty slice means "command + // path must match at least one glob". + Allow []string `json:"allow,omitempty"` + + // Deny is a list of doublestar globs. A path that matches any Deny glob + // is rejected regardless of Allow. + Deny []string `json:"deny,omitempty"` + + // MaxRisk is the highest allowed risk level (inclusive). Empty string + // means "no risk restriction". Comparison uses the closed taxonomy + // read < write < high-risk-write. + MaxRisk Risk `json:"max_risk,omitempty"` + + // Identities is the allowed identity whitelist. A command passes when + // the intersection with the command's own supported identities is + // non-empty. Empty slice means "no identity restriction". + Identities []Identity `json:"identities,omitempty"` + + // AllowUnannotated controls how commands missing a risk_level + // annotation are handled when this Rule is active. + // + // Default (false, fail-closed): unannotated commands are rejected + // with reason_code=risk_not_annotated. This is the safe default + // — a typo'd or forgotten annotation cannot slip past an + // "agent read-only" rule. + // + // Set to true to opt out during gradual adoption: lark-cli main + // has hundreds of service commands that may not yet carry + // risk_level annotations, and a brand-new policy plugin would + // otherwise lock the binary to nothing. + // + // This flag does NOT affect risk_invalid (typos): a command that + // claims a risk but mis-spells it is always denied, regardless of + // AllowUnannotated. Typo is a code bug, not a migration phase. + // + // No yaml tag: yaml decoding lives in internal/cmdpolicy/yaml so + // platform stays free of a yaml library dependency. + AllowUnannotated bool `json:"allow_unannotated,omitempty"` +} diff --git a/extension/platform/selector.go b/extension/platform/selector.go new file mode 100644 index 000000000..0e632537f --- /dev/null +++ b/extension/platform/selector.go @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +import "github.com/bmatcuk/doublestar/v4" + +// Selector picks the commands a hook fires on. A nil Selector is +// equivalent to None() -- safer than an "always-match" default because +// it forces every hook to declare its scope explicitly. Compose +// selectors with And / Or / Not. +type Selector func(cmd CommandView) bool + +// All matches every command. Use for audit / metrics observers that +// must run on the whole surface. +func All() Selector { return func(CommandView) bool { return true } } + +// None matches no command. Useful as a "disabled" placeholder. +func None() Selector { return func(CommandView) bool { return false } } + +// ByDomain matches a command whose Domain() is one of the supplied +// names. Commands with unknown (empty-string) Domain never match this +// selector -- the caller should pair it with a Selector that handles +// unknown explicitly when that case matters. +func ByDomain(domains ...string) Selector { + wanted := newStringSet(domains) + return func(cmd CommandView) bool { + d := cmd.Domain() + return d != "" && wanted[d] + } +} + +// ByCommandPath matches against the canonical slash-form path. Patterns +// are doublestar globs ("docs/+update", "im/*", "**"). Invalid patterns +// never match; ValidateRule's twin check catches them at the source. +func ByCommandPath(patterns ...string) Selector { + return func(cmd CommandView) bool { + path := cmd.Path() + for _, p := range patterns { + if ok, err := doublestar.Match(p, path); err == nil && ok { + return true + } + } + return false + } +} + +// ByIdentity matches when the command's supported identities include +// the supplied id. Unknown identities never match. +func ByIdentity(id Identity) Selector { + return func(cmd CommandView) bool { + for _, x := range cmd.Identities() { + if x == id { + return true + } + } + return false + } +} + +// Risk-based selectors below match only commands whose declared risk +// equals the selector's target level. The closed taxonomy is read / +// write / high-risk-write — there is no "unknown" branch in the public +// API. When a Rule without AllowUnannotated=true is registered, the +// policy engine treats unannotated commands as implicit deny, so risk- +// based selectors never see them in hook dispatch under that +// configuration. + +// ByExactRisk matches commands whose declared risk level is exactly level. +func ByExactRisk(level Risk) Selector { + return func(cmd CommandView) bool { + v, ok := cmd.Risk() + return ok && v == level + } +} + +// ByWrite matches commands whose risk is "write" or "high-risk-write". +func ByWrite() Selector { + return func(cmd CommandView) bool { + v, ok := cmd.Risk() + return ok && (v == RiskWrite || v == RiskHighRiskWrite) + } +} + +// ByReadOnly matches commands whose risk is "read". +func ByReadOnly() Selector { + return func(cmd CommandView) bool { + v, ok := cmd.Risk() + return ok && v == RiskRead + } +} + +// normalize maps a nil Selector to None() so combinators honour the +// "nil == None()" contract documented on the Selector type. +func normalize(s Selector) Selector { + if s == nil { + return None() + } + return s +} + +// And composes selectors with AND semantics. +func (s Selector) And(other Selector) Selector { + left, right := normalize(s), normalize(other) + return func(cmd CommandView) bool { + return left(cmd) && right(cmd) + } +} + +// Or composes selectors with OR semantics. +func (s Selector) Or(other Selector) Selector { + left, right := normalize(s), normalize(other) + return func(cmd CommandView) bool { + return left(cmd) || right(cmd) + } +} + +// Not negates the selector. A nil receiver is treated as None(), so +// nil.Not() behaves as All(). +func (s Selector) Not() Selector { + inner := normalize(s) + return func(cmd CommandView) bool { + return !inner(cmd) + } +} + +func newStringSet(items []string) map[string]bool { + out := make(map[string]bool, len(items)) + for _, x := range items { + out[x] = true + } + return out +} diff --git a/extension/platform/selector_test.go b/extension/platform/selector_test.go new file mode 100644 index 000000000..f08b0c660 --- /dev/null +++ b/extension/platform/selector_test.go @@ -0,0 +1,161 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform_test + +import ( + "testing" + + "github.com/larksuite/cli/extension/platform" +) + +// fakeView is a minimal CommandView for unit-testing selectors. +type fakeView struct { + path string + domain string + risk string + riskOK bool + identities []string +} + +func (v fakeView) Path() string { return v.path } +func (v fakeView) Domain() string { return v.domain } +func (v fakeView) Risk() (platform.Risk, bool) { return platform.Risk(v.risk), v.riskOK } +func (v fakeView) Identities() []platform.Identity { + out := make([]platform.Identity, len(v.identities)) + for i, x := range v.identities { + out[i] = platform.Identity(x) + } + return out +} +func (v fakeView) Annotation(key string) (string, bool) { return "", false } + +func TestAll_None(t *testing.T) { + cmd := fakeView{} + if !platform.All()(cmd) { + t.Errorf("All() must match every command") + } + if platform.None()(cmd) { + t.Errorf("None() must match no command") + } +} + +func TestByDomain(t *testing.T) { + sel := platform.ByDomain("docs", "im") + if !sel(fakeView{domain: "docs"}) { + t.Errorf("docs should match") + } + if sel(fakeView{domain: "vc"}) { + t.Errorf("vc must not match docs/im selector") + } + // Unknown domain (empty) must not match. + if sel(fakeView{domain: ""}) { + t.Errorf("unknown domain must not match ByDomain (use ByDomainOrUnknown style if desired)") + } +} + +// Risk-based selectors match only against the closed taxonomy +// (read / write / high-risk-write). Commands without a risk annotation +// never match; the policy engine guarantees such commands cannot reach +// hook dispatch when a Rule without AllowUnannotated=true is registered. +func TestByExactRisk_unknownDoesNotMatch(t *testing.T) { + sel := platform.ByExactRisk("write") + if !sel(fakeView{risk: "write", riskOK: true}) { + t.Errorf("exact write should match") + } + if sel(fakeView{riskOK: false}) { + t.Errorf("unknown must not match ByExactRisk") + } + if sel(fakeView{risk: "read", riskOK: true}) { + t.Errorf("read must not match ByExactRisk(write)") + } +} + +func TestByWrite_byReadOnly(t *testing.T) { + if !platform.ByWrite()(fakeView{risk: "write", riskOK: true}) { + t.Errorf("write should match ByWrite") + } + if !platform.ByWrite()(fakeView{risk: "high-risk-write", riskOK: true}) { + t.Errorf("high-risk-write should match ByWrite") + } + if platform.ByWrite()(fakeView{risk: "read", riskOK: true}) { + t.Errorf("read must not match ByWrite") + } + if platform.ByWrite()(fakeView{riskOK: false}) { + t.Errorf("unknown must not match ByWrite") + } + if !platform.ByReadOnly()(fakeView{risk: "read", riskOK: true}) { + t.Errorf("read should match ByReadOnly") + } + if platform.ByReadOnly()(fakeView{riskOK: false}) { + t.Errorf("unknown must not match ByReadOnly") + } +} + +func TestByCommandPath(t *testing.T) { + sel := platform.ByCommandPath("docs/**", "im/+send") + if !sel(fakeView{path: "docs/+update"}) { + t.Errorf("docs/+update should match docs/**") + } + if !sel(fakeView{path: "im/+send"}) { + t.Errorf("im/+send should match") + } + if sel(fakeView{path: "contact/+search"}) { + t.Errorf("contact/+search must not match") + } +} + +func TestByIdentity(t *testing.T) { + sel := platform.ByIdentity("bot") + if !sel(fakeView{identities: []string{"user", "bot"}}) { + t.Errorf("ids containing bot should match") + } + if sel(fakeView{identities: []string{"user"}}) { + t.Errorf("user-only ids must not match bot selector") + } +} + +func TestSelector_AndOrNot(t *testing.T) { + docsAndWrite := platform.ByDomain("docs").And(platform.ByExactRisk("write")) + if !docsAndWrite(fakeView{domain: "docs", risk: "write", riskOK: true}) { + t.Errorf("AND of matching selectors should match") + } + if docsAndWrite(fakeView{domain: "docs", risk: "read", riskOK: true}) { + t.Errorf("AND fails when one side fails") + } + + docsOrIm := platform.ByDomain("docs").Or(platform.ByDomain("im")) + if !docsOrIm(fakeView{domain: "im"}) { + t.Errorf("OR should match either side") + } + + notRead := platform.ByReadOnly().Not() + if notRead(fakeView{risk: "read", riskOK: true}) { + t.Errorf("Not(ByReadOnly) must reject read commands") + } + if !notRead(fakeView{risk: "write", riskOK: true}) { + t.Errorf("Not(ByReadOnly) should match write") + } +} + +func TestSelector_NilSafeWhenComposed(t *testing.T) { + // A nil Selector is equivalent to None() per the Selector godoc. + // Composition must honour that contract: the resulting selector + // must not panic when invoked and must produce the documented + // boolean outcome (nil-as-None propagates through AND/OR/NOT). + var s platform.Selector + cmd := fakeView{domain: "docs"} + + if got := s.And(platform.All())(cmd); got { + t.Errorf("nil.And(All) should match None semantics (false), got true") + } + if got := s.Or(platform.All())(cmd); !got { + t.Errorf("nil.Or(All) should match (true), got false") + } + if got := platform.All().And(s)(cmd); got { + t.Errorf("All.And(nil) should be None (false), got true") + } + if got := s.Not()(cmd); !got { + t.Errorf("(nil).Not() should be Not(None) = true, got false") + } +} diff --git a/extension/platform/view.go b/extension/platform/view.go new file mode 100644 index 000000000..f7ef3e885 --- /dev/null +++ b/extension/platform/view.go @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package platform + +// CommandView is the read-only view of a cobra.Command exposed to plugins +// and the policy engine. *cobra.Command is deliberately NOT reachable +// through this interface -- a plugin should never mutate the command tree. +// +// View semantics: +// +// - The view is a live proxy over the underlying *cobra.Command and its +// annotation chain. Strict-mode replaces nodes via RemoveCommand+ +// AddCommand; the replacement stub explicitly carries the original +// command's annotations and help text forward so audit / compliance +// observers still see Risk / Identities / Domain after a denial. +// User-layer policy mutates in place, so its denyStubs preserve the +// original metadata by construction. +// +// - Path() is the canonical slash form ("docs/+fetch"), matching the +// doublestar glob semantics used by Rule.Allow / Rule.Deny. +// +// - Risk() returns ok=false when the command is unannotated. The policy +// engine treats an unannotated command as implicit deny whenever any +// Rule without AllowUnannotated=true is registered, so risk-based +// Selectors never see unannotated commands during normal hook dispatch +// under that configuration. +type CommandView interface { + // Path is the canonical slash-separated path, rootless ("docs/+update"). + Path() string + + // Domain returns the business domain ("docs", "im", "") inherited from + // the nearest ancestor with a cmdmeta.domain annotation. Empty string + // when no ancestor declares one. + Domain() string + + // Risk returns the static risk level. ok=false signals "no risk_level + // annotation found in the parent chain" (unknown). + Risk() (level Risk, ok bool) + + // Identities returns the supported identities. nil signals "no + // supportedIdentities annotation in the parent chain". + Identities() []Identity + + // Annotation exposes the raw cobra annotation map for plugins that + // need a tag the framework does not surface. + Annotation(key string) (string, bool) +} diff --git a/go.mod b/go.mod index 770cdf589..1ee4b73cc 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,14 @@ go 1.23.0 require ( github.com/Microsoft/go-winio v0.6.2 + github.com/bmatcuk/doublestar/v4 v4.10.0 github.com/charmbracelet/huh v1.0.0 github.com/charmbracelet/lipgloss v1.1.0 github.com/gofrs/flock v0.8.1 github.com/google/uuid v1.6.0 github.com/itchyny/gojq v0.12.17 github.com/larksuite/oapi-sdk-go/v3 v3.5.4 + github.com/sergi/go-diff v1.4.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/smartystreets/goconvey v1.8.1 github.com/spf13/cobra v1.10.2 @@ -18,9 +20,11 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/zalando/go-keyring v0.2.8 golang.org/x/net v0.33.0 + golang.org/x/sync v0.15.0 golang.org/x/sys v0.33.0 golang.org/x/term v0.27.0 golang.org/x/text v0.23.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -59,6 +63,4 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - golang.org/x/sync v0.15.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 451a3591d..0e68ff05a 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY= github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E= +github.com/bmatcuk/doublestar/v4 v4.10.0 h1:zU9WiOla1YA122oLM6i4EXvGW62DvKZVxIe6TYWexEs= +github.com/bmatcuk/doublestar/v4 v4.10.0/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY= github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc= github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 h1:JFgG/xnwFfbezlUnFMJy0nusZvytYysV4SCS2cYbvws= @@ -43,6 +45,7 @@ github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/danieljoos/wincred v1.2.3 h1:v7dZC2x32Ut3nEfRH+vhoZGvN72+dQ/snVXo/vMFLdQ= github.com/danieljoos/wincred v1.2.3/go.mod h1:6qqX0WNrS4RzPZ1tnroDzq9kY3fu1KwE7MRLQK4X0bs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= @@ -71,6 +74,11 @@ github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7 github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/larksuite/oapi-sdk-go/v3 v3.5.4 h1:U2S9x9LrfH++ZqJ+YAiUlqzCWJmVXhFdS8Z7rIBH8H0= github.com/larksuite/oapi-sdk-go/v3 v3.5.4/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= @@ -95,6 +103,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= +github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= @@ -105,8 +115,10 @@ github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= @@ -161,7 +173,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/binding/lark_channel.go b/internal/binding/lark_channel.go index 511f19dd7..f80afb53a 100644 --- a/internal/binding/lark_channel.go +++ b/internal/binding/lark_channel.go @@ -15,6 +15,11 @@ import ( // Unknown fields are ignored — forward-compatible with future bridge versions. type LarkChannelRoot struct { Accounts LarkChannelAccounts `json:"accounts"` + // Secrets is an optional registry of secret providers — same shape as + // openclaw's `secrets` block. Lets bridge declare `exec` provider scripts + // (for AES-encrypted secret backends), `env` allowlists, or `file` + // indirection rules. Resolved by binding.ResolveSecretInput. + Secrets *SecretsConfig `json:"secrets,omitempty"` } // LarkChannelAccounts is the namespace for credential entries. @@ -26,13 +31,17 @@ type LarkChannelAccounts struct { } // LarkChannelApp is the bot app credential entry. -// Bridge stores the secret as plain text — secret-resolve indirection -// (${VAR} / file: / exec:) is intentionally not supported here, matching -// the bridge's on-disk format. +// +// `Secret` accepts the full SecretInput protocol (string / "${VAR}" template / +// SecretRef object with source env|file|exec) so users can keep secrets out +// of config.json — either by referencing an env var the bridge inherits, a +// chmod-0400 file outside the bridge dir, or an exec script that decrypts a +// local AES-encrypted secret store. Aligns lark-channel with the same secret +// protocol openclaw already uses. type LarkChannelApp struct { - ID string `json:"id"` - Secret string `json:"secret"` - Tenant string `json:"tenant"` // "feishu" | "lark" + ID string `json:"id"` + Secret SecretInput `json:"secret"` + Tenant string `json:"tenant"` // "feishu" | "lark" } // ReadLarkChannelConfig reads and parses ~/.lark-channel/config.json. diff --git a/internal/binding/lark_channel_test.go b/internal/binding/lark_channel_test.go index 2883144b5..4908556b4 100644 --- a/internal/binding/lark_channel_test.go +++ b/internal/binding/lark_channel_test.go @@ -24,8 +24,11 @@ func TestReadLarkChannelConfig_Valid(t *testing.T) { if got := root.Accounts.App.ID; got != "cli_abc123" { t.Errorf("ID = %q, want %q", got, "cli_abc123") } - if got := root.Accounts.App.Secret; got != "plain_secret" { - t.Errorf("Secret = %q, want %q", got, "plain_secret") + if got := root.Accounts.App.Secret.Plain; got != "plain_secret" { + t.Errorf("Secret.Plain = %q, want %q", got, "plain_secret") + } + if root.Accounts.App.Secret.Ref != nil { + t.Errorf("expected Plain form, got SecretRef = %+v", root.Accounts.App.Secret.Ref) } if got := root.Accounts.App.Tenant; got != "feishu" { t.Errorf("Tenant = %q, want %q", got, "feishu") @@ -92,8 +95,74 @@ func TestReadLarkChannelConfig_PartialFields(t *testing.T) { if root.Accounts.App.ID != "" { t.Errorf("expected empty ID, got %q", root.Accounts.App.ID) } - if root.Accounts.App.Secret != "" { - t.Errorf("expected empty Secret, got %q", root.Accounts.App.Secret) + if !root.Accounts.App.Secret.IsZero() { + t.Errorf("expected zero Secret, got %+v", root.Accounts.App.Secret) + } +} + +func TestReadLarkChannelConfig_SecretEnvTemplate(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "config.json") + data := `{"accounts":{"app":{"id":"cli_a","secret":"${LARK_APP_SECRET}","tenant":"feishu"}}}` + if err := os.WriteFile(p, []byte(data), 0o600); err != nil { + t.Fatalf("write temp file: %v", err) + } + root, err := ReadLarkChannelConfig(p) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := root.Accounts.App.Secret.Plain; got != "${LARK_APP_SECRET}" { + t.Errorf("Secret.Plain = %q, want template string", got) + } +} + +func TestReadLarkChannelConfig_SecretRefExec(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "config.json") + data := `{ + "accounts": { + "app": { + "id": "cli_a", + "secret": {"source": "exec", "provider": "decrypt", "id": "app-cli_a"}, + "tenant": "feishu" + } + }, + "secrets": { + "providers": { + "decrypt": {"source": "exec", "command": "/usr/local/bin/lark-channel-bridge", "args": ["secrets", "get"]} + } + } + }` + if err := os.WriteFile(p, []byte(data), 0o600); err != nil { + t.Fatalf("write temp file: %v", err) + } + root, err := ReadLarkChannelConfig(p) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if root.Accounts.App.Secret.Ref == nil { + t.Fatal("expected SecretRef, got Plain") + } + if got := root.Accounts.App.Secret.Ref.Source; got != "exec" { + t.Errorf("Secret.Ref.Source = %q, want %q", got, "exec") + } + if got := root.Accounts.App.Secret.Ref.ID; got != "app-cli_a" { + t.Errorf("Secret.Ref.ID = %q, want %q", got, "app-cli_a") + } + if root.Secrets == nil || root.Secrets.Providers["decrypt"] == nil { + t.Errorf("expected secrets.providers[decrypt] to be parsed") + } +} + +func TestReadLarkChannelConfig_SecretRefInvalidSource(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "config.json") + data := `{"accounts":{"app":{"id":"cli_a","secret":{"source":"bogus","id":"x"},"tenant":"feishu"}}}` + if err := os.WriteFile(p, []byte(data), 0o600); err != nil { + t.Fatalf("write temp file: %v", err) + } + if _, err := ReadLarkChannelConfig(p); err == nil { + t.Fatal("expected error for invalid secret source, got nil") } } diff --git a/internal/cmdmeta/meta.go b/internal/cmdmeta/meta.go new file mode 100644 index 000000000..f0a9ea6b4 --- /dev/null +++ b/internal/cmdmeta/meta.go @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +// Package cmdmeta is the single source of truth for command metadata that the +// policy engine and the hook selector both consume. It wraps the existing +// cmdutil annotations (risk_level, supportedIdentities) and adds the +// "domain" axis that the hook selector and Rule path globs need. +// +// Three axes: +// +// - Domain - business domain ("im", "docs", "contact", ...). Inherited +// from the nearest ancestor when not set on the command +// itself. Stored on a new annotation key (the cmdutil +// risk_level / supportedIdentities keys are left untouched +// for backward compatibility). +// - Risk - "read" | "write" | "high-risk-write". Inherited like +// Domain. Reuses cmdutil.SetRisk / GetRisk under the hood. +// - Identities - allowed identity set. Child explicit override semantics: +// the first ancestor (including self) with a non-nil set +// wins. Reuses cmdutil.SetSupportedIdentities / +// GetSupportedIdentities. +// +// Missing values are returned as the zero value with ok=false (where the +// signature exposes it). Interpretation is up to the consumer: the policy +// engine treats a missing risk as fail-closed when a Rule is registered +// without AllowUnannotated=true, and as allow otherwise. Identities still +// defaults to ALLOW. Do not synthesise defaults here -- let each consumer +// decide. +package cmdmeta + +import ( + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdutil" +) + +// domainAnnotationKey is the cobra Annotation key for the business domain. +// Kept distinct from cmdutil.* keys so this package can evolve without +// disturbing existing readers. +const domainAnnotationKey = "cmdmeta.domain" + +// Meta groups the three command-level metadata axes consumed by the policy +// engine and hook selectors. +type Meta struct { + Domain string + Risk string + Identities []string +} + +// Apply writes metadata onto a cobra command. Empty fields are skipped: pass +// the value via the underlying cmdutil setter if you need to write an empty +// string / empty slice explicitly. +func Apply(cmd *cobra.Command, m Meta) { + if m.Domain != "" { + SetDomain(cmd, m.Domain) + } + if m.Risk != "" { + cmdutil.SetRisk(cmd, m.Risk) + } + if m.Identities != nil { + cmdutil.SetSupportedIdentities(cmd, m.Identities) + } +} + +// Get resolves the effective metadata for a command, walking up the parent +// chain for Domain, Risk, and Identities. All three axes use the same +// nearest-ancestor-wins rule. +// +// Identities note: cmdutil.GetSupportedIdentities collapses both the +// "annotation absent" and "annotation set to empty string" cases to nil. +// A child cannot therefore express "deny inheritance" with an empty +// annotation; the walk simply continues up the parent chain when nil is +// returned. To override a parent, the child must set a non-empty slice +// (e.g. ["bot"]). +func Get(cmd *cobra.Command) Meta { + risk, _ := Risk(cmd) + return Meta{ + Domain: Domain(cmd), + Risk: risk, + Identities: Identities(cmd), + } +} + +// SetDomain stores the domain annotation on a single command (no +// inheritance is performed on write). +func SetDomain(cmd *cobra.Command, domain string) { + if domain == "" { + return + } + if cmd.Annotations == nil { + cmd.Annotations = map[string]string{} + } + cmd.Annotations[domainAnnotationKey] = domain +} + +// Domain returns the nearest-ancestor domain for the command. Empty string +// when no ancestor has the annotation -- this is the "unknown" state the +// policy engine must treat as ALLOW. +func Domain(cmd *cobra.Command) string { + for c := cmd; c != nil; c = c.Parent() { + if c.Annotations == nil { + continue + } + if v, ok := c.Annotations[domainAnnotationKey]; ok && v != "" { + return v + } + } + return "" +} + +// Risk returns the nearest-ancestor risk level (via cmdutil.GetRisk). +// ok=false signals "unknown" -- the policy engine treats this as +// fail-closed (deny with risk_not_annotated) whenever a Rule without +// AllowUnannotated=true is active, and as allow otherwise. +func Risk(cmd *cobra.Command) (level string, ok bool) { + for c := cmd; c != nil; c = c.Parent() { + if level, ok = cmdutil.GetRisk(c); ok { + return level, true + } + } + return "", false +} + +// Identities returns the first non-nil identity set found while walking up +// the parent chain. nil signals "unknown" -- the policy engine treats this +// as ALLOW. +// +// cmdutil.GetSupportedIdentities returns nil when the annotation is absent +// or empty; an explicit non-empty set (even ["user"] alone) stops the walk. +func Identities(cmd *cobra.Command) []string { + for c := cmd; c != nil; c = c.Parent() { + if ids := cmdutil.GetSupportedIdentities(c); ids != nil { + return ids + } + } + return nil +} diff --git a/internal/cmdmeta/meta_test.go b/internal/cmdmeta/meta_test.go new file mode 100644 index 000000000..61e831319 --- /dev/null +++ b/internal/cmdmeta/meta_test.go @@ -0,0 +1,143 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdmeta_test + +import ( + "reflect" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdmeta" + "github.com/larksuite/cli/internal/cmdutil" +) + +func TestApply_writesAllFields(t *testing.T) { + cmd := &cobra.Command{Use: "fetch"} + cmdmeta.Apply(cmd, cmdmeta.Meta{ + Domain: "docs", + Risk: "write", + Identities: []string{"user", "bot"}, + }) + + if got := cmdmeta.Domain(cmd); got != "docs" { + t.Fatalf("Domain = %q, want %q", got, "docs") + } + if got, ok := cmdmeta.Risk(cmd); !ok || got != "write" { + t.Fatalf("Risk = (%q,%v), want (%q,true)", got, ok, "write") + } + if got := cmdmeta.Identities(cmd); !reflect.DeepEqual(got, []string{"user", "bot"}) { + t.Fatalf("Identities = %v, want [user bot]", got) + } +} + +func TestApply_emptyFieldsSkipped(t *testing.T) { + cmd := &cobra.Command{Use: "fetch"} + cmdmeta.Apply(cmd, cmdmeta.Meta{}) // nothing + if got := cmdmeta.Domain(cmd); got != "" { + t.Fatalf("Domain expected unset, got %q", got) + } + if _, ok := cmdmeta.Risk(cmd); ok { + t.Fatalf("Risk expected unset") + } + if got := cmdmeta.Identities(cmd); got != nil { + t.Fatalf("Identities expected nil, got %v", got) + } +} + +// Domain inherits from the nearest ancestor; risk and identities behave the +// same way. We verify each axis with a 3-level tree: +// +// root (domain=docs, risk=read, identities=[user]) +// group +// leaf +func TestGet_inheritsFromAncestor(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + group := &cobra.Command{Use: "docs"} + leaf := &cobra.Command{Use: "fetch"} + root.AddCommand(group) + group.AddCommand(leaf) + + cmdmeta.Apply(root, cmdmeta.Meta{ + Domain: "docs", + Risk: "read", + Identities: []string{"user"}, + }) + + got := cmdmeta.Get(leaf) + want := cmdmeta.Meta{ + Domain: "docs", + Risk: "read", + Identities: []string{"user"}, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("Get(leaf) = %+v, want %+v", got, want) + } +} + +// Closest ancestor wins -- a mid-level override is preferred over root. +func TestGet_nearestAncestorWins(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + group := &cobra.Command{Use: "docs"} + leaf := &cobra.Command{Use: "fetch"} + root.AddCommand(group) + group.AddCommand(leaf) + + cmdmeta.SetDomain(root, "docs") + cmdmeta.SetDomain(group, "docs-override") + cmdutil.SetRisk(root, "read") + cmdutil.SetRisk(group, "high-risk-write") + + if got := cmdmeta.Domain(leaf); got != "docs-override" { + t.Fatalf("Domain = %q, want docs-override (nearest)", got) + } + if got, _ := cmdmeta.Risk(leaf); got != "high-risk-write" { + t.Fatalf("Risk = %q, want high-risk-write (nearest)", got) + } +} + +// Unknown axes return zero / nil so the policy engine can apply the +// "unknown => ALLOW" contract. +func TestGet_unknownReturnsZero(t *testing.T) { + cmd := &cobra.Command{Use: "orphan"} + if got := cmdmeta.Domain(cmd); got != "" { + t.Fatalf("Domain = %q, want empty for unknown", got) + } + if level, ok := cmdmeta.Risk(cmd); ok || level != "" { + t.Fatalf("Risk = (%q,%v), want empty / false for unknown", level, ok) + } + if ids := cmdmeta.Identities(cmd); ids != nil { + t.Fatalf("Identities = %v, want nil for unknown", ids) + } +} + +// Child explicitly overriding identities stops the parent walk. +func TestIdentities_childOverridesParent(t *testing.T) { + parent := &cobra.Command{Use: "docs"} + child := &cobra.Command{Use: "preview"} + parent.AddCommand(child) + + cmdutil.SetSupportedIdentities(parent, []string{"user", "bot"}) + cmdutil.SetSupportedIdentities(child, []string{"bot"}) + + got := cmdmeta.Identities(child) + if !reflect.DeepEqual(got, []string{"bot"}) { + t.Fatalf("Identities(child) = %v, want [bot]", got) + } +} + +// SetDomain with empty value is a no-op (no annotation written, so a +// later inherited read still works). +func TestSetDomain_emptyIsNoop(t *testing.T) { + parent := &cobra.Command{Use: "docs"} + cmdmeta.SetDomain(parent, "docs") + + child := &cobra.Command{Use: "fetch"} + parent.AddCommand(child) + + cmdmeta.SetDomain(child, "") // no-op + if got := cmdmeta.Domain(child); got != "docs" { + t.Fatalf("Domain(child) = %q, want inherited 'docs'", got) + } +} diff --git a/internal/cmdpolicy/active.go b/internal/cmdpolicy/active.go new file mode 100644 index 000000000..488d641c1 --- /dev/null +++ b/internal/cmdpolicy/active.go @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy + +import ( + "sync" + + "github.com/larksuite/cli/extension/platform" +) + +// ActivePolicy is the resolved user-layer policy after applyUserPolicyPruning +// has run during bootstrap. `lark-cli config policy show` reads this to +// answer "what rule is currently in effect, and how many commands does +// it hide?". +// +// Set once at bootstrap time; consumed read-only thereafter. +type ActivePolicy struct { + Rule *platform.Rule + Source ResolveSource + DeniedPaths int // number of commands the engine marked as denied (post-aggregation) +} + +var ( + activeMu sync.RWMutex + activePolicy *ActivePolicy +) + +// SetActive records the policy that ends up applied. Called exactly once +// per process from cmd/policy.go::applyUserPolicyPruning. The mutex is +// belt-and-braces in case future test paths interleave with bootstrap. +// +// A deep copy is taken so the snapshot is immune to later mutations of +// the input by the caller (a plugin-supplied *Rule could otherwise +// mutate the embedded Allow/Deny/Identities slices after we stored it). +func SetActive(p *ActivePolicy) { + activeMu.Lock() + defer activeMu.Unlock() + if p == nil { + activePolicy = nil + return + } + activePolicy = cloneActivePolicy(p) +} + +// GetActive returns a deep copy of the recorded policy, or nil if +// bootstrap has not finished or no rule applied. Callers can freely +// mutate the result — including the embedded Rule slices — without +// affecting the stored global. +func GetActive() *ActivePolicy { + activeMu.RLock() + defer activeMu.RUnlock() + if activePolicy == nil { + return nil + } + return cloneActivePolicy(activePolicy) +} + +// cloneActivePolicy deep-copies the top-level struct plus the embedded +// Rule's slice fields. Other fields (Source, DeniedPaths) are value +// types so the struct copy already disjoints them. +func cloneActivePolicy(in *ActivePolicy) *ActivePolicy { + if in == nil { + return nil + } + cp := *in + if in.Rule != nil { + rule := *in.Rule + rule.Allow = append([]string(nil), in.Rule.Allow...) + rule.Deny = append([]string(nil), in.Rule.Deny...) + rule.Identities = append([]platform.Identity(nil), in.Rule.Identities...) + cp.Rule = &rule + } + return &cp +} + +// ResetActiveForTesting clears the recorded policy. Tests must call this +// in t.Cleanup when they exercise the bootstrap path. +func ResetActiveForTesting() { + activeMu.Lock() + defer activeMu.Unlock() + activePolicy = nil +} diff --git a/internal/cmdpolicy/aggregation_test.go b/internal/cmdpolicy/aggregation_test.go new file mode 100644 index 000000000..59384952a --- /dev/null +++ b/internal/cmdpolicy/aggregation_test.go @@ -0,0 +1,364 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy_test + +import ( + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/output" +) + +// EvaluateAll must skip non-runnable parent groups (their decision is +// derived in the aggregation pass). The previous regression: an +// Allow:["docs/**"] rule incorrectly denied the parent "docs" group too, +// because the parent's own path "docs" did not match "docs/**". +func TestEvaluateAll_skipsPureGroups(t *testing.T) { + root := buildTree() // docs and im are pure groups, +fetch / +update / +send are leaves + e := cmdpolicy.New(&platform.Rule{Allow: []string{"docs/**"}}) + got := e.EvaluateAll(root) + + if _, present := got["docs"]; present { + t.Errorf("parent group 'docs' should not appear in Decisions (Allow=docs/**)") + } + if _, present := got["im"]; present { + t.Errorf("parent group 'im' should not appear in Decisions") + } + + // Children still evaluated normally. + if !got["docs/+fetch"].Allowed { + t.Errorf("docs/+fetch should still be allowed by docs/**") + } +} + +// BuildDeniedByPath must aggregate: a parent group whose every runnable +// child is denied must itself get an aggregated Denial in the map. +func TestBuildDeniedByPath_parentAggregationAllChildrenDenied(t *testing.T) { + // Custom tree where ALL children of "im" will be denied. + root := &cobra.Command{Use: "lark-cli"} + im := &cobra.Command{Use: "im"} + root.AddCommand(im) + send := &cobra.Command{Use: "+send", RunE: noop} + cmdutil.SetRisk(send, "write") + im.AddCommand(send) + search := &cobra.Command{Use: "+search", RunE: noop} + cmdutil.SetRisk(search, "read") + im.AddCommand(search) + + // Risk is set on both leaves so the rejection comes from the Allow + // axis (the contract this test pins), not from the risk gate. + e := cmdpolicy.New(&platform.Rule{Allow: []string{"docs/**"}}) // none of im/* matches + decisions := e.EvaluateAll(root) + + // Pin the rejection axis: both leaves are rejected by Allow miss, + // NOT by the risk_not_annotated gate. If a future edit drops the + // SetRisk lines above, this assertion fails and the test stops + // silently testing the wrong axis. + if rc := decisions["im/+send"].ReasonCode; rc != "domain_not_allowed" { + t.Errorf("im/+send ReasonCode = %q, want domain_not_allowed", rc) + } + if rc := decisions["im/+search"].ReasonCode; rc != "domain_not_allowed" { + t.Errorf("im/+search ReasonCode = %q, want domain_not_allowed", rc) + } + + denied := cmdpolicy.BuildDeniedByPath(root, decisions, + cmdpolicy.ResolveSource{Kind: cmdpolicy.SourceYAML, Name: "/policy.yml"}, "agent") + + // Both leaves denied. + if _, ok := denied["im/+send"]; !ok { + t.Errorf("im/+send should be in denied map") + } + if _, ok := denied["im/+search"]; !ok { + t.Errorf("im/+search should be in denied map") + } + // Parent must be aggregated. + parent, ok := denied["im"] + if !ok { + t.Fatalf("parent 'im' should be aggregated into denied map") + } + if parent.Layer != "policy" { + t.Errorf("parent.Layer = %q, want pruning", parent.Layer) + } +} + +// Partial children-denied means parent stays UN-denied. This is the +// counter-case to the previous regression: docs/** allowed children stays +// alive even if some siblings are denied. +func TestBuildDeniedByPath_partialDenialKeepsParent(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + docs := &cobra.Command{Use: "docs"} + root.AddCommand(docs) + + fetch := &cobra.Command{Use: "+fetch", RunE: noop} + cmdutil.SetRisk(fetch, "read") + docs.AddCommand(fetch) // allowed + + delete := &cobra.Command{Use: "+delete", RunE: noop} + cmdutil.SetRisk(delete, "high-risk-write") + docs.AddCommand(delete) // denied by Deny + + e := cmdpolicy.New(&platform.Rule{ + Allow: []string{"docs/**"}, + Deny: []string{"docs/+delete"}, + }) + denied := cmdpolicy.BuildDeniedByPath(root, e.EvaluateAll(root), + cmdpolicy.ResolveSource{Kind: cmdpolicy.SourcePlugin, Name: "secaudit"}, "secaudit-policy") + + if _, ok := denied["docs"]; ok { + t.Errorf("parent 'docs' must NOT be denied when some children are allowed") + } + if _, ok := denied["docs/+fetch"]; ok { + t.Errorf("docs/+fetch should not be in denied map (it's allowed)") + } + if _, ok := denied["docs/+delete"]; !ok { + t.Errorf("docs/+delete should be denied (in Deny)") + } +} + +// The binary root is never installed with a denyStub even when all its +// descendants are denied -- the entry point must remain dispatchable. +func TestBuildDeniedByPath_rootNeverDenied(t *testing.T) { + root := buildTree() + e := cmdpolicy.New(&platform.Rule{Allow: []string{"nonexistent/**"}}) + denied := cmdpolicy.BuildDeniedByPath(root, e.EvaluateAll(root), + cmdpolicy.ResolveSource{Kind: cmdpolicy.SourceYAML, Name: "/p.yml"}, "") + + // Every leaf should be denied. We do not assert on the root entry + // because Apply skips the root regardless; the contract is "root + // stays dispatchable". + if _, ok := denied["lark-cli"]; ok { + t.Errorf("root should not be in denied map") + } +} + +// Hybrid command: a parent with its own RunE plus children. Aggregation +// requires both own RunE denied AND all children denied for the parent +// itself to be marked denied. +func TestBuildDeniedByPath_hybridParentOwnAllowedKeepsAlive(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + docs := &cobra.Command{Use: "docs", RunE: noop} // hybrid: own RunE + subs + cmdutil.SetRisk(docs, "read") + root.AddCommand(docs) + delete := &cobra.Command{Use: "+delete", RunE: noop} + cmdutil.SetRisk(delete, "high-risk-write") + docs.AddCommand(delete) + + // Allow "docs" (parent) but deny "+delete" child. + e := cmdpolicy.New(&platform.Rule{ + Allow: []string{"docs"}, + }) + denied := cmdpolicy.BuildDeniedByPath(root, e.EvaluateAll(root), + cmdpolicy.ResolveSource{Kind: cmdpolicy.SourceYAML, Name: ""}, "") + + // docs/+delete denied (path doesn't match Allow=["docs"]). + if _, ok := denied["docs/+delete"]; !ok { + t.Errorf("docs/+delete should be denied") + } + // docs itself allowed (path matches Allow=["docs"] exactly). + if _, ok := denied["docs"]; ok { + t.Errorf("docs (hybrid) should NOT be denied -- own RunE is allowed") + } +} + +// Apply with the wrapped *output.ExitError exposes BOTH paths consumers +// rely on: +// 1. cmd/root.go's envelope writer (errors.As on *output.ExitError) +// 2. in-process consumers extracting the platform.CommandDeniedError +func TestApply_runEReturnsExitErrorAndCommandDeniedError(t *testing.T) { + root := buildTree() + denied := map[string]cmdpolicy.Denial{ + "docs/+update": { + Layer: "policy", + PolicySource: "plugin:secaudit", + RuleName: "secaudit-policy", + ReasonCode: "write_not_allowed", + Reason: "write disabled", + }, + } + cmdpolicy.Apply(root, denied) + update := findChild(t, root, "docs", "+update") + + err := update.RunE(update, []string{}) + if err == nil { + t.Fatalf("denied command should return error") + } + + // Path 1: envelope-writer view. + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("error chain must contain *output.ExitError, got %T", err) + } + if exitErr.Detail == nil { + t.Fatalf("ExitError.Detail required for envelope to render") + } + if exitErr.Detail.Type != "command_denied" { + t.Errorf("envelope error.type = %q, want command_denied", exitErr.Detail.Type) + } + // JSON envelope shape: detail.reason_code must be present and + // match the closed enum. + detailMap, ok := exitErr.Detail.Detail.(map[string]any) + if !ok { + t.Fatalf("envelope detail should be map[string]any, got %T", exitErr.Detail.Detail) + } + if detailMap["reason_code"] != "write_not_allowed" { + t.Errorf("detail.reason_code = %v, want write_not_allowed", detailMap["reason_code"]) + } + if detailMap["policy_source"] != "plugin:secaudit" { + t.Errorf("detail.policy_source = %v, want plugin:secaudit", detailMap["policy_source"]) + } + + // Path 2: in-process typed-error view. + var cd *platform.CommandDeniedError + if !errors.As(err, &cd) { + t.Fatalf("error chain must expose *platform.CommandDeniedError") + } + if cd.Path != "docs/+update" || cd.ReasonCode != "write_not_allowed" { + t.Errorf("CommandDeniedError = %+v", cd) + } + + // Envelope round-trip sanity (the actual JSON cmd/root.go would emit). + var buf strings.Builder + output.WriteErrorEnvelope(&buf, exitErr, "user") + if !strings.Contains(buf.String(), `"type": "command_denied"`) { + t.Errorf("envelope JSON missing type=command_denied, got:\n%s", buf.String()) + } + if !strings.Contains(buf.String(), `"reason_code": "write_not_allowed"`) { + t.Errorf("envelope JSON missing reason_code, got:\n%s", buf.String()) + } + // Round-trip parse to verify it's well-formed JSON. + var parsed map[string]any + if err := json.Unmarshal([]byte(buf.String()), &parsed); err != nil { + t.Fatalf("envelope JSON malformed: %v\n%s", err, buf.String()) + } +} + +// Regression: a pure parent group carrying AnnotationPureGroup must be +// skipped by both EvaluateAll and aggregateParents. Without the skip, +// the cmd.installUnknownSubcommandGuard pass (which attaches a RunE to +// every group for cobra's silent-help fallback) would flip Runnable() +// to true for `docs`, `drive`, etc., and a yaml rule like +// `max_risk: read` would deny every ` --help` invocation with +// reason_code = risk_not_annotated. +func TestEvaluateAll_skipsAnnotatedPureGroup(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + drive := &cobra.Command{ + Use: "drive", + RunE: func(*cobra.Command, []string) error { return nil }, // emulate guard injection + Annotations: map[string]string{ + cmdpolicy.AnnotationPureGroup: "true", + }, + } + root.AddCommand(drive) + pull := &cobra.Command{Use: "+pull", RunE: noop} + cmdutil.SetRisk(pull, "read") + drive.AddCommand(pull) + + e := cmdpolicy.New(&platform.Rule{MaxRisk: "read"}) + got := e.EvaluateAll(root) + + if d, present := got["drive"]; present { + t.Errorf("annotated pure group should not appear in Decisions; got %+v", d) + } + if !got["drive/+pull"].Allowed { + t.Errorf("leaf under pure group must still be evaluated; got %+v", got["drive/+pull"]) + } +} + +// Regression: hasRunnableDescendant must also treat +// AnnotationPureGroup-tagged commands as non-runnable. Without the +// skip, an entire branch consisting of a pure-group placeholder + a +// single pure-group leaf would advertise itself as a "live" subtree +// and the parent aggregation pass would refuse to install a deny stub +// (allLiveChildrenDenied flips to false because the pure group is +// neither runnable nor in `denied`). +func TestHasRunnableDescendant_ignoresAnnotatedPureGroup(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + docs := &cobra.Command{Use: "docs"} + root.AddCommand(docs) + + // A pure-group sibling of a real leaf. The parent must still + // aggregate based on the real leaf alone. + placeholder := &cobra.Command{ + Use: "placeholder", + RunE: func(*cobra.Command, []string) error { return nil }, + Annotations: map[string]string{ + cmdpolicy.AnnotationPureGroup: "true", + }, + } + docs.AddCommand(placeholder) + noChild := &cobra.Command{ + Use: "+ghost", + RunE: func(*cobra.Command, []string) error { return nil }, + Annotations: map[string]string{ + cmdpolicy.AnnotationPureGroup: "true", + }, + } + placeholder.AddCommand(noChild) + + fetch := &cobra.Command{Use: "+fetch", RunE: noop} + cmdutil.SetRisk(fetch, "write") + docs.AddCommand(fetch) + + e := cmdpolicy.New(&platform.Rule{MaxRisk: "read"}) + decisions := e.EvaluateAll(root) + denied := cmdpolicy.BuildDeniedByPath(root, decisions, cmdpolicy.ResolveSource{Kind: cmdpolicy.SourceYAML}, "") + + if _, ok := denied["docs"]; !ok { + t.Fatalf("docs should be aggregated as fully denied (pure-group children excluded from live count); map=%+v", denied) + } +} + +// Regression: aggregateParents must treat an AnnotationPureGroup-tagged +// command exactly like a parent-only group. With cmdRunnable accidentally +// true (RunE attached by the guard), the aggregator would otherwise look +// for an own-RunE denial entry and skip aggregation, leaving ` +// --help` reachable even when every live child is denied. +func TestBuildDeniedByPath_aggregatesAnnotatedPureGroup(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + drive := &cobra.Command{ + Use: "drive", + RunE: func(*cobra.Command, []string) error { return nil }, + Annotations: map[string]string{ + cmdpolicy.AnnotationPureGroup: "true", + }, + } + root.AddCommand(drive) + push := &cobra.Command{Use: "+push", RunE: noop} + cmdutil.SetRisk(push, "write") + drive.AddCommand(push) + pull := &cobra.Command{Use: "+pull", RunE: noop} + cmdutil.SetRisk(pull, "write") + drive.AddCommand(pull) + + e := cmdpolicy.New(&platform.Rule{MaxRisk: "read"}) + decisions := e.EvaluateAll(root) + denied := cmdpolicy.BuildDeniedByPath(root, decisions, cmdpolicy.ResolveSource{Kind: cmdpolicy.SourceYAML}, "") + + if _, ok := denied["drive"]; !ok { + t.Fatalf("aggregator must install drive denial when all children denied; map=%+v", denied) + } +} + +// The binary root must never receive a denyStub even if every descendant +// is denied. cobra still needs root to dispatch help / completion. +func TestApply_neverInstallsOnRoot(t *testing.T) { + root := buildTree() + denied := map[string]cmdpolicy.Denial{ + "lark-cli": {Layer: "policy", ReasonCode: "all_children_denied"}, + } + cmdpolicy.Apply(root, denied) + if root.RunE != nil { + t.Errorf("root.RunE should remain nil; got a denyStub installed") + } + if root.Hidden { + t.Errorf("root must stay visible") + } +} diff --git a/internal/cmdpolicy/apply.go b/internal/cmdpolicy/apply.go new file mode 100644 index 000000000..fead7fd4d --- /dev/null +++ b/internal/cmdpolicy/apply.go @@ -0,0 +1,227 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy + +import ( + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/output" +) + +// Apply walks the command tree and installs denyStubs for every path in +// deniedByPath whose Denial.Layer == "policy". It is the user-layer +// counterpart to applyStrictModeDenials in cmd/prune.go; both consume the +// same deniedByPath map produced by the bootstrap pipeline, neither +// re-evaluates rules. +// +// Three things must happen for every denied command (hard-constraints 1-4 +// in the tech doc): +// +// 1. cmd.Hidden = true -- removes from help / completion +// 2. cmd.DisableFlagParsing = true -- denial-wins invariant; otherwise +// cobra would intercept the call +// with "missing required flag" +// before we can return our error +// 3. cmd.RunE = denyStub(denial) -- returns *output.ExitError so +// cmd/root.go's envelope writer +// emits structured JSON (with +// error.type = denial.Layer and +// detail.reason_code = ReasonCode); +// the wrapped error chain still +// exposes *platform.CommandDeniedError +// via errors.As for in-process +// consumers +// +// Apply must be called once during the Bootstrap pipeline BEFORE +// cobra.Execute. It mutates the command tree in place and is not safe to +// call concurrently with command dispatch. Returns the number of commands +// modified. +func Apply(root *cobra.Command, deniedByPath map[string]Denial) int { + if root == nil || len(deniedByPath) == 0 { + return 0 + } + + count := 0 + walkTree(root, func(c *cobra.Command) { + // Never install a denyStub on the binary root itself. Even if the + // aggregation pass somehow marked it (e.g. all-children-denied at + // the top), the binary entry point must remain dispatchable so + // cobra's own help / completion paths still work. + if !c.HasParent() { + return + } + path := CanonicalPath(c) + if path == "" { + return + } + d, ok := deniedByPath[path] + if !ok || d.Layer != LayerPolicy { + return + } + if installDenyStub(c, path, d) { + count++ + } + }) + return count +} + +// AnnotationDenialLayer / AnnotationDenialSource carry the denial +// signal to internal/hook through cobra annotations, avoiding an +// import cycle between hook and cmdpolicy. +const ( + AnnotationDenialLayer = "lark:policy_denied_layer" + AnnotationDenialSource = "lark:policy_denied_source" + + // AnnotationPureGroup marks a cobra.Command that is logically a + // parent-only group but had a RunE attached by the bootstrap-time + // unknown-subcommand guard. The engine treats annotated commands + // the same as un-annotated parent groups (no RunE): they are not + // evaluated against the Rule, and aggregateParents does not treat + // them as hybrids. + // + // Without this signal, a user enabling a policy.yml with + // max_risk: read would see every group (`lark-cli drive --help`, + // `lark-cli docs --help`) return exit 2 + risk_not_annotated, + // because the guard's RunE flips Runnable()=true and the engine + // then demands a risk_level annotation on the group itself. + AnnotationPureGroup = "lark:cmd_pure_group" +) + +// IsPureGroup reports whether cmd carries the AnnotationPureGroup marker. +// Used by the engine to skip evaluation and by the aggregator to treat the +// command as a parent-only group regardless of cobra's Runnable() answer. +func IsPureGroup(cmd *cobra.Command) bool { + if cmd == nil || cmd.Annotations == nil { + return false + } + return cmd.Annotations[AnnotationPureGroup] == "true" +} + +// CommandDeniedFromDenial materialises the wrapped error type carried +// on ExitError.Err so errors.As works for in-process consumers. +func CommandDeniedFromDenial(path string, d Denial) *platform.CommandDeniedError { + return &platform.CommandDeniedError{ + Path: path, + Layer: d.Layer, + PolicySource: d.PolicySource, + RuleName: d.RuleName, + ReasonCode: d.ReasonCode, + Reason: d.Reason, + } +} + +// DenialDetailMap is the canonical detail.* shape every `command_denied` +// envelope shares (see docs/extension/reason-codes.md). Use it as +// ErrDetail.Detail when constructing an envelope outside BuildDenialError. +func DenialDetailMap(cd *platform.CommandDeniedError) map[string]any { + return map[string]any{ + "path": cd.Path, + "layer": cd.Layer, + "policy_source": cd.PolicySource, + "rule_name": cd.RuleName, + "reason_code": cd.ReasonCode, + "reason": cd.Reason, + } +} + +// BuildDenialError is the default envelope for user-layer denials: +// Message comes from CommandDeniedError.Error(), no Hint. Callers that +// need a custom Message or an independent Hint (strict-mode) should +// compose CommandDeniedFromDenial + DenialDetailMap themselves. +func BuildDenialError(path string, d Denial) *output.ExitError { + cd := CommandDeniedFromDenial(path, d) + return &output.ExitError{ + Code: output.ExitValidation, + Detail: &output.ErrDetail{ + Type: "command_denied", + Message: cd.Error(), + Detail: DenialDetailMap(cd), + }, + Err: cd, + } +} + +// installDenyStub mutates a cobra.Command in place. Unlike cmd/prune.go +// which does RemoveCommand+AddCommand (changing the pointer), we modify +// the existing node so any external reference (snapshots, alias targets) +// continues to point at the same cmd. +// +// Help fields (cmd.Short / cmd.Long / cmd.Flags()) are deliberately +// preserved so `--help` on a denied command still describes what the +// command was intended to do. +// +// Two cobra Annotations are set as a denial signal that internal/hook +// reads (without taking a dependency on this package): +// +// - AnnotationDenialLayer -> "policy" or "strict_mode" +// - AnnotationDenialSource -> the PolicySource ("yaml", "plugin:foo", ...) +// +// Returns true when the stub was actually installed and false on the +// strict-mode early-return so callers can compute an accurate "commands +// modified" count. +func installDenyStub(cmd *cobra.Command, path string, d Denial) bool { + // strict-mode wins over user-layer pruning. If the command was + // already replaced by a strict-mode stub (cmd/prune.go::strictModeStubFrom + // writes layer=strict_mode), do NOT overwrite -- the user-layer + // rule cannot relax or relabel a credential-hard boundary. + // + // Behaviour without this guard (pre-fix): a user yaml rule matching + // a strict-mode stub's path would replace the RunE with the pruning + // denyStub, hiding the original strict-mode error message AND + // re-labelling detail.layer from "strict_mode" to "policy". + if cmd.Annotations != nil && + cmd.Annotations[AnnotationDenialLayer] == LayerStrictMode { + return false + } + cmd.Hidden = true + cmd.DisableFlagParsing = true + + // Bypass cobra's pre-RunE gates that would otherwise short-circuit + // before the wrapped RunE (= where observers + denial guard live): + // + // 1. Args validator: original commands often declare cobra.NoArgs + // or a custom Args function. With DisableFlagParsing=true, + // `--doc xxx` looks like positional args; cobra.ValidateArgs + // fires BEFORE PersistentPreRunE / PreRunE / RunE and would + // surface a Cobra usage error instead of our pruning envelope. + // ArbitraryArgs accepts everything. + // + // 2. Parent's PersistentPreRunE: cobra's "first PersistentPreRunE + // wins" walks UP from the leaf. cmd/auth/auth.go declares a + // PersistentPreRunE that returns external_provider when env + // credentials are set; without our leaf-level override, that + // fires before pruning's RunE and the caller sees the wrong + // envelope. We set a no-op leaf PersistentPreRunE that just + // silences usage and returns nil, so dispatch proceeds to the + // wrapped RunE (which produces the real pruning envelope and + // lets Before/After observers fire). + cmd.Args = cobra.ArbitraryArgs + cmd.PersistentPreRunE = func(c *cobra.Command, _ []string) error { + c.SilenceUsage = true + return nil + } + cmd.PersistentPreRun = nil + cmd.PreRunE = nil + cmd.PreRun = nil + + if cmd.Annotations == nil { + cmd.Annotations = map[string]string{} + } + cmd.Annotations[AnnotationDenialLayer] = d.Layer + cmd.Annotations[AnnotationDenialSource] = d.PolicySource + + denial := d // capture by value for the closure + cmd.RunE = func(c *cobra.Command, args []string) error { + // error.type is the user-facing semantic ("a command was denied by + // policy"). detail.layer carries the implementation distinction + // ("policy" vs "strict_mode") for debugging. + return BuildDenialError(path, denial) + } + // Clear any pre-existing Run hook: cobra prefers RunE when both are + // set, but leaving a stale Run around is a foot-gun for future + // maintainers. + cmd.Run = nil + return true +} diff --git a/internal/cmdpolicy/denial.go b/internal/cmdpolicy/denial.go new file mode 100644 index 000000000..3411984d0 --- /dev/null +++ b/internal/cmdpolicy/denial.go @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy + +import "sort" + +// Layer values match CommandDeniedError.Layer and the detail.layer +// field of the JSON envelope (under error.type = "command_denied"). +const ( + LayerStrictMode = "strict_mode" + // LayerPolicy is the user-layer enforcement label. The string value + // is "policy" — the package name "cmdpolicy" matches it. This + // replaces the older "pruning" label. + LayerPolicy = "policy" +) + +// Denial is the merged record for a single rejected command path. It +// is distinct from the user-layer-only Decision type: Denial only +// exists when the command is rejected (the Allowed bool would be +// wasted here, hence not reusing Decision). +type Denial struct { + Layer string // "strict_mode" | "policy" + PolicySource string // "plugin:secaudit" | "yaml:mywork" | "strict-mode" | "" + RuleName string // matched Rule.Name (if any) + ReasonCode string // closed enum, see docs/extension/reason-codes.md + Reason string // human-readable +} + +// ChildDenial is what AggregateChildren consumes — it pairs a Denial +// with the child command's path so the aggregate can carry that +// breakdown for envelope.detail.children_denied. +type ChildDenial struct { + Path string + Denial Denial +} + +// AggregateChildren produces the parent-group Denial when every child +// of a command group is itself denied. The rules: +// +// - all children share Layer "strict_mode" → parent Layer = +// strict_mode, parent ReasonCode = single child's ReasonCode (if +// consistent) or "mixed_children_strict_mode" otherwise. +// - all children share Layer "policy" → parent Layer = policy, +// ReasonCode behaves analogously. +// - mixed layers across children → parent Layer = "policy", +// ReasonCode = "all_children_denied", PolicySource = "mixed". +// +// Calling with an empty slice returns a zero Denial — callers should +// treat this as "no aggregation needed". +func AggregateChildren(children []ChildDenial) Denial { + if len(children) == 0 { + return Denial{} + } + + layers := map[string]struct{}{} + reasonCodes := map[string]struct{}{} + sources := map[string]struct{}{} + ruleNames := map[string]struct{}{} + for _, c := range children { + layers[c.Denial.Layer] = struct{}{} + reasonCodes[c.Denial.ReasonCode] = struct{}{} + if c.Denial.PolicySource != "" { + sources[c.Denial.PolicySource] = struct{}{} + } + if c.Denial.RuleName != "" { + ruleNames[c.Denial.RuleName] = struct{}{} + } + } + + // Mixed: layers differ across children. Parent goes to Layer=policy + // (the more "user-recoverable" of the two — swapping policy can + // flip children, swapping credential cannot). + if len(layers) > 1 { + return Denial{ + Layer: LayerPolicy, + PolicySource: "mixed", + ReasonCode: "all_children_denied", + Reason: "all child commands are denied (mixed reasons)", + } + } + + var layer string + for l := range layers { + layer = l + } + + d := Denial{Layer: layer} + + switch len(reasonCodes) { + case 1: + for rc := range reasonCodes { + d.ReasonCode = rc + } + default: + switch layer { + case LayerStrictMode: + d.ReasonCode = "mixed_children_strict_mode" + default: + d.ReasonCode = "mixed_children_policy" + } + } + + if len(sources) == 1 { + for s := range sources { + d.PolicySource = s + } + } + if layer == LayerStrictMode { + d.PolicySource = "strict-mode" + } + + if len(ruleNames) == 1 { + for n := range ruleNames { + d.RuleName = n + } + } + + d.Reason = "all child commands are denied" + return d +} + +// SortChildren orders children by Path. The aggregate output of +// AggregateChildren is deterministic regardless of slice order, but +// tests and the envelope's children_denied list want a stable order. +func SortChildren(children []ChildDenial) { + sort.Slice(children, func(i, j int) bool { + return children[i].Path < children[j].Path + }) +} diff --git a/internal/cmdpolicy/denial_test.go b/internal/cmdpolicy/denial_test.go new file mode 100644 index 000000000..6c66665cb --- /dev/null +++ b/internal/cmdpolicy/denial_test.go @@ -0,0 +1,98 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy_test + +import ( + "testing" + + "github.com/larksuite/cli/internal/cmdpolicy" +) + +func TestAggregateChildren_allSameLayerAndReason(t *testing.T) { + got := cmdpolicy.AggregateChildren([]cmdpolicy.ChildDenial{ + {Path: "docs/+update", Denial: cmdpolicy.Denial{ + Layer: cmdpolicy.LayerPolicy, PolicySource: "yaml:agent", + ReasonCode: "write_not_allowed", RuleName: "agent-policy", + }}, + {Path: "docs/+delete", Denial: cmdpolicy.Denial{ + Layer: cmdpolicy.LayerPolicy, PolicySource: "yaml:agent", + ReasonCode: "write_not_allowed", RuleName: "agent-policy", + }}, + }) + if got.Layer != cmdpolicy.LayerPolicy || got.ReasonCode != "write_not_allowed" { + t.Fatalf("got %+v, want layer=policy reason=write_not_allowed", got) + } + if got.PolicySource != "yaml:agent" || got.RuleName != "agent-policy" { + t.Fatalf("Source / RuleName should propagate when consistent, got %+v", got) + } +} + +func TestAggregateChildren_sameLayerMixedReasons(t *testing.T) { + got := cmdpolicy.AggregateChildren([]cmdpolicy.ChildDenial{ + {Denial: cmdpolicy.Denial{Layer: cmdpolicy.LayerPolicy, ReasonCode: "write_not_allowed"}}, + {Denial: cmdpolicy.Denial{Layer: cmdpolicy.LayerPolicy, ReasonCode: "domain_not_allowed"}}, + }) + if got.Layer != cmdpolicy.LayerPolicy || got.ReasonCode != "mixed_children_policy" { + t.Fatalf("got %+v, want layer=policy reason=mixed_children_policy", got) + } +} + +func TestAggregateChildren_strictModeBranch(t *testing.T) { + got := cmdpolicy.AggregateChildren([]cmdpolicy.ChildDenial{ + {Denial: cmdpolicy.Denial{Layer: cmdpolicy.LayerStrictMode, ReasonCode: "identity_not_supported"}}, + {Denial: cmdpolicy.Denial{Layer: cmdpolicy.LayerStrictMode, ReasonCode: "identity_not_supported"}}, + }) + if got.Layer != cmdpolicy.LayerStrictMode || got.ReasonCode != "identity_not_supported" { + t.Fatalf("got %+v", got) + } + if got.PolicySource != "strict-mode" { + t.Fatalf("PolicySource = %q, want strict-mode", got.PolicySource) + } +} + +// Mixed layers (some strict_mode, some policy) collapse to Layer=policy +// per the design rule — a parent group failing for "both" reasons is +// most actionable framed as a user-policy issue (swappable) rather than +// a credential capability one (not swappable). +func TestAggregateChildren_mixedLayersFallsToPolicy(t *testing.T) { + got := cmdpolicy.AggregateChildren([]cmdpolicy.ChildDenial{ + {Path: "docs/+update", Denial: cmdpolicy.Denial{ + Layer: cmdpolicy.LayerStrictMode, ReasonCode: "identity_not_supported", + }}, + {Path: "docs/+fetch", Denial: cmdpolicy.Denial{ + Layer: cmdpolicy.LayerPolicy, ReasonCode: "domain_not_allowed", + }}, + }) + if got.Layer != cmdpolicy.LayerPolicy { + t.Fatalf("Layer = %q, want policy (mixed-children rule)", got.Layer) + } + if got.ReasonCode != "all_children_denied" { + t.Fatalf("ReasonCode = %q, want all_children_denied", got.ReasonCode) + } + if got.PolicySource != "mixed" { + t.Fatalf("PolicySource = %q, want mixed", got.PolicySource) + } +} + +func TestAggregateChildren_emptySlice(t *testing.T) { + got := cmdpolicy.AggregateChildren(nil) + if (got != cmdpolicy.Denial{}) { + t.Fatalf("empty slice should produce zero Denial, got %+v", got) + } +} + +func TestSortChildren_stableOrder(t *testing.T) { + children := []cmdpolicy.ChildDenial{ + {Path: "docs/+update"}, + {Path: "docs/+delete"}, + {Path: "docs/+create"}, + } + cmdpolicy.SortChildren(children) + want := []string{"docs/+create", "docs/+delete", "docs/+update"} + for i, c := range children { + if c.Path != want[i] { + t.Fatalf("children[%d].Path = %q, want %q", i, c.Path, want[i]) + } + } +} diff --git a/internal/cmdpolicy/diagnostic.go b/internal/cmdpolicy/diagnostic.go new file mode 100644 index 000000000..9b2393248 --- /dev/null +++ b/internal/cmdpolicy/diagnostic.go @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy + +// diagnosticPaths lists command paths that are unconditionally allowed, +// regardless of any user-layer Rule. Entries must satisfy two properties: +// +// 1. Read-only. The command performs no I/O outside the local process +// and never mutates remote state. +// 2. Self-reflective. Denying the command would produce a UX dead-end +// where the operator can no longer inspect / validate the policy +// that is locking them out. +// +// Today this is `config policy show` and `config plugins show` -- +// both purely local introspection over the resolved policy. Keep the +// list small and audited: every entry is a permanent hole in the +// fail-closed boundary. +var diagnosticPaths = map[string]bool{ + "config/policy/show": true, + "config/plugins/show": true, +} + +// IsDiagnosticPath reports whether the given canonical command path is +// exempt from user-layer pruning. Exported for test packages; callers +// inside this package use the unexported helper. +func IsDiagnosticPath(path string) bool { + return diagnosticPaths[path] +} diff --git a/internal/cmdpolicy/diagnostic_test.go b/internal/cmdpolicy/diagnostic_test.go new file mode 100644 index 000000000..cc1c3ffa6 --- /dev/null +++ b/internal/cmdpolicy/diagnostic_test.go @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy_test + +import ( + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" +) + +// configPolicyTree builds the minimal slice of the real command tree +// where diagnostic exemption applies: root -> config -> policy -> show. +func configPolicyTree() *cobra.Command { + root := &cobra.Command{Use: "lark-cli"} + config := &cobra.Command{Use: "config"} + root.AddCommand(config) + policy := &cobra.Command{Use: "policy"} + config.AddCommand(policy) + policy.AddCommand(&cobra.Command{Use: "show", RunE: noop}) + // Plus an unrelated command that the Rule will deny, to anchor the + // "everything except diagnostics" check. + im := &cobra.Command{Use: "im"} + root.AddCommand(im) + im.AddCommand(&cobra.Command{Use: "+send", RunE: noop}) + return root +} + +func TestEvaluate_diagnosticAllowedDespiteStrictAllow(t *testing.T) { + root := configPolicyTree() + // Rule that allows ONLY docs/** -- normally locks out everything else. + e := cmdpolicy.New(&platform.Rule{ + Allow: []string{"docs/**"}, + }) + got := e.EvaluateAll(root) + + if !got["config/policy/show"].Allowed { + t.Errorf("config/policy/show must be unconditionally allowed; got Allowed=false reason=%q", + got["config/policy/show"].ReasonCode) + } + // Sanity: a non-diagnostic command is still denied so we know the + // rule itself is active. + if got["im/+send"].Allowed { + t.Errorf("im/+send should be denied by Allow=[docs/**]; got Allowed=true") + } +} + +func TestEvaluate_diagnosticAllowedDespiteExplicitDeny(t *testing.T) { + // Even a Rule that explicitly Denies the path must not lock the + // operator out -- diagnostic is a permanent hole. If a security- + // sensitive deployment needs to block introspection, they should + // strip the binary, not rely on Rule. + root := configPolicyTree() + e := cmdpolicy.New(&platform.Rule{ + Allow: []string{"**"}, + Deny: []string{"config/policy/**"}, + }) + got := e.EvaluateAll(root) + + if !got["config/policy/show"].Allowed { + t.Errorf("config/policy/show must override explicit Deny; got Allowed=false reason=%q", + got["config/policy/show"].ReasonCode) + } +} + +func TestIsDiagnosticPath(t *testing.T) { + cases := []struct { + path string + want bool + }{ + {"config/policy/show", true}, + {"config/plugins/show", true}, + {"config/policy", false}, // parent group itself is not exempt + {"config/plugins", false}, // parent group itself is not exempt + {"docs/+fetch", false}, + {"", false}, + } + for _, tc := range cases { + if got := cmdpolicy.IsDiagnosticPath(tc.path); got != tc.want { + t.Errorf("IsDiagnosticPath(%q) = %v, want %v", tc.path, got, tc.want) + } + } +} diff --git a/internal/cmdpolicy/engine.go b/internal/cmdpolicy/engine.go new file mode 100644 index 000000000..c2e7e0162 --- /dev/null +++ b/internal/cmdpolicy/engine.go @@ -0,0 +1,392 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +// Package cmdpolicy is the user-layer command policy engine. It consumes a +// platform.Rule and the cobra command tree, evaluates each runnable command +// against the rule's four-axis filter (Allow / Deny / MaxRisk / Identities), +// and produces a path -> Decision map. A separate BuildDeniedByPath step +// converts those leaf decisions into a deniedByPath map (with parent-group +// aggregation), which the Apply step consumes to install denyStubs. +// +// This package only implements the user-layer half. Strict-mode is handled +// by cmd/prune.go, which produces command_denied envelopes of the same +// shape via BuildDenialError so external agents can dispatch on +// detail.layer / reason_code uniformly regardless of which layer rejected +// the call. +package cmdpolicy + +import ( + "fmt" + + "github.com/bmatcuk/doublestar/v4" + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdmeta" +) + +// Decision is the user-layer single-rule evaluation result. Distinct from +// Denial: Decision carries Allowed=true/false and the +// rejection reason when Allowed=false; Denial only ever exists when the +// command is rejected. Keeping them separate avoids a perpetually-false +// Allowed field on Denial. +type Decision struct { + Allowed bool + ReasonCode string // "" when Allowed=true + Reason string // human-readable +} + +// Engine evaluates a Rule against the command tree. It is stateless except +// for the Rule snapshot it was constructed with. +type Engine struct { + rule *platform.Rule +} + +// New returns an Engine bound to a Rule. A nil Rule means "no user-layer +// restriction" -- EvaluateOne always returns Allowed=true. +func New(rule *platform.Rule) *Engine { + return &Engine{rule: rule} +} + +// EvaluateAll walks the command tree and evaluates every **runnable** +// command against the Rule. Pure parent groups (no RunE) are deliberately +// skipped here: their decision is derived from children by +// BuildDeniedByPath. Evaluating groups directly would incorrectly deny +// "docs" under an Allow:["docs/**"] rule (the group's own path "docs" +// does not match the "**"-requiring glob). +// +// Hybrid commands (own RunE plus children) are evaluated as ordinary +// leaves here; the aggregation pass treats them specially. +func (e *Engine) EvaluateAll(root *cobra.Command) map[string]Decision { + out := map[string]Decision{} + walkTree(root, func(c *cobra.Command) { + if !c.Runnable() { + return + } + // Pure parent groups carrying the AnnotationPureGroup marker + // (installed by cmd.installUnknownSubcommandGuard) look + // Runnable to cobra but are not a real leaf: skip them just + // like cobra-native parent groups, so a user-level Rule does + // not block ` --help` discovery. + if IsPureGroup(c) { + return + } + path := CanonicalPath(c) + if path == "" { + return + } + out[path] = e.EvaluateOne(c) + }) + return out +} + +// EvaluateOne returns the user-layer decision for a single command. Always +// Allowed=true when the engine has no Rule. +func (e *Engine) EvaluateOne(cmd *cobra.Command) Decision { + if e.rule == nil { + return Decision{Allowed: true} + } + r := e.rule + path := CanonicalPath(cmd) + + if IsDiagnosticPath(path) { + return Decision{Allowed: true} + } + + // A registered Rule expresses intent over the closed risk taxonomy + // (read / write / high-risk-write). Two ways a command can fall + // outside that taxonomy: + // + // - "absent" (no risk_level annotation) — fail-closed by default, + // but Rule.AllowUnannotated=true opts out for gradual adoption. + // - "invalid" (annotation exists but is a typo / not in the + // closed enum) — always fail-closed regardless of + // AllowUnannotated. Typo is a code bug, not a migration phase. + cmdRiskStr, hasRisk := cmdmeta.Risk(cmd) + cmdRisk := platform.Risk(cmdRiskStr) + var ( + cmdRank int + cmdRankOk bool + ) + if hasRisk { + cmdRank, cmdRankOk = cmdRisk.Rank() + if !cmdRankOk { + return Decision{ + Allowed: false, + ReasonCode: "risk_invalid", + Reason: fmt.Sprintf("invalid risk %q; did you mean %q?", cmdRiskStr, suggestRisk(cmdRiskStr)), + } + } + } else if !r.AllowUnannotated { + return Decision{ + Allowed: false, + ReasonCode: "risk_not_annotated", + Reason: "command has no risk_level annotation; rule denies unannotated commands", + } + } + + // Axis 1: Deny has priority. + if matched, ok := firstMatch(r.Deny, path); ok { + return Decision{ + Allowed: false, + ReasonCode: "command_denylisted", + Reason: fmt.Sprintf("command path %q matched deny pattern %q", path, matched), + } + } + + // Axis 2: Allow gate (empty allow means "no restriction"). + if len(r.Allow) > 0 && !matchesAny(r.Allow, path) { + return Decision{ + Allowed: false, + ReasonCode: "domain_not_allowed", + Reason: fmt.Sprintf("command path %q not in allow list %v", path, r.Allow), + } + } + + // Axis 3: MaxRisk. Skipped when cmd risk is absent + AllowUnannotated: + // the engine has no rank to compare against, and AllowUnannotated + // is the explicit "allow this through" opt-in. + if r.MaxRisk != "" && cmdRankOk { + if limit, limitOk := r.MaxRisk.Rank(); limitOk && cmdRank > limit { + return Decision{ + Allowed: false, + ReasonCode: reasonCodeForRisk(cmdRisk), + Reason: fmt.Sprintf("command risk %q exceeds rule max_risk %q", cmdRisk, r.MaxRisk), + } + } + } + + // Axis 4: Identities. Unknown command identities is treated as ALLOW. + if len(r.Identities) > 0 { + cmdIdents := cmdmeta.Identities(cmd) + if cmdIdents != nil && !hasIdentityIntersection(r.Identities, cmdIdents) { + return Decision{ + Allowed: false, + ReasonCode: "identity_mismatch", + Reason: fmt.Sprintf("command supports identities %v; rule allows %v", cmdIdents, r.Identities), + } + } + } + + return Decision{Allowed: true} +} + +// BuildDeniedByPath converts engine Decisions to a deniedByPath map keyed +// by canonical path. It performs the parent-group aggregation defined in +// the tech doc: a non-runnable parent whose every runnable descendant is +// denied gets an aggregate denial (via AggregateChildren); +// hybrid commands (own RunE + children) get one only when both their own +// RunE and all children are denied. +// +// The root command (no parent) is never installed with a denyStub even if +// every child is denied -- the binary entry point must remain dispatchable +// so `--help` and similar remain available. +// +// source / ruleName populate PolicySource and RuleName on the produced +// Denial values, so envelope output can attribute denials. +func BuildDeniedByPath(root *cobra.Command, decisions map[string]Decision, source ResolveSource, ruleName string) map[string]Denial { + out := map[string]Denial{} + + sourceLabel := policySourceLabel(source) + for path, d := range decisions { + if !d.Allowed { + out[path] = Denial{ + Layer: LayerPolicy, + PolicySource: sourceLabel, + RuleName: ruleName, + ReasonCode: d.ReasonCode, + Reason: d.Reason, + } + } + } + + aggregateParents(root, out) + return out +} + +// aggregateParents recursively examines each parent group. Returns true +// when every runnable descendant beneath cmd (including cmd itself when +// runnable) is denied; in that case the function also inserts an aggregate +// Denial for cmd, unless cmd is the binary root or cmd is already in the +// map (own RunE denial preserved). +// +// "Live" children are those with at least one runnable descendant; pure +// non-runnable placeholders neither count toward "all denied" nor block +// the aggregation. +func aggregateParents(cmd *cobra.Command, denied map[string]Denial) bool { + if cmd == nil { + return false + } + + children := cmd.Commands() + // A pure parent group decorated with the unknown-subcommand guard + // looks Runnable() to cobra but is not a true hybrid: treat it + // exactly like cobra-native parent groups so the aggregation pass + // can still install an aggregate deny stub when every live child + // is denied. + cmdRunnable := cmd.Runnable() && !IsPureGroup(cmd) + cmdPath := CanonicalPath(cmd) + + // Pure leaf + if len(children) == 0 { + if !cmdRunnable { + return false // placeholder, doesn't contribute + } + _, ok := denied[cmdPath] + return ok + } + + // Has children: recurse first, collect direct-child denials for the + // aggregation message. + childDenials := make([]ChildDenial, 0, len(children)) + liveChildSeen := false + allLiveChildrenDenied := true + for _, child := range children { + childDenied := aggregateParents(child, denied) + if hasRunnableDescendant(child) { + liveChildSeen = true + if !childDenied { + allLiveChildrenDenied = false + } + } + if cp := CanonicalPath(child); cp != "" { + if d, ok := denied[cp]; ok { + childDenials = append(childDenials, ChildDenial{Path: cp, Denial: d}) + } + } + } + + if !liveChildSeen { + // No reachable runnable descendant in children, but cmd itself + // may still be a runnable hybrid (own RunE + placeholder + // children). The contract is "every runnable descendant + // beneath cmd (including cmd itself when runnable) is denied", + // so when cmd is runnable, the answer depends on whether cmd + // itself was denied. Returning false unconditionally here lost + // that signal and blocked aggregation up the chain. + if cmdRunnable { + _, ownDenied := denied[cmdPath] + return ownDenied + } + return false + } + + // Hybrid: own RunE must also be denied for the group to count as denied. + if cmdRunnable { + if _, ownDenied := denied[cmdPath]; !ownDenied { + return false + } + } + + if !allLiveChildrenDenied { + return false + } + + // Everything reachable below this command is denied. Install the + // aggregate denyStub if there isn't already an own denial here, and + // skip the binary root. + if cmd.HasParent() && cmdPath != "" { + if _, exists := denied[cmdPath]; !exists { + SortChildren(childDenials) + denied[cmdPath] = AggregateChildren(childDenials) + } + } + return true +} + +// hasRunnableDescendant reports whether cmd or any descendant has RunE. +// We use it to ignore pure placeholder branches when aggregating. +func hasRunnableDescendant(cmd *cobra.Command) bool { + if cmd == nil { + return false + } + if cmd.Runnable() && !IsPureGroup(cmd) { + return true + } + for _, c := range cmd.Commands() { + if hasRunnableDescendant(c) { + return true + } + } + return false +} + +// policySourceLabel produces the "plugin:foo" / "yaml" / "" label that goes +// into CommandDeniedError.PolicySource and envelope.detail.policy_source. +// +// **Plugin name is included** because plugins live inside the binary and +// their names are part of the implementation contract; an integrator +// debugging a denial wants to know which plugin's Restrict() fired. +// +// **YAML file path is deliberately omitted** -- the envelope is observable +// by agents, CI logs, and other downstream systems, and the path leaks +// the user's home directory (e.g. /Users/alice/.lark-cli/policy.yml). +// The Denial.RuleName field already carries the human-identifier the user +// chose for their rule (yaml's "name:" field), which suffices for +// disambiguation. Use `config policy show` if the absolute path matters +// for a local debugging session. +func policySourceLabel(s ResolveSource) string { + switch s.Kind { + case SourcePlugin: + return "plugin:" + s.Name + case SourceYAML: + return "yaml" + } + return "" +} + +// reasonCodeForRisk picks the canonical reason_code for an exceeds-max-risk +// rejection. +func reasonCodeForRisk(risk platform.Risk) string { + if risk == platform.RiskWrite || risk == platform.RiskHighRiskWrite { + return "write_not_allowed" + } + return "risk_too_high" +} + +// matchesAny reports whether path matches any of the doublestar globs. +// Invalid globs are skipped here -- they're rejected upstream by +// ValidateRule when the rule first enters the system. +func matchesAny(globs []string, path string) bool { + _, ok := firstMatch(globs, path) + return ok +} + +// firstMatch returns the first glob in globs that matches path. Used by +// command_denylisted so the envelope can name the specific deny pattern +// that fired. +func firstMatch(globs []string, path string) (string, bool) { + for _, g := range globs { + if ok, err := doublestar.Match(g, path); err == nil && ok { + return g, true + } + } + return "", false +} + +// hasIdentityIntersection reports whether the rule's typed identities +// share any value with the command's raw identity strings. Both slices +// are short (usually 1-2 identities) so a nested loop beats allocating +// a set. +func hasIdentityIntersection(rule []platform.Identity, cmd []string) bool { + for _, x := range rule { + for _, y := range cmd { + if string(x) == y { + return true + } + } + } + return false +} + +// walkTree applies fn to every command in the tree, depth-first. Hidden +// commands are visited too -- they can still be invoked. +func walkTree(root *cobra.Command, fn func(*cobra.Command)) { + if root == nil { + return + } + fn(root) + for _, c := range root.Commands() { + walkTree(c, fn) + } +} diff --git a/internal/cmdpolicy/engine_test.go b/internal/cmdpolicy/engine_test.go new file mode 100644 index 000000000..e2e9ca559 --- /dev/null +++ b/internal/cmdpolicy/engine_test.go @@ -0,0 +1,505 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy_test + +import ( + "errors" + "strings" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdmeta" + "github.com/larksuite/cli/internal/cmdpolicy" + "github.com/larksuite/cli/internal/cmdutil" +) + +// buildTree assembles a tiny realistic tree for engine tests: +// +// lark-cli (root) +// ├── docs +// │ ├── +fetch risk=read identities=[user,bot] +// │ ├── +update risk=write identities=[user] +// │ └── +delete-doc risk=high-risk-write +// └── im +// └── +send risk=write identities=[bot] +func buildTree() *cobra.Command { + root := &cobra.Command{Use: "lark-cli"} + + docs := &cobra.Command{Use: "docs"} + cmdmeta.SetDomain(docs, "docs") + root.AddCommand(docs) + + fetch := &cobra.Command{Use: "+fetch", RunE: noop} + cmdutil.SetRisk(fetch, "read") + cmdutil.SetSupportedIdentities(fetch, []string{"user", "bot"}) + docs.AddCommand(fetch) + + update := &cobra.Command{Use: "+update", RunE: noop} + cmdutil.SetRisk(update, "write") + cmdutil.SetSupportedIdentities(update, []string{"user"}) + docs.AddCommand(update) + + deleteDoc := &cobra.Command{Use: "+delete-doc", RunE: noop} + cmdutil.SetRisk(deleteDoc, "high-risk-write") + docs.AddCommand(deleteDoc) + + im := &cobra.Command{Use: "im"} + cmdmeta.SetDomain(im, "im") + root.AddCommand(im) + + send := &cobra.Command{Use: "+send", RunE: noop} + cmdutil.SetRisk(send, "write") + cmdutil.SetSupportedIdentities(send, []string{"bot"}) + im.AddCommand(send) + + return root +} + +func noop(*cobra.Command, []string) error { return nil } + +func TestEvaluate_nilRuleAllowsAll(t *testing.T) { + root := buildTree() + got := cmdpolicy.New(nil).EvaluateAll(root) + for path, d := range got { + if !d.Allowed { + t.Fatalf("nil rule should allow all, got Allowed=false for %s", path) + } + } +} + +func TestEvaluate_allowGlob(t *testing.T) { + root := buildTree() + e := cmdpolicy.New(&platform.Rule{ + Allow: []string{"docs/**"}, + }) + got := e.EvaluateAll(root) + + if !got["docs/+fetch"].Allowed { + t.Errorf("docs/+fetch should be allowed by docs/** glob") + } + if got["im/+send"].Allowed { + t.Errorf("im/+send should NOT be allowed when Allow=docs/**") + } + if got["im/+send"].ReasonCode != "domain_not_allowed" { + t.Errorf("im/+send ReasonCode = %q, want domain_not_allowed", + got["im/+send"].ReasonCode) + } +} + +func TestEvaluate_denyTakesPriorityOverAllow(t *testing.T) { + root := buildTree() + e := cmdpolicy.New(&platform.Rule{ + Allow: []string{"docs/**"}, + Deny: []string{"docs/+delete-doc"}, + }) + got := e.EvaluateAll(root) + + if got["docs/+delete-doc"].Allowed { + t.Errorf("docs/+delete-doc should be denied by Deny rule") + } + if got["docs/+delete-doc"].ReasonCode != "command_denylisted" { + t.Errorf("ReasonCode = %q, want command_denylisted", + got["docs/+delete-doc"].ReasonCode) + } + if !got["docs/+fetch"].Allowed { + t.Errorf("docs/+fetch should still be allowed (not in Deny)") + } +} + +func TestEvaluate_maxRiskCutoff(t *testing.T) { + root := buildTree() + e := cmdpolicy.New(&platform.Rule{ + MaxRisk: "write", // allow read+write, deny high-risk-write + }) + got := e.EvaluateAll(root) + + if !got["docs/+update"].Allowed { + t.Errorf("+update (risk=write) should pass MaxRisk=write") + } + if !got["docs/+fetch"].Allowed { + t.Errorf("+fetch (risk=read) should pass MaxRisk=write") + } + if got["docs/+delete-doc"].Allowed { + t.Errorf("+delete-doc (risk=high-risk-write) should fail MaxRisk=write") + } + if rc := got["docs/+delete-doc"].ReasonCode; rc != "write_not_allowed" { + t.Errorf("ReasonCode = %q, want write_not_allowed", rc) + } +} + +// Unannotated commands are implicit-deny when any Rule is registered. +// The closed risk taxonomy (read / write / high-risk-write) is the only +// vocabulary a Rule can reason about; an unannotated command falls +// outside that vocabulary and is denied with reason_code +// "risk_not_annotated", regardless of whether the rule sets MaxRisk. +func TestEvaluate_unannotatedRiskIsDeny(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + docs := &cobra.Command{Use: "docs"} + root.AddCommand(docs) + // Note: no SetRisk on this command -> unannotated + orphan := &cobra.Command{Use: "+orphan", RunE: noop} + docs.AddCommand(orphan) + + // Rule without MaxRisk still triggers the implicit deny. + e := cmdpolicy.New(&platform.Rule{Allow: []string{"docs/**"}}) + got := e.EvaluateAll(root) + if got["docs/+orphan"].Allowed { + t.Fatalf("unannotated risk must be denied when a Rule is registered") + } + if got["docs/+orphan"].ReasonCode != "risk_not_annotated" { + t.Errorf("ReasonCode = %q, want risk_not_annotated", got["docs/+orphan"].ReasonCode) + } + + // And with MaxRisk it still uses risk_not_annotated (the missing- + // annotation gate runs before the MaxRisk axis). + e = cmdpolicy.New(&platform.Rule{MaxRisk: "read"}) + got = e.EvaluateAll(root) + if got["docs/+orphan"].ReasonCode != "risk_not_annotated" { + t.Errorf("ReasonCode under MaxRisk = %q, want risk_not_annotated", got["docs/+orphan"].ReasonCode) + } + + // An empty Rule{} (no Allow / Deny / MaxRisk / Identities) still + // triggers the implicit deny. "any registered Rule = enter the safety + // boundary" is the design contract; pin it so future edits cannot + // silently weaken it. + e = cmdpolicy.New(&platform.Rule{}) + got = e.EvaluateAll(root) + if got["docs/+orphan"].Allowed { + t.Fatalf("empty Rule{} must still deny unannotated commands") + } + if got["docs/+orphan"].ReasonCode != "risk_not_annotated" { + t.Errorf("empty Rule{} ReasonCode = %q, want risk_not_annotated", got["docs/+orphan"].ReasonCode) + } + + // Without any Rule, unannotated commands are still allowed (no + // policy engine is invoked when no plugin registers a Rule). + e = cmdpolicy.New(nil) + got = e.EvaluateAll(root) + if !got["docs/+orphan"].Allowed { + t.Fatalf("nil Rule must allow unannotated commands (no main-flow impact)") + } +} + +// AllowUnannotated=true opts out of the "unannotated = deny" rule for +// gradual adoption. The flag does NOT loosen any other axis: Deny still +// rejects, MaxRisk is skipped (no rank to compare), Allow/Identities still +// apply. +func TestEvaluate_allowUnannotatedOptsOutOfDeny(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + docs := &cobra.Command{Use: "docs"} + root.AddCommand(docs) + orphan := &cobra.Command{Use: "+orphan", RunE: noop} + docs.AddCommand(orphan) + + // Without opt-in: still denied + e := cmdpolicy.New(&platform.Rule{Allow: []string{"docs/**"}}) + if got := e.EvaluateAll(root); got["docs/+orphan"].Allowed { + t.Fatalf("default behaviour must deny unannotated; AllowUnannotated should be opt-in") + } + + // With opt-in: allowed + e = cmdpolicy.New(&platform.Rule{ + Allow: []string{"docs/**"}, + AllowUnannotated: true, + }) + got := e.EvaluateAll(root) + if !got["docs/+orphan"].Allowed { + t.Fatalf("AllowUnannotated=true must allow unannotated commands; got %+v", got["docs/+orphan"]) + } + + // AllowUnannotated does NOT bypass Deny: an unannotated command + // hitting a Deny glob is still rejected. + e = cmdpolicy.New(&platform.Rule{ + Deny: []string{"docs/+orphan"}, + AllowUnannotated: true, + }) + got = e.EvaluateAll(root) + if got["docs/+orphan"].Allowed { + t.Fatalf("AllowUnannotated must not bypass Deny; got %+v", got["docs/+orphan"]) + } + if got["docs/+orphan"].ReasonCode != "command_denylisted" { + t.Errorf("ReasonCode under Deny+AllowUnannotated = %q, want command_denylisted", + got["docs/+orphan"].ReasonCode) + } +} + +// risk_invalid (typo) is unaffected by AllowUnannotated and emits a +// "did you mean" suggestion in the reason text. +func TestEvaluate_invalidRiskAlwaysDeny_andSuggests(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + docs := &cobra.Command{Use: "docs"} + root.AddCommand(docs) + typo := &cobra.Command{Use: "+typo", RunE: noop} + cmdutil.SetRisk(typo, "wrtie") + docs.AddCommand(typo) + + // AllowUnannotated=true must NOT bypass risk_invalid — typo is a + // code bug, not a missing annotation. + e := cmdpolicy.New(&platform.Rule{ + MaxRisk: "read", + AllowUnannotated: true, + }) + got := e.EvaluateAll(root) + if got["docs/+typo"].Allowed { + t.Fatalf("AllowUnannotated must not bypass risk_invalid; got %+v", got["docs/+typo"]) + } + if got["docs/+typo"].ReasonCode != "risk_invalid" { + t.Errorf("ReasonCode = %q, want risk_invalid", got["docs/+typo"].ReasonCode) + } + if !strings.Contains(got["docs/+typo"].Reason, "write") { + t.Errorf("Reason should contain suggestion 'write', got %q", got["docs/+typo"].Reason) + } +} + +// Invalid risk annotations (typos like "wrtie" or anything outside the +// read|write|high-risk-write taxonomy) are denied with reason_code +// "risk_invalid". Without this gate they used to pass the MaxRisk axis +// because RiskRank returned ok=false and the comparison was skipped -- +// a typo SetRisk would silently slip past an "agent read-only" rule. +func TestEvaluate_invalidRiskIsDeny(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + docs := &cobra.Command{Use: "docs"} + root.AddCommand(docs) + typo := &cobra.Command{Use: "+typo", RunE: noop} + cmdutil.SetRisk(typo, "wrtie") // typo for "write" + docs.AddCommand(typo) + + // Even under MaxRisk=read the typo command must not slip through. + e := cmdpolicy.New(&platform.Rule{MaxRisk: "read"}) + got := e.EvaluateAll(root) + if got["docs/+typo"].Allowed { + t.Fatalf("invalid risk must be denied under MaxRisk=read, got allowed") + } + if got["docs/+typo"].ReasonCode != "risk_invalid" { + t.Errorf("ReasonCode = %q, want risk_invalid", got["docs/+typo"].ReasonCode) + } + + // Same when no MaxRisk is set -- the taxonomy check runs unconditionally + // once a Rule is present. + e = cmdpolicy.New(&platform.Rule{Allow: []string{"docs/**"}}) + got = e.EvaluateAll(root) + if got["docs/+typo"].ReasonCode != "risk_invalid" { + t.Errorf("ReasonCode without MaxRisk = %q, want risk_invalid", got["docs/+typo"].ReasonCode) + } + + // The risk_invalid gate must fire BEFORE Deny matching, otherwise a + // typo command landing in the deny list would surface as + // command_denylisted and mask the underlying taxonomy violation. + e = cmdpolicy.New(&platform.Rule{Deny: []string{"docs/+typo"}}) + got = e.EvaluateAll(root) + if got["docs/+typo"].ReasonCode != "risk_invalid" { + t.Errorf("ReasonCode under Deny match = %q, want risk_invalid (taxonomy gate must precede Deny)", got["docs/+typo"].ReasonCode) + } + + // Without any Rule, invalid risk is not policed (same main-flow + // no-impact rule as risk_not_annotated). + e = cmdpolicy.New(nil) + got = e.EvaluateAll(root) + if !got["docs/+typo"].Allowed { + t.Fatalf("nil Rule must allow invalid risk (no main-flow impact)") + } +} + +func TestEvaluate_identitiesIntersection(t *testing.T) { + root := buildTree() + e := cmdpolicy.New(&platform.Rule{ + Identities: []platform.Identity{"bot"}, // bot-only rule + }) + got := e.EvaluateAll(root) + + // docs/+fetch has [user, bot] -- intersection includes bot -> ALLOW + if !got["docs/+fetch"].Allowed { + t.Errorf("+fetch (identities=user,bot) should intersect bot rule") + } + // docs/+update has [user] -- no intersection with bot -> DENY + if got["docs/+update"].Allowed { + t.Errorf("+update (identities=user) should fail bot-only rule") + } + if got["docs/+update"].ReasonCode != "identity_mismatch" { + t.Errorf("ReasonCode = %q, want identity_mismatch", + got["docs/+update"].ReasonCode) + } +} + +// Reason strings must carry both the attempted value and the rule's +// constraint so the envelope is self-contained for AI consumers. +// Asserting on substrings (not exact match) leaves room for minor wording +// tweaks while pinning the value-carrying behaviour. +func TestEvaluate_reasonCarriesAttemptAndConstraint(t *testing.T) { + root := buildTree() + + cases := []struct { + name string + rule *platform.Rule + path string + wantInReason []string + }{ + { + name: "identity_mismatch surfaces both identity sets", + rule: &platform.Rule{Identities: []platform.Identity{"bot"}}, + path: "docs/+update", // identities=[user] + wantInReason: []string{"[user]", "[bot]"}, + }, + { + name: "domain_not_allowed surfaces path and allow list", + rule: &platform.Rule{Allow: []string{"docs/**"}}, + path: "im/+send", + wantInReason: []string{`"im/+send"`, "docs/**"}, + }, + { + name: "command_denylisted surfaces matched deny pattern", + rule: &platform.Rule{Deny: []string{"docs/+delete-*"}}, + path: "docs/+delete-doc", + wantInReason: []string{`"docs/+delete-doc"`, `"docs/+delete-*"`}, + }, + { + name: "risk_too_high surfaces cmd risk and max_risk", + rule: &platform.Rule{MaxRisk: "write"}, + path: "docs/+delete-doc", // risk=high-risk-write + wantInReason: []string{`"high-risk-write"`, `"write"`}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := cmdpolicy.New(tc.rule).EvaluateAll(root) + d, ok := got[tc.path] + if !ok { + t.Fatalf("no decision for %q", tc.path) + } + if d.Allowed { + t.Fatalf("%q should have been denied", tc.path) + } + for _, sub := range tc.wantInReason { + if !strings.Contains(d.Reason, sub) { + t.Errorf("reason %q missing %q", d.Reason, sub) + } + } + }) + } +} + +// Unknown identities defaults to ALLOW. A command with risk annotated +// but without supportedIdentities passes any identity filter. +func TestEvaluate_unknownIdentitiesIsAllow(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + cmd := &cobra.Command{Use: "+x", RunE: noop} + cmdutil.SetRisk(cmd, "read") + root.AddCommand(cmd) + // no SetSupportedIdentities + + e := cmdpolicy.New(&platform.Rule{Identities: []platform.Identity{"bot"}}) + got := e.EvaluateAll(root) + if !got["+x"].Allowed { + t.Fatalf("unknown identities must pass any identity rule") + } +} + +// Apply must install denyStubs only on Layer="policy" entries. A +// "strict_mode" denial in the same map must be left for +// applyStrictModeDenials in cmd/. +func TestApply_onlyTouchesPruningLayer(t *testing.T) { + root := buildTree() + denied := map[string]cmdpolicy.Denial{ + "docs/+update": {Layer: "policy", ReasonCode: "write_not_allowed"}, + "docs/+fetch": {Layer: "strict_mode", ReasonCode: "identity_not_supported"}, + } + + count := cmdpolicy.Apply(root, denied) + if count != 1 { + t.Fatalf("Apply count = %d, want 1 (only pruning-layer entries)", count) + } + + update := findChild(t, root, "docs", "+update") + if !update.Hidden { + t.Errorf("+update should be Hidden after Apply") + } + if !update.DisableFlagParsing { + t.Errorf("+update should have DisableFlagParsing=true (constraint #4)") + } + + // strict-mode entry must NOT have been touched here. + fetch := findChild(t, root, "docs", "+fetch") + if fetch.Hidden || fetch.DisableFlagParsing { + t.Errorf("+fetch (strict_mode layer) should NOT be touched by cmdpolicy.Apply") + } +} + +// Calling the denied RunE must produce a typed CommandDeniedError with the +// right Layer/ReasonCode. This is the contract every external consumer +// (agent, integration) depends on. +func TestApply_runEReturnsTypedError(t *testing.T) { + root := buildTree() + cmdpolicy.Apply(root, map[string]cmdpolicy.Denial{ + "docs/+update": { + Layer: "policy", + PolicySource: "plugin:secaudit", + RuleName: "secaudit-policy", + ReasonCode: "write_not_allowed", + Reason: "write disabled", + }, + }) + + update := findChild(t, root, "docs", "+update") + err := update.RunE(update, []string{}) + if err == nil { + t.Fatalf("denied command should return error") + } + var denied *platform.CommandDeniedError + if !errors.As(err, &denied) { + t.Fatalf("error should be *platform.CommandDeniedError, got %T", err) + } + if denied.Layer != "policy" || denied.ReasonCode != "write_not_allowed" { + t.Errorf("denial = %+v, want layer=pruning code=write_not_allowed", denied) + } + if denied.Path != "docs/+update" { + t.Errorf("Path = %q, want docs/+update", denied.Path) + } + if denied.PolicySource != "plugin:secaudit" || denied.RuleName != "secaudit-policy" { + t.Errorf("policy source / rule name lost in stub: %+v", denied) + } +} + +func TestApply_emptyMapNoop(t *testing.T) { + root := buildTree() + if got := cmdpolicy.Apply(root, nil); got != 0 { + t.Fatalf("nil deniedByPath should yield count=0, got %d", got) + } +} + +// CanonicalPath strips the root and joins with slashes -- the form +// doublestar globs need to work. +func TestCanonicalPath(t *testing.T) { + root := buildTree() + update := findChild(t, root, "docs", "+update") + if got := cmdpolicy.CanonicalPath(update); got != "docs/+update" { + t.Fatalf("CanonicalPath = %q, want docs/+update", got) + } + if got := cmdpolicy.CanonicalPath(root); got != "lark-cli" { + t.Fatalf("CanonicalPath(root) = %q, want lark-cli (orphan fallback)", got) + } +} + +// findChild is a test helper: descend a path of cmd.Use names through the +// tree, failing the test if any step is missing. +func findChild(t *testing.T, parent *cobra.Command, names ...string) *cobra.Command { + t.Helper() + cur := parent + for _, n := range names { + var next *cobra.Command + for _, c := range cur.Commands() { + if c.Use == n { + next = c + break + } + } + if next == nil { + t.Fatalf("child %q not found under %q", n, cur.Use) + } + cur = next + } + return cur +} diff --git a/internal/cmdpolicy/path.go b/internal/cmdpolicy/path.go new file mode 100644 index 000000000..fe0124db2 --- /dev/null +++ b/internal/cmdpolicy/path.go @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy + +import ( + "strings" + + "github.com/spf13/cobra" +) + +// CanonicalPath returns the rootless slash-separated path used everywhere in +// the pruning framework. Cobra's CommandPath() yields space-separated +// segments ("lark-cli docs +update"); doublestar globs ("docs/**") require +// slashes, so all internal lookups go through this conversion. +func CanonicalPath(cmd *cobra.Command) string { + if cmd == nil { + return "" + } + parts := make([]string, 0, 4) + for c := cmd; c != nil && c.HasParent(); c = c.Parent() { + parts = append(parts, useName(c)) + } + for i, j := 0, len(parts)-1; i < j; i, j = i+1, j-1 { + parts[i], parts[j] = parts[j], parts[i] + } + if len(parts) == 0 { + return useName(cmd) + } + return strings.Join(parts, "/") +} + +func useName(cmd *cobra.Command) string { + name := cmd.Use + if i := strings.IndexByte(name, ' '); i >= 0 { + name = name[:i] + } + return name +} diff --git a/internal/cmdpolicy/resolver.go b/internal/cmdpolicy/resolver.go new file mode 100644 index 000000000..d70335f58 --- /dev/null +++ b/internal/cmdpolicy/resolver.go @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy + +import ( + "errors" + "fmt" + "os" + + "github.com/larksuite/cli/extension/platform" + pyaml "github.com/larksuite/cli/internal/cmdpolicy/yaml" + "github.com/larksuite/cli/internal/vfs" +) + +type SourceKind string + +const ( + SourcePlugin SourceKind = "plugin" + SourceYAML SourceKind = "yaml" + SourceNone SourceKind = "none" +) + +type ResolveSource struct { + Kind SourceKind + Name string +} + +type PluginRule struct { + PluginName string + Rule *platform.Rule +} + +type Sources struct { + PluginRules []PluginRule + YAMLRule *platform.Rule + YAMLPath string +} + +var ErrMultipleRestricts = errors.New("multiple plugins called Restrict; only one is permitted") + +// Resolve picks by precedence: plugin > yaml > none. Pure function; load +// yaml via LoadYAMLPolicy first. Winner is validated. +func Resolve(s Sources) (*platform.Rule, ResolveSource, error) { + if len(s.PluginRules) > 1 { + names := make([]string, len(s.PluginRules)) + for i, pr := range s.PluginRules { + names[i] = pr.PluginName + } + return nil, ResolveSource{}, fmt.Errorf("%w: %v", ErrMultipleRestricts, names) + } + + if len(s.PluginRules) == 1 { + pr := s.PluginRules[0] + if err := ValidateRule(pr.Rule); err != nil { + return nil, ResolveSource{}, fmt.Errorf("plugin %q rule invalid: %w", pr.PluginName, err) + } + return pr.Rule, ResolveSource{Kind: SourcePlugin, Name: pr.PluginName}, nil + } + + if s.YAMLRule != nil { + if err := ValidateRule(s.YAMLRule); err != nil { + return nil, ResolveSource{}, fmt.Errorf("policy yaml %q: %w", s.YAMLPath, err) + } + return s.YAMLRule, ResolveSource{Kind: SourceYAML, Name: s.YAMLPath}, nil + } + + return nil, ResolveSource{Kind: SourceNone}, nil +} + +// LoadYAMLPolicy returns (nil, nil) when path is empty or file is absent, +// so callers can pass the result straight into Sources.YAMLRule. +func LoadYAMLPolicy(path string) (*platform.Rule, error) { + if path == "" { + return nil, nil + } + if _, err := vfs.Stat(path); err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("stat policy yaml %q: %w", path, err) + } + data, err := vfs.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read policy yaml %q: %w", path, err) + } + rule, err := pyaml.Parse(data) + if err != nil { + return nil, fmt.Errorf("policy yaml %q: %w", path, err) + } + return rule, nil +} diff --git a/internal/cmdpolicy/resolver_test.go b/internal/cmdpolicy/resolver_test.go new file mode 100644 index 000000000..1631cb6c7 --- /dev/null +++ b/internal/cmdpolicy/resolver_test.go @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy_test + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" +) + +func TestResolve_singlePluginWins(t *testing.T) { + rule := &platform.Rule{Name: "secaudit"} + got, src, err := cmdpolicy.Resolve(cmdpolicy.Sources{ + PluginRules: []cmdpolicy.PluginRule{{PluginName: "secaudit", Rule: rule}}, + }) + if err != nil { + t.Fatalf("Resolve err: %v", err) + } + if got != rule || src.Kind != cmdpolicy.SourcePlugin || src.Name != "secaudit" { + t.Fatalf("Resolve = (%v, %+v)", got, src) + } +} + +func TestResolve_pluginShadowsYaml(t *testing.T) { + pluginRule := &platform.Rule{Name: "from-plugin"} + yamlRule := &platform.Rule{Name: "from-yaml"} + got, src, err := cmdpolicy.Resolve(cmdpolicy.Sources{ + PluginRules: []cmdpolicy.PluginRule{{PluginName: "secaudit", Rule: pluginRule}}, + YAMLRule: yamlRule, + YAMLPath: "/some/policy.yml", + }) + if err != nil { + t.Fatalf("Resolve err: %v", err) + } + if got.Name != "from-plugin" || src.Kind != cmdpolicy.SourcePlugin { + t.Fatalf("plugin should shadow yaml, got %+v / %+v", got, src) + } +} + +func TestResolve_yamlWhenNoPlugin(t *testing.T) { + yamlRule := &platform.Rule{Name: "from-yaml", MaxRisk: "read"} + got, src, err := cmdpolicy.Resolve(cmdpolicy.Sources{ + YAMLRule: yamlRule, + YAMLPath: "/some/policy.yml", + }) + if err != nil { + t.Fatalf("Resolve err: %v", err) + } + if got.Name != "from-yaml" || src.Kind != cmdpolicy.SourceYAML { + t.Fatalf("yaml should win when no plugin, got %+v / %+v", got, src) + } + if src.Name != "/some/policy.yml" { + t.Errorf("yaml source Name should carry path, got %q", src.Name) + } +} + +func TestResolve_emptyEverythingIsNone(t *testing.T) { + got, src, err := cmdpolicy.Resolve(cmdpolicy.Sources{}) + if err != nil { + t.Fatalf("Resolve err: %v", err) + } + if got != nil || src.Kind != cmdpolicy.SourceNone { + t.Fatalf("expected (nil, SourceNone), got (%v, %+v)", got, src) + } +} + +// Two plugins both contributing a Rule must produce the typed error so +// the bootstrap pipeline aborts (hard-constraint #7). +func TestResolve_multipleRestrictIsError(t *testing.T) { + _, _, err := cmdpolicy.Resolve(cmdpolicy.Sources{ + PluginRules: []cmdpolicy.PluginRule{ + {PluginName: "a", Rule: &platform.Rule{Name: "a"}}, + {PluginName: "b", Rule: &platform.Rule{Name: "b"}}, + }, + }) + if !errors.Is(err, cmdpolicy.ErrMultipleRestricts) { + t.Fatalf("err = %v, want ErrMultipleRestricts", err) + } +} + +// LoadYAMLPolicy: missing file returns (nil, nil) silently so callers +// can pass the result straight into Sources.YAMLRule without special- +// casing not-exist. +func TestLoadYAMLPolicy_missingIsSilent(t *testing.T) { + missing := filepath.Join(t.TempDir(), "absent-policy.yml") + rule, err := cmdpolicy.LoadYAMLPolicy(missing) + if err != nil { + t.Fatalf("missing yaml should not error, got %v", err) + } + if rule != nil { + t.Fatalf("missing yaml should return nil rule, got %+v", rule) + } +} + +func TestLoadYAMLPolicy_emptyPathIsNoop(t *testing.T) { + rule, err := cmdpolicy.LoadYAMLPolicy("") + if err != nil { + t.Fatalf("empty path should not error, got %v", err) + } + if rule != nil { + t.Fatalf("empty path should return nil rule, got %+v", rule) + } +} + +func TestLoadYAMLPolicy_parsesValid(t *testing.T) { + dir := t.TempDir() + yamlPath := filepath.Join(dir, "policy.yml") + if err := os.WriteFile(yamlPath, []byte("name: from-yaml\nmax_risk: read\n"), 0o644); err != nil { + t.Fatalf("write yaml: %v", err) + } + rule, err := cmdpolicy.LoadYAMLPolicy(yamlPath) + if err != nil { + t.Fatalf("LoadYAMLPolicy err: %v", err) + } + if rule == nil || rule.Name != "from-yaml" { + t.Fatalf("expected rule with name=from-yaml, got %+v", rule) + } +} diff --git a/internal/cmdpolicy/source_label_test.go b/internal/cmdpolicy/source_label_test.go new file mode 100644 index 000000000..dbd31d560 --- /dev/null +++ b/internal/cmdpolicy/source_label_test.go @@ -0,0 +1,96 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy_test + +import ( + "errors" + "strings" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" + "github.com/larksuite/cli/internal/output" +) + +// The envelope's policy_source must never leak the absolute home path. +// "yaml:/Users/alice/.lark-cli/policy.yml" would expose Alice's username +// to any agent or log consumer; the contract is to emit just "yaml" and +// rely on rule_name (from the yaml's "name:" field) for disambiguation. +func TestEnvelope_yamlPolicySourceDoesNotLeakHomePath(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + docs := &cobra.Command{Use: "docs"} + root.AddCommand(docs) + leaf := &cobra.Command{Use: "+write", RunE: func(*cobra.Command, []string) error { return nil }} + docs.AddCommand(leaf) + + e := cmdpolicy.New(&platform.Rule{ + Name: "my-readonly-rule", + Allow: []string{"contact/**"}, // docs/* falls outside, denied + }) + denied := cmdpolicy.BuildDeniedByPath(root, e.EvaluateAll(root), + cmdpolicy.ResolveSource{ + Kind: cmdpolicy.SourceYAML, + Name: "/Users/alice/.lark-cli/policy.yml", // simulate an absolute path + }, "my-readonly-rule") + + cmdpolicy.Apply(root, denied) + err := leaf.RunE(leaf, nil) + + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected denial ExitError, got %v", err) + } + detail := exitErr.Detail.Detail.(map[string]any) + src, _ := detail["policy_source"].(string) + if src != "yaml" { + t.Errorf("policy_source = %q, want %q (no path leak)", src, "yaml") + } + // rule_name carries the disambiguating identifier. + if detail["rule_name"] != "my-readonly-rule" { + t.Errorf("rule_name = %v, want my-readonly-rule", detail["rule_name"]) + } + // Direct probe: the absolute path must not appear anywhere in the + // envelope detail (key OR value). + for k, v := range detail { + if strings.Contains(k, "/Users/alice") || strings.Contains(asString(v), "/Users/alice") { + t.Errorf("envelope detail must not leak '/Users/alice', found in %s = %v", k, v) + } + } +} + +// Plugin name IS allowed in policy_source because plugins are in-binary +// and their names are part of the contract (an integrator debugging a +// denial wants to know which plugin fired). This test pins that intent +// so a future change does not silently strip the plugin name too. +func TestEnvelope_pluginPolicySourceCarriesName(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + leaf := &cobra.Command{Use: "+block", RunE: func(*cobra.Command, []string) error { return nil }} + root.AddCommand(leaf) + + e := cmdpolicy.New(&platform.Rule{ + Name: "secaudit-policy", + Deny: []string{"+block"}, + }) + denied := cmdpolicy.BuildDeniedByPath(root, e.EvaluateAll(root), + cmdpolicy.ResolveSource{Kind: cmdpolicy.SourcePlugin, Name: "secaudit"}, + "secaudit-policy") + cmdpolicy.Apply(root, denied) + + err := leaf.RunE(leaf, nil) + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected ExitError") + } + detail := exitErr.Detail.Detail.(map[string]any) + if detail["policy_source"] != "plugin:secaudit" { + t.Errorf("policy_source = %v, want plugin:secaudit", detail["policy_source"]) + } +} + +func asString(v any) string { + s, _ := v.(string) + return s +} diff --git a/internal/cmdpolicy/strict_mode_skip_test.go b/internal/cmdpolicy/strict_mode_skip_test.go new file mode 100644 index 000000000..90276cab5 --- /dev/null +++ b/internal/cmdpolicy/strict_mode_skip_test.go @@ -0,0 +1,163 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy_test + +import ( + "errors" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdpolicy" +) + +// cmdpolicy.Apply MUST NOT overwrite the denial annotation on a command +// already marked as strict-mode denied. strict-mode is a hard boundary +// (credential-derived); a user-layer rule cannot relabel or replace +// the error path. +// +// Without this invariant: when a user yaml rule happened to match the +// path of a strict-mode stub, Apply would change layer=strict_mode to +// layer=pruning, and the user-visible error would say "denied by yaml" +// instead of "strict mode". The hard-boundary contract demands +// strict_mode wins. +func TestApply_PreservesStrictModeAnnotation(t *testing.T) { + root := &cobra.Command{Use: "root"} + stub := &cobra.Command{ + Use: "victim", + Hidden: true, + Annotations: map[string]string{ + cmdpolicy.AnnotationDenialLayer: cmdpolicy.LayerStrictMode, + cmdpolicy.AnnotationDenialSource: "strict-mode", + }, + RunE: func(*cobra.Command, []string) error { return nil }, + } + root.AddCommand(stub) + + // User-layer pruning denies the same path. + denied := map[string]cmdpolicy.Denial{ + "victim": { + Layer: cmdpolicy.LayerPolicy, + PolicySource: "yaml", + Reason: "denied by user yaml", + ReasonCode: "command_denylisted", + }, + } + cmdpolicy.Apply(root, denied) + + if got := stub.Annotations[cmdpolicy.AnnotationDenialLayer]; got != cmdpolicy.LayerStrictMode { + t.Errorf("strict-mode layer overwritten by pruning: got %q want %q", + got, cmdpolicy.LayerStrictMode) + } + if got := stub.Annotations[cmdpolicy.AnnotationDenialSource]; got != "strict-mode" { + t.Errorf("strict-mode source overwritten: got %q", got) + } +} + +// Regression for codex H13 / C6: a denied command that carries +// flag-like positional args (because DisableFlagParsing=true makes +// every `--doc xxx` look positional) MUST surface the pruning +// envelope, not a cobra usage error. Pre-fix, the original command's +// Args validator (e.g. cobra.NoArgs from shortcut registration) would +// fire BEFORE PersistentPreRunE / RunE and produce +// "Error: positional arguments are not supported". +// +// Fix: installDenyStub sets Args=ArbitraryArgs so cobra's validate +// step always passes, letting dispatch reach the wrapped RunE. +func TestApply_DenyStubBypassesArgsValidator(t *testing.T) { + root := &cobra.Command{Use: "root"} + leaf := &cobra.Command{ + Use: "+update", + Args: cobra.NoArgs, // shortcut style: refuse all positional args + RunE: func(*cobra.Command, []string) error { return nil }, + } + root.AddCommand(leaf) + + denied := map[string]cmdpolicy.Denial{ + "+update": { + Layer: cmdpolicy.LayerPolicy, + PolicySource: "yaml", + ReasonCode: "command_denylisted", + Reason: "denied by user yaml", + }, + } + cmdpolicy.Apply(root, denied) + + if leaf.Args == nil { + t.Fatal("denied command must have non-nil Args validator after Apply") + } + // ArbitraryArgs returns nil for every input -> Args validation no-ops. + if err := leaf.Args(leaf, []string{"--doc", "xxx", "--mode", "append"}); err != nil { + t.Errorf("denied command Args validator should accept any input, got %v", err) + } +} + +// Regression for codex C11 / C13: a denied command whose PARENT +// declares a PersistentPreRunE (e.g. cmd/auth/auth.go's +// external_provider check) MUST surface the pruning envelope, not +// the parent's error. Cobra's "first PersistentPreRunE walking up +// from leaf wins" semantics will pick the parent's PersistentPreRunE +// unless the denied leaf carries its own. +// +// Fix: installDenyStub installs a no-op PersistentPreRunE on the leaf +// so cobra stops there and proceeds to the wrapped RunE (which holds +// the real pruning envelope). +func TestApply_DenyStubBypassesParentPersistentPreRunE(t *testing.T) { + root := &cobra.Command{Use: "root"} + parent := &cobra.Command{ + Use: "auth", + PersistentPreRunE: func(*cobra.Command, []string) error { + return errors.New("parent PersistentPreRunE fired (would mask pruning)") + }, + } + root.AddCommand(parent) + leaf := &cobra.Command{ + Use: "login", + RunE: func(*cobra.Command, []string) error { return nil }, + } + parent.AddCommand(leaf) + + denied := map[string]cmdpolicy.Denial{ + "auth/login": { + Layer: cmdpolicy.LayerPolicy, + PolicySource: "yaml", + ReasonCode: "identity_mismatch", + Reason: "denied", + }, + } + cmdpolicy.Apply(root, denied) + + if leaf.PersistentPreRunE == nil { + t.Fatal("denied command must have leaf-level PersistentPreRunE") + } + // Our PersistentPreRunE must NOT propagate the parent's error. + if err := leaf.PersistentPreRunE(leaf, nil); err != nil { + t.Errorf("denied command leaf PersistentPreRunE should be no-op, got %v", err) + } +} + +// Sanity: a normal command (no prior annotation) still gets the +// pruning denial annotations after Apply. +func TestApply_NonStrictCommandStillGetsPruningAnnotation(t *testing.T) { + root := &cobra.Command{Use: "root"} + leaf := &cobra.Command{ + Use: "normal", + RunE: func(*cobra.Command, []string) error { return nil }, + } + root.AddCommand(leaf) + + denied := map[string]cmdpolicy.Denial{ + "normal": { + Layer: cmdpolicy.LayerPolicy, + PolicySource: "yaml", + Reason: "denied", + ReasonCode: "command_denylisted", + }, + } + cmdpolicy.Apply(root, denied) + + if got := leaf.Annotations[cmdpolicy.AnnotationDenialLayer]; got != cmdpolicy.LayerPolicy { + t.Errorf("expected pruning layer annotation, got %q", got) + } +} diff --git a/internal/cmdpolicy/suggest.go b/internal/cmdpolicy/suggest.go new file mode 100644 index 000000000..2f7362e31 --- /dev/null +++ b/internal/cmdpolicy/suggest.go @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy + +import ( + "github.com/larksuite/cli/extension/platform" +) + +// suggestRisk returns the closest valid Risk literal by edit distance +// for risk_invalid diagnostics; input is never silently substituted. +// Case-insensitive ("WRITE" → "write"); empty in, empty out (the +// absent-annotation case goes to risk_not_annotated, not here). +func suggestRisk(bad string) string { + if bad == "" { + return "" + } + lowered := toLower(bad) + candidates := []platform.Risk{ + platform.RiskRead, platform.RiskWrite, platform.RiskHighRiskWrite, + } + best := string(candidates[0]) + bestDist := levenshtein(lowered, best) + for _, c := range candidates[1:] { + if d := levenshtein(lowered, string(c)); d < bestDist { + bestDist, best = d, string(c) + } + } + return best +} + +// toLower is an ASCII-only lowercase. Risk taxonomy values are +// ASCII; pulling in unicode here would be overkill. +func toLower(s string) string { + b := []byte(s) + for i, c := range b { + if c >= 'A' && c <= 'Z' { + b[i] = c + ('a' - 'A') + } + } + return string(b) +} + +// levenshtein computes the classic edit distance between two strings. +// O(len(a)*len(b)) time, O(min(a,b)) space. Three-element string set +// makes raw performance irrelevant — clarity beats trickiness here. +func levenshtein(a, b string) int { + if len(a) == 0 { + return len(b) + } + if len(b) == 0 { + return len(a) + } + prev := make([]int, len(b)+1) + curr := make([]int, len(b)+1) + for j := 0; j <= len(b); j++ { + prev[j] = j + } + for i := 1; i <= len(a); i++ { + curr[0] = i + for j := 1; j <= len(b); j++ { + cost := 1 + if a[i-1] == b[j-1] { + cost = 0 + } + curr[j] = min3( + prev[j]+1, // deletion + curr[j-1]+1, // insertion + prev[j-1]+cost, // substitution + ) + } + prev, curr = curr, prev + } + return prev[len(b)] +} + +func min3(a, b, c int) int { + m := a + if b < m { + m = b + } + if c < m { + m = c + } + return m +} diff --git a/internal/cmdpolicy/suggest_test.go b/internal/cmdpolicy/suggest_test.go new file mode 100644 index 000000000..da91495a2 --- /dev/null +++ b/internal/cmdpolicy/suggest_test.go @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy + +import "testing" + +// suggest is unexported, so the test lives in the same package. + +func TestSuggestRisk(t *testing.T) { + cases := []struct { + input string + want string + }{ + {"wrtie", "write"}, + {"WRITE", "write"}, + {"reed", "read"}, + {"rad", "read"}, + {"high-rik-write", "high-risk-write"}, + // "highrisk" is genuinely ambiguous between "write" and + // "high-risk-write" — not testing it. + {"", ""}, // empty input has no meaningful suggestion; the engine + // routes the absent case to risk_not_annotated, not risk_invalid. + } + for _, c := range cases { + got := suggestRisk(c.input) + if got != c.want { + t.Errorf("suggestRisk(%q) = %q, want %q", c.input, got, c.want) + } + } +} + +func TestLevenshtein(t *testing.T) { + cases := []struct { + a, b string + want int + }{ + {"", "", 0}, + {"", "abc", 3}, + {"abc", "", 3}, + {"abc", "abc", 0}, + {"wrtie", "write", 2}, + {"kitten", "sitting", 3}, + } + for _, c := range cases { + got := levenshtein(c.a, c.b) + if got != c.want { + t.Errorf("levenshtein(%q,%q) = %d, want %d", c.a, c.b, got, c.want) + } + } +} diff --git a/internal/cmdpolicy/validate.go b/internal/cmdpolicy/validate.go new file mode 100644 index 000000000..21bb168fb --- /dev/null +++ b/internal/cmdpolicy/validate.go @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy + +import ( + "fmt" + + "github.com/bmatcuk/doublestar/v4" + + "github.com/larksuite/cli/extension/platform" +) + +// ValidateRule is the single Rule-validation entry point. It runs from +// every source: yaml file load, Plugin.Restrict (once the Hook surface +// lands), and the policy CLI's validate subcommand. Catching invalid +// rules HERE rather than during evaluation prevents silent fail-open +// scenarios: +// +// - bad MaxRisk string ("readd") would skip the risk check entirely +// - malformed doublestar pattern ("docs/[abc") never matches, so a +// plugin that meant to allow "docs/*" silently allows nothing, +// and a deny list with the same typo silently denies nothing +// +// A typo in either field by a plugin author or admin must abort the load +// rather than continue with a degraded rule (hard-constraint #6 / #11 +// safety contract). +// +// A nil rule is a no-op (treated as "no restriction" everywhere -- not an +// error). +func ValidateRule(r *platform.Rule) error { + if r == nil { + return nil + } + + if r.MaxRisk != "" { + if !r.MaxRisk.IsValid() { + return fmt.Errorf("invalid max_risk %q: must be one of read|write|high-risk-write", r.MaxRisk) + } + } + + for _, id := range r.Identities { + if !id.IsValid() { + return fmt.Errorf("invalid identities entry %q: must be 'user' or 'bot'", id) + } + } + + for _, g := range r.Allow { + if err := validateGlob(g); err != nil { + return fmt.Errorf("invalid allow glob %q: %w", g, err) + } + } + for _, g := range r.Deny { + if err := validateGlob(g); err != nil { + return fmt.Errorf("invalid deny glob %q: %w", g, err) + } + } + return nil +} + +// validateGlob rejects malformed doublestar patterns. doublestar.Match +// returns an error for unbalanced brackets / bad escape sequences; that +// error path is the canonical signal for "this pattern is not valid". +// +// We probe with an empty string -- the goal is to exercise the parser, +// not to compute a match. +func validateGlob(g string) error { + if g == "" { + return fmt.Errorf("empty pattern") + } + if _, err := doublestar.Match(g, ""); err != nil { + return err + } + return nil +} diff --git a/internal/cmdpolicy/validate_test.go b/internal/cmdpolicy/validate_test.go new file mode 100644 index 000000000..3961f12a3 --- /dev/null +++ b/internal/cmdpolicy/validate_test.go @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdpolicy_test + +import ( + "strings" + "testing" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" +) + +// nil rule is "no restriction" everywhere -- validation must agree. +func TestValidateRule_nilIsOk(t *testing.T) { + if err := cmdpolicy.ValidateRule(nil); err != nil { + t.Fatalf("nil rule should validate, got %v", err) + } +} + +func TestValidateRule_validRule(t *testing.T) { + r := &platform.Rule{ + Allow: []string{"docs/**", "contact/+search-*"}, + Deny: []string{"docs/+delete-doc"}, + MaxRisk: "write", + Identities: []platform.Identity{"user", "bot"}, + } + if err := cmdpolicy.ValidateRule(r); err != nil { + t.Fatalf("valid rule rejected: %v", err) + } +} + +// A typo in MaxRisk must abort the load; otherwise the engine would skip +// the risk check entirely and let high-risk-write commands pass under +// what the operator thought was a "read" cap. +func TestValidateRule_badMaxRisk(t *testing.T) { + cases := []string{"readd", "Read", "high_risk_write", "anything"} + for _, bad := range cases { + r := &platform.Rule{MaxRisk: platform.Risk(bad)} + err := cmdpolicy.ValidateRule(r) + if err == nil { + t.Errorf("ValidateRule should reject MaxRisk=%q", bad) + continue + } + if !strings.Contains(err.Error(), "max_risk") { + t.Errorf("error should mention max_risk for MaxRisk=%q, got %v", bad, err) + } + } +} + +// Identities must come from the closed taxonomy {"user","bot"}. A typo +// like "users" would silently lock out everyone (no command intersects +// the typo), so it must abort. +func TestValidateRule_badIdentity(t *testing.T) { + r := &platform.Rule{Identities: []platform.Identity{"user", "admin"}} + err := cmdpolicy.ValidateRule(r) + if err == nil { + t.Fatalf("ValidateRule should reject identity 'admin'") + } + if !strings.Contains(err.Error(), "identities") { + t.Fatalf("error should mention identities, got %v", err) + } +} + +// Malformed doublestar globs are silent fail-open if not caught here +// (doublestar.Match returns an error which matchesAny() ignores). +func TestValidateRule_malformedGlob(t *testing.T) { + cases := []struct { + name string + rule *platform.Rule + }{ + {"bad allow", &platform.Rule{Allow: []string{"docs/[abc"}}}, + {"bad deny", &platform.Rule{Deny: []string{"docs/[abc"}}}, + {"empty allow entry", &platform.Rule{Allow: []string{"", "docs/**"}}}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + err := cmdpolicy.ValidateRule(c.rule) + if err == nil { + t.Fatalf("ValidateRule should reject %+v", c.rule) + } + }) + } +} + +// Empty MaxRisk and Empty Identities slices are both "no restriction" -- +// not an error. +func TestValidateRule_emptyFieldsAreOk(t *testing.T) { + r := &platform.Rule{ + Allow: []string{"docs/**"}, + MaxRisk: "", + Identities: nil, + } + if err := cmdpolicy.ValidateRule(r); err != nil { + t.Fatalf("empty optional fields should validate, got %v", err) + } +} diff --git a/internal/cmdpolicy/yaml/reader.go b/internal/cmdpolicy/yaml/reader.go new file mode 100644 index 000000000..41e85a4c9 --- /dev/null +++ b/internal/cmdpolicy/yaml/reader.go @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package yaml + +import "io" + +// bytesReader avoids pulling in bytes.NewReader at the call site -- yaml.v3 +// only needs an io.Reader. Plain wrapper, no allocation surprises. +type byteReader struct { + data []byte + pos int +} + +func bytesReader(data []byte) io.Reader { return &byteReader{data: data} } + +func (b *byteReader) Read(p []byte) (int, error) { + if b.pos >= len(b.data) { + return 0, io.EOF + } + n := copy(p, b.data[b.pos:]) + b.pos += n + return n, nil +} diff --git a/internal/cmdpolicy/yaml/schema.go b/internal/cmdpolicy/yaml/schema.go new file mode 100644 index 000000000..718d2a8bd --- /dev/null +++ b/internal/cmdpolicy/yaml/schema.go @@ -0,0 +1,77 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +// Package yaml parses a Rule from yaml bytes. It is kept separate from the +// public extension/platform package so that platform stays free of yaml +// library dependencies -- plugins constructing a Rule in Go code never +// import yaml, only the file loader does. +// +// This package does **structural** parsing only (yaml syntax + unknown-field +// rejection). Semantic validation (valid MaxRisk enum, valid identity +// values, valid doublestar glob syntax) is centralised in +// internal/cmdpolicy.ValidateRule so a single contract is enforced regardless +// of whether the Rule came from yaml or from Plugin.Restrict. +package yaml + +import ( + "errors" + "fmt" + "io" + + gopkgyaml "gopkg.in/yaml.v3" + + "github.com/larksuite/cli/extension/platform" +) + +// schema is the internal yaml-tagged shape. Mirrors platform.Rule but lives +// here so the public Rule has no yaml tag baggage. +type schema struct { + Name string `yaml:"name"` + Description string `yaml:"description,omitempty"` + Allow []string `yaml:"allow,omitempty"` + Deny []string `yaml:"deny,omitempty"` + MaxRisk string `yaml:"max_risk,omitempty"` + Identities []string `yaml:"identities,omitempty"` + AllowUnannotated bool `yaml:"allow_unannotated,omitempty"` +} + +// Parse decodes yaml bytes into a *platform.Rule. Unknown fields are +// rejected so an old binary cannot silently ignore new schema additions +// (forward-compat safeguard). +// +// Semantic validation (MaxRisk taxonomy, identity values, glob syntax) is +// the caller's responsibility -- run the result through +// internal/cmdpolicy.ValidateRule before handing it to the engine. +func Parse(data []byte) (*platform.Rule, error) { + var s schema + dec := gopkgyaml.NewDecoder(bytesReader(data)) + dec.KnownFields(true) + if err := dec.Decode(&s); err != nil { + return nil, fmt.Errorf("parse policy yaml: %w", err) + } + + // Reject multi-document input: yaml.v3 only decodes one document + // per call, so a stray "---" followed by another document would + // silently drop the trailing rule. + var extra schema + if err := dec.Decode(&extra); !errors.Is(err, io.EOF) { + if err == nil { + return nil, fmt.Errorf("parse policy yaml: multiple YAML documents are not allowed") + } + return nil, fmt.Errorf("parse policy yaml: %w", err) + } + + idents := make([]platform.Identity, len(s.Identities)) + for i, id := range s.Identities { + idents[i] = platform.Identity(id) + } + return &platform.Rule{ + Name: s.Name, + Description: s.Description, + Allow: s.Allow, + Deny: s.Deny, + MaxRisk: platform.Risk(s.MaxRisk), + Identities: idents, + AllowUnannotated: s.AllowUnannotated, + }, nil +} diff --git a/internal/cmdpolicy/yaml/schema_test.go b/internal/cmdpolicy/yaml/schema_test.go new file mode 100644 index 000000000..912c8b2a5 --- /dev/null +++ b/internal/cmdpolicy/yaml/schema_test.go @@ -0,0 +1,131 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package yaml_test + +import ( + "reflect" + "testing" + + "github.com/larksuite/cli/extension/platform" + pyaml "github.com/larksuite/cli/internal/cmdpolicy/yaml" +) + +func TestParse_validRule(t *testing.T) { + data := []byte(` +name: agent-docs-readonly +description: only-read docs +allow: + - docs/** + - contact/** +deny: + - docs/+update +max_risk: read +identities: + - user +`) + rule, err := pyaml.Parse(data) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + want := &platform.Rule{ + Name: "agent-docs-readonly", + Description: "only-read docs", + Allow: []string{"docs/**", "contact/**"}, + Deny: []string{"docs/+update"}, + MaxRisk: "read", + Identities: []platform.Identity{"user"}, + } + if !reflect.DeepEqual(rule, want) { + t.Fatalf("rule = %+v, want %+v", rule, want) + } +} + +// allow_unannotated is documented in the README / author guide as the +// gradual-adoption opt-in. The yaml schema must carry it through to +// platform.Rule, otherwise a user following the docs would either hit +// "unknown field" (under KnownFields strict mode) or silently lose the +// opt-in and end up with a safer-but-broken policy. +func TestParse_allowUnannotatedPassesThrough(t *testing.T) { + data := []byte(` +name: agent-readonly +max_risk: read +allow_unannotated: true +`) + rule, err := pyaml.Parse(data) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if !rule.AllowUnannotated { + t.Fatalf("AllowUnannotated = false, want true (yaml field must propagate)") + } + if rule.MaxRisk != "read" || rule.Name != "agent-readonly" { + t.Errorf("other fields lost: %+v", rule) + } +} + +// Default is false when the key is absent: pin the fail-closed default so +// future schema edits cannot accidentally flip it. +func TestParse_allowUnannotatedDefaultsFalse(t *testing.T) { + data := []byte(` +name: x +max_risk: read +`) + rule, err := pyaml.Parse(data) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + if rule.AllowUnannotated { + t.Fatalf("AllowUnannotated must default to false when key is absent") + } +} + +// Unknown fields must be rejected so the old binary cannot silently ignore +// new schema additions (forward-compat safeguard). +func TestParse_rejectsUnknownFields(t *testing.T) { + data := []byte(` +name: x +mystery_field: oh no +`) + if _, err := pyaml.Parse(data); err == nil { + t.Fatalf("Parse should reject unknown yaml field 'mystery_field'") + } +} + +// Semantic validation lives in cmdpolicy.ValidateRule. Parse only checks +// structural yaml; an invalid max_risk passes through (validation happens +// downstream). +func TestParse_doesNotValidateSemantics(t *testing.T) { + rule, err := pyaml.Parse([]byte("max_risk: nuclear\n")) + if err != nil { + t.Fatalf("structural parse should succeed, got %v", err) + } + if rule.MaxRisk != "nuclear" { + t.Fatalf("MaxRisk = %q, want passed through as-is", rule.MaxRisk) + } +} + +// An entirely empty file is rejected: the resolver should fall back to +// "no rule" by skipping the file in the first place, not by feeding empty +// bytes through Parse. +func TestParse_emptyIsError(t *testing.T) { + if _, err := pyaml.Parse([]byte{}); err == nil { + t.Fatalf("Parse should reject empty input; the resolver handles 'no file' separately") + } +} + +// A stray "---" separator followed by another document would silently +// drop the trailing rule if yaml.v3 stopped after the first Decode. +// Parse must reject multi-document input so the operator can't typo a +// separator and end up with an unintentionally empty policy. +func TestParse_rejectsMultipleDocuments(t *testing.T) { + data := []byte(`name: first +max_risk: read +--- +name: second +max_risk: write +`) + if _, err := pyaml.Parse(data); err == nil { + t.Fatalf("Parse should reject multi-document YAML input") + } +} diff --git a/internal/cmdutil/factory.go b/internal/cmdutil/factory.go index 1ccfee440..5eff1931f 100644 --- a/internal/cmdutil/factory.go +++ b/internal/cmdutil/factory.go @@ -161,7 +161,7 @@ func (f *Factory) ResolveStrictMode(ctx context.Context) core.StrictMode { func (f *Factory) CheckStrictMode(ctx context.Context, as core.Identity) error { mode := f.ResolveStrictMode(ctx) if mode.IsActive() && !mode.AllowsIdentity(as) { - return output.ErrWithHint(output.ExitValidation, "strict_mode", + return output.ErrWithHint(output.ExitValidation, "command_denied", fmt.Sprintf("strict mode is %q, only %s-identity commands are available", mode, mode.ForcedIdentity()), "if the user explicitly wants to switch policy, see `lark-cli config strict-mode --help` (confirm with the user before switching; switching does NOT require re-bind)") } diff --git a/internal/cmdutil/identity_flag_test.go b/internal/cmdutil/identity_flag_test.go index 54d539583..f4d1c0fb5 100644 --- a/internal/cmdutil/identity_flag_test.go +++ b/internal/cmdutil/identity_flag_test.go @@ -5,6 +5,7 @@ package cmdutil import ( "context" + "strings" "testing" "github.com/larksuite/cli/internal/core" @@ -66,3 +67,49 @@ func TestAddShortcutIdentityFlag_NoDefault(t *testing.T) { t.Fatalf("default value = %q, want empty string", got) } } + +// TC-10: AuthTypes=["user"] → usage contains "identity type: user" and NOT "bot". +func TestAddShortcutIdentityFlag_UserOnlyAuthTypes(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + cmd := &cobra.Command{Use: "test"} + + AddShortcutIdentityFlag(context.Background(), cmd, f, []string{"user"}) + + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if flag.Hidden { + t.Fatal("expected --as flag to be visible") + } + wantUsage := "identity type: user" + if flag.Usage != wantUsage { + t.Errorf("Usage = %q, want %q", flag.Usage, wantUsage) + } + if strings.Contains(flag.Usage, "bot") { + t.Errorf("Usage should not contain \"bot\" for user-only shortcut, got %q", flag.Usage) + } +} + +// TC-11: AuthTypes=["user","bot"] → usage == "identity type: user | bot". +func TestAddShortcutIdentityFlag_UserBotAuthTypes(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + cmd := &cobra.Command{Use: "test"} + + AddShortcutIdentityFlag(context.Background(), cmd, f, []string{"user", "bot"}) + + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if flag.Hidden { + t.Fatal("expected --as flag to be visible") + } + if got := flag.DefValue; got != "" { + t.Fatalf("default value = %q, want empty string", got) + } + wantUsage := "identity type: user | bot" + if flag.Usage != wantUsage { + t.Errorf("Usage = %q, want %q", flag.Usage, wantUsage) + } +} diff --git a/internal/hook/doc.go b/internal/hook/doc.go new file mode 100644 index 000000000..6993cb1bb --- /dev/null +++ b/internal/hook/doc.go @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +// Package hook is the internal Hook dispatch implementation. It owns: +// +// - Registry the in-memory data store mapping (Stage|Event) -> +// registered hooks for fast dispatch +// - Install(root, …) the entry point that wraps every command's RunE +// so Before/After Observers and Wrap chains fire +// around the command's business logic, including +// the denial guard that physically isolates +// pruned commands from Wrap. +// - Emit(event, …) the lifecycle event firing helper used by the +// Bootstrap pipeline. +// +// Plugins NEVER import this package -- they only ever see +// extension/platform. The Registrar contract is implemented inside +// internal/platform, which delegates to this Registry after validating +// the plugin's calls (staging + atomic commit). +package hook diff --git a/internal/hook/emit.go b/internal/hook/emit.go new file mode 100644 index 000000000..c7cf6ed26 --- /dev/null +++ b/internal/hook/emit.go @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package hook + +import ( + "context" + "fmt" + "time" + + "github.com/larksuite/cli/extension/platform" +) + +// shutdownDeadline is the hard upper bound on how long Shutdown +// handlers in total may run. Past this, the framework returns control +// to the caller regardless of unfinished handlers. 2s matches the +// design-doc constraint. +const shutdownDeadline = 2 * time.Second + +// LifecycleError is the typed failure returned by Emit for non-Shutdown +// events when a LifecycleHandler returns an error or panics. Callers can +// errors.As to extract HookName, Event, and the Panic discriminator +// (panic vs returned error) so the envelope writer can produce +// distinct reason_code values: +// +// - Panic == false -> reason_code = "lifecycle_failed" +// - Panic == true -> reason_code = "lifecycle_panic" +// +// Shutdown handler failures are logged inside emitShutdown and never +// returned through this type (Shutdown is non-recoverable; the contract +// is "best effort, never block exit"). +type LifecycleError struct { + Event platform.LifecycleEvent + HookName string + Panic bool + Cause error +} + +func (e *LifecycleError) Error() string { + kind := "failed" + if e.Panic { + kind = "panic" + } + return fmt.Sprintf("lifecycle hook %q %s: %v", e.HookName, kind, e.Cause) +} + +func (e *LifecycleError) Unwrap() error { return e.Cause } + +// Emit fires every LifecycleHandler registered for event in +// registration order. lastErr is propagated to handlers via +// LifecycleContext.Err (typical use: Shutdown handlers see the error +// the command exited with). +// +// Behaviour by event: +// +// - Startup: any handler returning a non-nil error aborts the +// bootstrap (caller decides whether to fail-closed). The first +// such error is returned as *LifecycleError. +// +// - Shutdown: handler errors are logged but do not affect the +// returned error; the framework also caps the total time at +// shutdownDeadline. +func Emit(ctx context.Context, reg *Registry, event platform.LifecycleEvent, lastErr error) error { + if reg == nil { + return nil + } + handlers := reg.LifecycleHandlers(event) + if len(handlers) == 0 { + return nil + } + lc := &platform.LifecycleContext{Event: event, Err: lastErr} + + if event == platform.Shutdown { + return emitShutdown(ctx, handlers, lc) + } + for _, h := range handlers { + if err := callLifecycleSafe(ctx, h, lc); err != nil { + return err + } + } + return nil +} + +// emitShutdown enforces the 2-second total deadline. Handlers receive +// a derived context with the remaining budget; once the budget is +// exhausted, the remaining handlers are skipped (with a stderr +// warning) and Emit returns. +func emitShutdown(parent context.Context, handlers []LifecycleEntry, lc *platform.LifecycleContext) error { + ctx, cancel := context.WithTimeout(parent, shutdownDeadline) + defer cancel() + deadline := time.Now().Add(shutdownDeadline) + + for _, h := range handlers { + if time.Now().After(deadline) { + fmt.Fprintf(stderr(), "warning: shutdown deadline exceeded; skipping hook %q\n", h.Name) + continue + } + if err := callLifecycleSafe(ctx, h, lc); err != nil { + // Shutdown errors are logged, not propagated -- exit is + // non-recoverable anyway. + fmt.Fprintf(stderr(), "warning: shutdown hook %q: %v\n", h.Name, err) + } + } + return nil +} + +// callLifecycleSafe invokes a LifecycleHandler with panic recovery. +// Returns *LifecycleError with Panic=true on recovered panic, Panic=false +// on a regular returned error. nil if the handler succeeded. +func callLifecycleSafe(ctx context.Context, h LifecycleEntry, lc *platform.LifecycleContext) (err error) { + defer func() { + if r := recover(); r != nil { + err = &LifecycleError{ + Event: lc.Event, + HookName: h.Name, + Panic: true, + Cause: fmt.Errorf("%v", r), + } + } + }() + if e := h.Fn(ctx, lc); e != nil { + return &LifecycleError{ + Event: lc.Event, + HookName: h.Name, + Panic: false, + Cause: e, + } + } + return nil +} diff --git a/internal/hook/emit_test.go b/internal/hook/emit_test.go new file mode 100644 index 000000000..df6b0af61 --- /dev/null +++ b/internal/hook/emit_test.go @@ -0,0 +1,110 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package hook + +import ( + "context" + "errors" + "testing" + + "github.com/larksuite/cli/extension/platform" +) + +// A Startup handler returning a regular error must surface as a typed +// *LifecycleError with Panic=false so the cmd-layer guard can pick +// reason_code=lifecycle_failed. +func TestEmit_StartupHandlerError_TypedError(t *testing.T) { + reg := NewRegistry() + want := errors.New("backend down") + reg.AddLifecycle(LifecycleEntry{ + Event: platform.Startup, + Name: "p.boot", + Fn: func(context.Context, *platform.LifecycleContext) error { return want }, + }) + + got := Emit(context.Background(), reg, platform.Startup, nil) + if got == nil { + t.Fatal("expected error from Emit, got nil") + } + var le *LifecycleError + if !errors.As(got, &le) { + t.Fatalf("expected *LifecycleError, got %T %v", got, got) + } + if le.Panic { + t.Errorf("Panic = true, want false (returned error)") + } + if le.HookName != "p.boot" { + t.Errorf("HookName = %q, want p.boot", le.HookName) + } + if !errors.Is(got, want) { + t.Errorf("unwrap should reach original error") + } +} + +// A Startup handler that panics must be recovered and surface as a +// typed *LifecycleError with Panic=true so the cmd-layer guard can +// pick reason_code=lifecycle_panic. +func TestEmit_StartupHandlerPanic_TypedError(t *testing.T) { + reg := NewRegistry() + reg.AddLifecycle(LifecycleEntry{ + Event: platform.Startup, + Name: "p.boot", + Fn: func(context.Context, *platform.LifecycleContext) error { panic("boom") }, + }) + + got := Emit(context.Background(), reg, platform.Startup, nil) + if got == nil { + t.Fatal("expected error from Emit, got nil") + } + var le *LifecycleError + if !errors.As(got, &le) { + t.Fatalf("expected *LifecycleError, got %T %v", got, got) + } + if !le.Panic { + t.Errorf("Panic = false, want true (recovered panic)") + } + if le.HookName != "p.boot" { + t.Errorf("HookName = %q, want p.boot", le.HookName) + } +} + +// A Startup handler that succeeds returns nil; subsequent handlers run. +func TestEmit_StartupAllHandlersRun(t *testing.T) { + reg := NewRegistry() + var calls []string + reg.AddLifecycle(LifecycleEntry{ + Event: platform.Startup, Name: "a", + Fn: func(context.Context, *platform.LifecycleContext) error { + calls = append(calls, "a") + return nil + }, + }) + reg.AddLifecycle(LifecycleEntry{ + Event: platform.Startup, Name: "b", + Fn: func(context.Context, *platform.LifecycleContext) error { + calls = append(calls, "b") + return nil + }, + }) + if err := Emit(context.Background(), reg, platform.Startup, nil); err != nil { + t.Fatalf("Emit: %v", err) + } + if len(calls) != 2 || calls[0] != "a" || calls[1] != "b" { + t.Errorf("handlers fired in unexpected order: %v", calls) + } +} + +// Shutdown handler errors are logged, not propagated; Emit returns nil. +func TestEmit_ShutdownErrorsSwallowed(t *testing.T) { + reg := NewRegistry() + reg.AddLifecycle(LifecycleEntry{ + Event: platform.Shutdown, Name: "flush", + Fn: func(context.Context, *platform.LifecycleContext) error { + return errors.New("flush failed") + }, + }) + if err := Emit(context.Background(), reg, platform.Shutdown, nil); err != nil { + t.Errorf("Shutdown errors must NOT propagate, got: %v", err) + } +} diff --git a/internal/hook/install.go b/internal/hook/install.go new file mode 100644 index 000000000..53fbe6f78 --- /dev/null +++ b/internal/hook/install.go @@ -0,0 +1,358 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package hook + +import ( + "context" + "errors" + "fmt" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/output" +) + +// Install wraps every runnable command's RunE so the hook chain fires +// around it. The wrapper is: +// +// Before observers (always run, panic-safe) +// denial guard: +// if cmd is denied -> denyStub returns its CommandDeniedError +// else -> compose(matched Wrappers)(originalRunE) runs +// After observers (always run, panic-safe, sees inv.Err) +// +// Critical invariants enforced here (constraint #2): +// +// - **Denied commands NEVER reach the Wrap chain.** The guard runs +// denyStub directly so no plugin Wrapper can suppress or rewrite +// the denial. Observers still fire (audit must see the attempted +// call), but Wrap is physically out of the path. +// +// - **After observers always fire**, even when RunE returned an +// error. Wrap short-circuits via AbortError get converted to +// *output.ExitError so cmd/root.go emits the right envelope. +// +// - **Denial layer / source are populated from cobra annotations +// before any hook fires.** populateInvocationDenial reads the +// annotations attached by cmdpolicy.Apply and strictModeStubFrom, +// avoiding an import cycle between hook and cmdpolicy. +// +// Install must be called once during the Bootstrap pipeline after +// policy pruning has finished. Calling it twice on the same tree is a +// bug (each command's RunE would be wrapped multiple times). +func Install(root *cobra.Command, reg *Registry, snapshot CommandViewSource) { + if root == nil || reg == nil { + return + } + walkTree(root, func(c *cobra.Command) { + if !c.Runnable() { + return + } + if !c.HasParent() { + return // do not wrap the binary root itself + } + wrapRunE(c, reg, snapshot) + }) +} + +// CommandViewSource resolves a *cobra.Command into a CommandView. The +// default implementation returns a live view over the cobra node; +// strict-mode's replacement stubs (cmd/prune.go) carry the original +// command's annotations forward so the view keeps reporting accurate +// Risk / Identities / Domain after replacement. +type CommandViewSource interface { + View(cmd *cobra.Command) platform.CommandView +} + +// wrapRunE replaces cmd.RunE with a hook-aware wrapper. The original +// RunE is captured by closure so the Wrapper chain can still call it +// as the innermost handler. +// +// The wrapper preserves the Run vs RunE distinction: cmd.Run is +// cleared because RunE wins when both are set and leaving a stale Run +// around is a hazard for future maintainers. +func wrapRunE(cmd *cobra.Command, reg *Registry, snapshot CommandViewSource) { + originalRunE := cmd.RunE + originalRun := cmd.Run + cmd.Run = nil + + cmd.RunE = func(c *cobra.Command, args []string) error { + view := snapshot.View(c) + inv := newInvocation(view, args) + + // Detect denial: a denied command's original RunE was already + // replaced by cmdpolicy.Apply with a denyStub that returns + // *output.ExitError wrapping *platform.CommandDeniedError. We + // invoke originalRunE once with a probe-only context (no args + // matter because DisableFlagParsing is set on denied commands) + // to extract its CommandDeniedError, but for V1 we use a + // simpler shortcut: cmdpolicy.Apply itself marks the command + // via cobra annotation; install reads the annotation directly. + populateInvocationDenial(inv, c) + + ctx := c.Context() + if ctx == nil { + ctx = context.Background() + } + + // === Before observers (panic-safe, always run) === + for _, obs := range reg.MatchingObservers(view, platform.Before) { + runObserverSafe(ctx, obs, inv) + } + + // === Denial guard === + // If denied, run the originalRunE directly (it is the denyStub + // installed by cmdpolicy.Apply). The Wrap chain is bypassed. + var err error + if inv.DeniedByPolicy() { + err = invokeOriginal(ctx, c, args, originalRunE, originalRun) + } else { + // Compose matching Wrappers around the originalRunE. Each + // Wrapper is wrapped with a thin namespacing shim so any + // *AbortError returned has its HookName replaced with the + // framework-namespaced WrapperEntry.Name -- a plugin + // cannot impersonate another plugin's hook even by + // accident. + matched := reg.MatchingWrappers(view) + wrappers := make([]platform.Wrapper, 0, len(matched)) + for _, w := range matched { + // Each plugin Wrapper is wrapped twice: once by the + // namespacing shim (AbortError attribution) and once + // by the panic shim (so a plugin panic becomes a + // structured hook envelope instead of crashing the + // process). + wrappers = append(wrappers, recoverWrap(w.Name, namespacedWrap(w.Name, w.Fn))) + } + composed := ComposeWrappers(wrappers) + // Pass the wrapRunE-local args, not i.Args(): the original + // RunE must see what cobra parsed, not what a hook may have + // observed via the read-only interface. + finalHandler := composed(func(c2 context.Context, _ platform.Invocation) error { + return invokeOriginal(c2, c, args, originalRunE, originalRun) + }) + err = finalHandler(ctx, inv) + } + + // Convert AbortError -> *output.ExitError so the envelope writer + // renders the structured "hook" type. + err = wrapAbortError(err) + + inv.setErr(err) + + // === After observers (panic-safe, always run, including + // when err != nil) === + for _, obs := range reg.MatchingObservers(view, platform.After) { + runObserverSafe(ctx, obs, inv) + } + + return err + } +} + +// invokeOriginal runs whatever the original command logic was. If +// originalRunE is non-nil (the common case), use it; otherwise fall +// back to the Run variant. Commands without either are a programming +// error caught at registration time (cmd.Runnable() returns false). +// +// The wrapper-propagated ctx is set on cmd via SetContext *before* the +// inner RunE/Run is invoked, so any context values injected by an +// upstream Wrapper (auth tokens, request-scoped IDs, trace spans, +// cancellation deadlines) reach the original handler. Without this +// hand-off the inner handler would observe c.Context() — the +// pre-wrapper context — and silently lose every value the Wrap chain +// added. +// +// We restore the previous context on return so a single command's +// SetContext mutation cannot leak to sibling dispatches that share the +// same *cobra.Command pointer (cobra reuses the tree across calls in +// long-running embedders). +func invokeOriginal(ctx context.Context, c *cobra.Command, args []string, runE func(*cobra.Command, []string) error, run func(*cobra.Command, []string)) error { + prev := c.Context() + c.SetContext(ctx) + defer c.SetContext(prev) + + if runE != nil { + return runE(c, args) + } + if run != nil { + run(c, args) + return nil + } + return nil +} + +// runObserverSafe invokes an Observer with panic recovery. Observers +// must not break the main flow; their job is side-effect-only and a +// broken plugin should not cascade into a failed CLI run. +func runObserverSafe(ctx context.Context, obs ObserverEntry, inv platform.Invocation) { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(stderr(), "warning: hook %q panicked: %v\n", obs.Name, r) + } + }() + obs.Fn(ctx, inv) +} + +// wrapAbortError converts *platform.AbortError into the equivalent +// *output.ExitError so cmd/root.go's envelope writer emits the right +// JSON structure (type="hook"). Non-AbortError values pass through +// unchanged. +func wrapAbortError(err error) error { + if err == nil { + return nil + } + var ab *platform.AbortError + if !errors.As(err, &ab) { + return err + } + return &output.ExitError{ + Code: output.ExitValidation, + Detail: &output.ErrDetail{ + Type: "hook", + Message: ab.Error(), + Detail: map[string]any{ + "hook_name": ab.HookName, + "reason": ab.Reason, + "reason_code": "aborted", + "detail": ab.Detail, + }, + }, + Err: ab, + } +} + +// recoverWrap wraps a Wrapper so any panic anywhere in the plugin's +// implementation -- including the wrapper FACTORY call (the +// `func(next Handler) Handler` step) and the inner Handler call -- is +// recovered and surfaced as a structured *output.ExitError with +// type="hook" and reason_code="panic". Without this guard, a panicking +// plugin would crash the entire CLI process and break the structured- +// error contract (downstream automation cannot parse a stack trace). +// +// The recovered panic keeps the fully-qualified hook name (the same +// namespacing as namespacedWrap below uses) so on-call can pinpoint +// the offending plugin without grepping logs. +// +// **Why the factory call is inside the deferred recover**: a plugin +// can write something like +// +// func(next Handler) Handler { +// state := mustInit() // panics on bad config +// return func(...) error { ... use state ... } +// } +// +// If `mustInit` panics, the panic happens during composition +// (ComposeWrappers -> ws[i](next)) which runs at invocation time inside +// wrapRunE. Without recovering this branch, the whole CLI crashes. +// We pay a tiny per-invocation cost (one factory call per command +// dispatch) in exchange for total panic isolation. +// +// **Factory-local state lifetime contract**: any value the plugin's +// outer factory captures (`state` in the example above) is now created +// PER INVOCATION of the wrapped command -- it is NOT a one-shot init +// the way Plugin.Install is. Plugins that need long-lived state (a +// connection pool, an LRU cache, a metrics counter) MUST hold it on +// the Plugin struct or in a package-level variable; relying on +// closure-local memoisation inside the wrapper factory will silently +// reset on every command dispatch. +func recoverWrap(fullName string, w platform.Wrapper) platform.Wrapper { + return func(next platform.Handler) platform.Handler { + return func(ctx context.Context, inv platform.Invocation) (returned error) { + defer func() { + if r := recover(); r != nil { + returned = &output.ExitError{ + Code: output.ExitValidation, + Detail: &output.ErrDetail{ + Type: "hook", + Message: fmt.Sprintf("hook %q panicked: %v", fullName, r), + Detail: map[string]any{ + "hook_name": fullName, + "reason_code": "panic", + "reason": fmt.Sprintf("%v", r), + }, + }, + Err: fmt.Errorf("hook %q panic: %v", fullName, r), + } + } + }() + // Construct AFTER the recover is armed so a panicking + // factory becomes a hook envelope instead of a process + // crash. + inner := w(next) + return inner(ctx, inv) + } + } +} + +// namespacedWrap wraps a plugin's Wrapper so any *platform.AbortError it +// returns is replaced with a fresh copy whose HookName is the +// framework-namespaced name (e.g. "policy-plugin.policy"). Plugin +// authors do not need to know their own plugin name; the framework +// attribution is authoritative. +// +// **Why a copy, not mutation**: an AbortError value may be shared +// across concurrent command invocations (e.g. a plugin's package-level +// sentinel). Mutating it would race; copy keeps each invocation's +// attribution isolated. +// +// **Why only top-level AbortError, not wrapped**: a wrapped AbortError +// in a chain via fmt.Errorf("...: %w", ab) would require rebuilding +// the entire chain to substitute the value. The simpler contract -- +// "plugin returns AbortError directly to short-circuit" -- is what we +// document, so we only namespace the top-level case. Wrapped +// AbortErrors keep whatever HookName the plugin set; that is still +// surfaced unchanged by the envelope writer. +func namespacedWrap(fullName string, w platform.Wrapper) platform.Wrapper { + return func(next platform.Handler) platform.Handler { + inner := w(next) + return func(ctx context.Context, inv platform.Invocation) error { + err := inner(ctx, inv) + if err == nil { + return nil + } + if ab, ok := err.(*platform.AbortError); ok { + copied := *ab + copied.HookName = fullName + return &copied + } + return err + } + } +} + +// stderr returns the stderr writer the wrapper uses for safe warnings. +// Indirected through a func so tests can substitute it. +var stderr = func() interface{ Write(p []byte) (int, error) } { + // Avoid pulling os just for stderr access -- the real impl lives + // in install_default.go (see file). The function is overridable + // to keep test isolation tight. + return defaultStderr +} + +// populateInvocationDenial reads the cobra annotation set by +// cmdpolicy.Apply and propagates it onto the framework-internal +// invocation. +// +// V1 contract: a denial is signalled by the cobra annotation +// "lark:policy_denied_layer" being set on the command. The layer +// value is the enforcement layer ("policy" / "strict_mode") that +// gets emitted as detail.layer in the envelope; the source follows +// the annotation "lark:policy_denied_source". +// +// This indirection lets us avoid an import cycle between hook and +// pruning packages. +func populateInvocationDenial(inv *invocation, c *cobra.Command) { + const layerKey = "lark:policy_denied_layer" + const sourceKey = "lark:policy_denied_source" + if c.Annotations == nil { + return + } + layer, ok := c.Annotations[layerKey] + if !ok || layer == "" { + return + } + source := c.Annotations[sourceKey] + inv.setDenial(layer, source) +} diff --git a/internal/hook/install_default.go b/internal/hook/install_default.go new file mode 100644 index 000000000..2c382a76e --- /dev/null +++ b/internal/hook/install_default.go @@ -0,0 +1,11 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package hook + +import "os" + +// defaultStderr is the real os.Stderr writer. Kept in a separate file so +// tests can replace `stderr` (in install.go) with a buffer without +// shadowing this variable. +var defaultStderr = os.Stderr //nolint:forbidigo // framework-level fallback writer; hooks fire before IOStreams plumbing is available diff --git a/internal/hook/install_test.go b/internal/hook/install_test.go new file mode 100644 index 000000000..7f11f2897 --- /dev/null +++ b/internal/hook/install_test.go @@ -0,0 +1,397 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package hook_test + +import ( + "bytes" + "context" + "errors" + "fmt" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/hook" + "github.com/larksuite/cli/internal/output" +) + +// fakeViewSource is a minimal CommandView for tests -- it ignores the +// cobra command and returns a fixed view. +type fakeViewSource struct{ view platform.CommandView } + +func (f fakeViewSource) View(*cobra.Command) platform.CommandView { return f.view } + +type fakeView struct { + path string + risk string +} + +func (v fakeView) Path() string { return v.path } +func (v fakeView) Domain() string { return "" } +func (v fakeView) Risk() (platform.Risk, bool) { return platform.Risk(v.risk), v.risk != "" } +func (v fakeView) Identities() []platform.Identity { return nil } +func (v fakeView) Annotation(string) (string, bool) { return "", false } + +func makeLeaf(use string) *cobra.Command { + return &cobra.Command{Use: use, RunE: func(*cobra.Command, []string) error { return nil }} +} + +// Observers fire on Before AND After even when RunE returns an error. +// This is the failure-path observability contract -- After must always +// run so audit hooks see completion regardless of outcome. +func TestInstall_observersBeforeAndAfterAlwaysRun(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + leaf := &cobra.Command{Use: "+x", RunE: func(*cobra.Command, []string) error { + return errors.New("boom") + }} + root.AddCommand(leaf) + + reg := hook.NewRegistry() + + var seen []string + reg.AddObserver(hook.ObserverEntry{ + Name: "before", When: platform.Before, Selector: platform.All(), + Fn: func(_ context.Context, inv platform.Invocation) { + seen = append(seen, fmt.Sprintf("before:err=%v", inv.Err())) + }, + }) + reg.AddObserver(hook.ObserverEntry{ + Name: "after", When: platform.After, Selector: platform.All(), + Fn: func(_ context.Context, inv platform.Invocation) { + seen = append(seen, fmt.Sprintf("after:err=%v", inv.Err())) + }, + }) + + hook.Install(root, reg, fakeViewSource{view: fakeView{path: "+x"}}) + + err := leaf.RunE(leaf, nil) + if err == nil || err.Error() != "boom" { + t.Fatalf("expected RunE to return original error, got %v", err) + } + + wantBefore := "before:err=" // before fires with Err still nil + wantAfter := "after:err=boom" // after sees the failed RunE error + if len(seen) != 2 || seen[0] != wantBefore || seen[1] != wantAfter { + t.Fatalf("observer ordering / Err propagation broken, got %v", seen) + } +} + +// Wrap chain composes outermost-first (registration order). A regression +// that inverts the composition would change which Wrapper short-circuits +// first for safety-sensitive layers. +func TestInstall_wrapperChainOrder(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + var order []string + leaf := &cobra.Command{Use: "+x", RunE: func(*cobra.Command, []string) error { + order = append(order, "RunE") + return nil + }} + root.AddCommand(leaf) + + reg := hook.NewRegistry() + reg.AddWrapper(hook.WrapperEntry{ + Name: "outer", Selector: platform.All(), + Fn: func(next platform.Handler) platform.Handler { + return func(ctx context.Context, inv platform.Invocation) error { + order = append(order, "outer-before") + err := next(ctx, inv) + order = append(order, "outer-after") + return err + } + }, + }) + reg.AddWrapper(hook.WrapperEntry{ + Name: "inner", Selector: platform.All(), + Fn: func(next platform.Handler) platform.Handler { + return func(ctx context.Context, inv platform.Invocation) error { + order = append(order, "inner-before") + err := next(ctx, inv) + order = append(order, "inner-after") + return err + } + }, + }) + + hook.Install(root, reg, fakeViewSource{view: fakeView{path: "+x"}}) + if err := leaf.RunE(leaf, nil); err != nil { + t.Fatalf("RunE: %v", err) + } + want := []string{"outer-before", "inner-before", "RunE", "inner-after", "outer-after"} + if !equalStrings(order, want) { + t.Fatalf("Wrapper order = %v, want %v", order, want) + } +} + +// Denial guard physical isolation: the most safety-critical invariant. +// A denied command must NEVER reach a Wrap chain. We register a Wrap +// that, given the chance, would silently allow the call (return nil, +// don't call next, no AbortError). The guard must skip Wrap entirely +// so the denyStub's error reaches the caller. +// +// Without this guarantee, any plugin Wrap matching All() could +// bypass user policy / strict-mode denials. +func TestInstall_denialGuard_physicalIsolation(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + denyStubCalled := false + leaf := &cobra.Command{ + Use: "+forbidden", + RunE: func(*cobra.Command, []string) error { + denyStubCalled = true + return errors.New("CommandPruned: this is the denyStub") + }, + Annotations: map[string]string{ + "lark:policy_denied_layer": "policy", + "lark:policy_denied_source": "yaml", + }, + } + root.AddCommand(leaf) + + reg := hook.NewRegistry() + + maliciousWrapCalled := false + reg.AddWrapper(hook.WrapperEntry{ + Name: "malicious", Selector: platform.All(), + Fn: func(next platform.Handler) platform.Handler { + return func(ctx context.Context, inv platform.Invocation) error { + maliciousWrapCalled = true + return nil // suppress the denial + } + }, + }) + + hook.Install(root, reg, fakeViewSource{view: fakeView{path: "+forbidden"}}) + + err := leaf.RunE(leaf, nil) + if maliciousWrapCalled { + t.Errorf("denial guard violated: Wrap was invoked on a denied command") + } + if !denyStubCalled { + t.Errorf("denyStub (original RunE) should still run on the denial path") + } + if err == nil { + t.Fatalf("denyStub error must propagate, got nil") + } +} + +// Observer panics must not break the main flow. The guard converts the +// panic to a stderr warning and continues; the command still runs. +func TestInstall_observerPanicIsolated(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + runECalled := false + leaf := &cobra.Command{Use: "+x", RunE: func(*cobra.Command, []string) error { + runECalled = true + return nil + }} + root.AddCommand(leaf) + + reg := hook.NewRegistry() + reg.AddObserver(hook.ObserverEntry{ + Name: "buggy", When: platform.Before, Selector: platform.All(), + Fn: func(context.Context, platform.Invocation) { + panic("plugin author wrote bad code") + }, + }) + + // Capture stderr to make sure the warning was emitted. Restore the + // previous sink so a subsequent test isn't stuck writing into our + // discarded buffer. + t.Cleanup(hook.SetStderrForTesting(&bytes.Buffer{})) // discard + + hook.Install(root, reg, fakeViewSource{view: fakeView{path: "+x"}}) + if err := leaf.RunE(leaf, nil); err != nil { + t.Fatalf("RunE should still succeed when an Observer panicked, got %v", err) + } + if !runECalled { + t.Errorf("RunE must execute despite Observer panic") + } +} + +// A Wrapper returning AbortError surfaces as *output.ExitError with +// type="hook" so cmd/root.go's envelope writer can serialise it. +func TestInstall_abortErrorBecomesExitError(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + leaf := makeLeaf("+x") + root.AddCommand(leaf) + + reg := hook.NewRegistry() + reg.AddWrapper(hook.WrapperEntry{ + Name: "rejecter", Selector: platform.All(), + Fn: func(_ platform.Handler) platform.Handler { + return func(context.Context, platform.Invocation) error { + return &platform.AbortError{ + HookName: "rejecter", + Reason: "policy says no", + } + } + }, + }) + + hook.Install(root, reg, fakeViewSource{view: fakeView{path: "+x"}}) + + err := leaf.RunE(leaf, nil) + if err == nil { + t.Fatalf("Wrap aborted; expected error") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("AbortError must convert to *output.ExitError, got %T %+v", err, err) + } + if exitErr.Detail.Type != "hook" { + t.Errorf("envelope type = %q, want hook", exitErr.Detail.Type) + } + detail := exitErr.Detail.Detail.(map[string]any) + if detail["reason_code"] != "aborted" || detail["hook_name"] != "rejecter" { + t.Errorf("detail = %+v", detail) + } + // The original AbortError must still be reachable via errors.As. + var ab *platform.AbortError + if !errors.As(err, &ab) { + t.Errorf("error chain should expose *platform.AbortError") + } +} + +// namespacedWrap must not mutate a shared *AbortError. A plugin author +// might construct a sentinel at package scope and return it from +// multiple Wrap invocations; mutating it would let attribution leak +// across concurrent command runs and would also race. +// +// Production path test: drive a real cobra.Command through Install +// so namespacedWrap inside install.go is exercised. The plugin returns +// the same sentinel pointer twice. Both observed envelopes must have +// the framework-namespaced HookName, but the sentinel's own HookName +// must remain whatever the plugin originally set. +func TestInstall_namespacedWrap_doesNotMutateSentinel(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + leafA := makeLeaf("+a") + leafB := makeLeaf("+b") + root.AddCommand(leafA) + root.AddCommand(leafB) + + sentinel := &platform.AbortError{HookName: "sentinel-original", Reason: "no"} + + reg := hook.NewRegistry() + // Two Wrappers, different namespaced names, return the SAME + // sentinel. + reg.AddWrapper(hook.WrapperEntry{ + Name: "plugin-a.wrap", + Selector: platform.ByCommandPath("+a"), + Fn: func(platform.Handler) platform.Handler { + return func(context.Context, platform.Invocation) error { return sentinel } + }, + }) + reg.AddWrapper(hook.WrapperEntry{ + Name: "plugin-b.wrap", + Selector: platform.ByCommandPath("+b"), + Fn: func(platform.Handler) platform.Handler { + return func(context.Context, platform.Invocation) error { return sentinel } + }, + }) + + hook.Install(root, reg, fakeViewSourceByPath{}) + + // Invoke both leaves. + errA := leafA.RunE(leafA, nil) + errB := leafB.RunE(leafB, nil) + + // Sentinel must remain untouched: the framework must copy before + // rewriting HookName. + if sentinel.HookName != "sentinel-original" { + t.Errorf("sentinel AbortError was mutated: HookName = %q", sentinel.HookName) + } + + // Each invocation's envelope must carry the correct namespace -- + // proving the framework DID set the right name on its own copy. + checkHookName(t, errA, "plugin-a.wrap") + checkHookName(t, errB, "plugin-b.wrap") +} + +// fakeViewSourceByPath returns a CommandView whose Path matches the +// leaf's Use field (so ByCommandPath selectors discriminate). +type fakeViewSourceByPath struct{} + +func (fakeViewSourceByPath) View(c *cobra.Command) platform.CommandView { + return fakeView{path: c.Use} +} + +func checkHookName(t *testing.T, err error, want string) { + t.Helper() + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected ExitError, got %T", err) + } + detail := exitErr.Detail.Detail.(map[string]any) + if detail["hook_name"] != want { + t.Errorf("hook_name = %v, want %v", detail["hook_name"], want) + } +} + +// A Before observer mutating inv.Args() must not affect what the +// original RunE sees: pins the slice-level read-only contract. +func TestInstall_argsNotMutableByObserver(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + + var seenByRunE []string + leaf := &cobra.Command{ + Use: "+echo", + RunE: func(_ *cobra.Command, args []string) error { + seenByRunE = append([]string(nil), args...) + return nil + }, + } + root.AddCommand(leaf) + + reg := hook.NewRegistry() + reg.AddObserver(hook.ObserverEntry{ + Name: "tamper", When: platform.Before, Selector: platform.All(), + Fn: func(_ context.Context, inv platform.Invocation) { + got := inv.Args() + if len(got) > 0 { + got[0] = "HIJACKED" + } + }, + }) + hook.Install(root, reg, fakeViewSource{view: fakeView{path: "+echo"}}) + + originalArgs := []string{"hello", "world"} + if err := leaf.RunE(leaf, originalArgs); err != nil { + t.Fatalf("RunE returned %v", err) + } + if !equalStrings(seenByRunE, originalArgs) { + t.Fatalf("RunE saw mutated args: got %v, want %v", seenByRunE, originalArgs) + } + if originalArgs[0] != "hello" { + t.Fatalf("caller's original args were mutated: %v", originalArgs) + } +} + +// Root command (no parent) must never be wrapped -- it dispatches help +// and other framework concerns. The root has no RunE so we instead +// verify the root's children are wrapped while the root itself remains +// untouched (RunE stays nil). +func TestInstall_rootStaysUntouched(t *testing.T) { + root := &cobra.Command{Use: "lark-cli"} + leaf := makeLeaf("+x") + root.AddCommand(leaf) + reg := hook.NewRegistry() + hook.Install(root, reg, fakeViewSource{view: fakeView{path: "+x"}}) + if root.RunE != nil { + t.Fatalf("root.RunE should remain nil after Install") + } + if leaf.RunE == nil { + t.Fatalf("child leaf.RunE must remain non-nil (wrapped)") + } +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/hook/invocation.go b/internal/hook/invocation.go new file mode 100644 index 000000000..804755bc0 --- /dev/null +++ b/internal/hook/invocation.go @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package hook + +import ( + "time" + + "github.com/larksuite/cli/extension/platform" +) + +// invocation is the framework-side concrete implementation of +// platform.Invocation. All setters are unexported so plugin code +// (which only sees the platform.Invocation interface) cannot mutate +// state. +type invocation struct { + cmd platform.CommandView + args []string + started time.Time + err error + + denied bool + layer string + source string +} + +// newInvocation copies args so the read-only platform.Invocation +// contract holds at the slice level: a hook cannot mutate the args +// the original RunE will see. +func newInvocation(cmd platform.CommandView, args []string) *invocation { + argsCopy := append([]string(nil), args...) + return &invocation{ + cmd: cmd, + args: argsCopy, + started: time.Now(), + } +} + +// --- platform.Invocation read interface --- + +func (i *invocation) Cmd() platform.CommandView { return i.cmd } + +// Args returns a fresh copy every call; see newInvocation. +func (i *invocation) Args() []string { + out := make([]string, len(i.args)) + copy(out, i.args) + return out +} +func (i *invocation) Started() time.Time { return i.started } +func (i *invocation) Err() error { return i.err } + +func (i *invocation) DeniedByPolicy() bool { return i.denied } +func (i *invocation) DenialLayer() string { return i.layer } +func (i *invocation) DenialPolicySource() string { + return i.source +} + +// --- framework-internal setters (unexported) --- + +func (i *invocation) setDenial(layer, source string) { + i.denied = true + i.layer = layer + i.source = source +} + +func (i *invocation) setErr(err error) { + i.err = err +} diff --git a/internal/hook/registry.go b/internal/hook/registry.go new file mode 100644 index 000000000..90235c270 --- /dev/null +++ b/internal/hook/registry.go @@ -0,0 +1,184 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package hook + +import ( + "context" + "sync" + + "github.com/larksuite/cli/extension/platform" +) + +// ObserverEntry stores one Observer registration. The full hook name +// (already namespaced with plugin prefix by the caller) lets diagnostic +// output point at the responsible plugin. +type ObserverEntry struct { + Name string + When platform.When + Selector platform.Selector + Fn platform.Observer +} + +// WrapperEntry stores one Wrapper registration. Wrappers compose in +// registration order; the outermost (registered first) runs first. +type WrapperEntry struct { + Name string + Selector platform.Selector + Fn platform.Wrapper +} + +// LifecycleEntry stores one lifecycle handler. Selector is unused +// (lifecycle events are global), but Name is preserved for diagnostics. +type LifecycleEntry struct { + Name string + Event platform.LifecycleEvent + Fn platform.LifecycleHandler +} + +// Registry holds all registered hooks. The framework constructs one +// Registry per binary execution; concurrent reads after Install +// commits are safe because the maps are not mutated thereafter. Writes +// (during Install) are serialised by the internalplatform. +type Registry struct { + mu sync.RWMutex + + observers []ObserverEntry + wrappers []WrapperEntry + lifecycles []LifecycleEntry +} + +// NewRegistry returns an empty Registry. +func NewRegistry() *Registry { return &Registry{} } + +// Observers returns a snapshot of all registered observers. Order is +// registration order. Diagnostic commands (config plugins show) call +// this to enumerate every hook attached to the binary. +func (r *Registry) Observers() []ObserverEntry { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]ObserverEntry, len(r.observers)) + copy(out, r.observers) + return out +} + +// Wrappers returns a snapshot of all registered wrappers. Order is +// registration order (outermost first). +func (r *Registry) Wrappers() []WrapperEntry { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]WrapperEntry, len(r.wrappers)) + copy(out, r.wrappers) + return out +} + +// Lifecycles returns a snapshot of all registered lifecycle handlers. +func (r *Registry) Lifecycles() []LifecycleEntry { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]LifecycleEntry, len(r.lifecycles)) + copy(out, r.lifecycles) + return out +} + +// AddObserver registers an Observer. Caller is responsible for namespacing +// (the platformhost does this). Nil fn is silently skipped -- the staging +// Registrar should reject invalid registrations before this layer. +func (r *Registry) AddObserver(e ObserverEntry) { + if e.Fn == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.observers = append(r.observers, e) +} + +// AddWrapper registers a Wrapper. +func (r *Registry) AddWrapper(e WrapperEntry) { + if e.Fn == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.wrappers = append(r.wrappers, e) +} + +// AddLifecycle registers a LifecycleHandler. +func (r *Registry) AddLifecycle(e LifecycleEntry) { + if e.Fn == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.lifecycles = append(r.lifecycles, e) +} + +// MatchingObservers returns the observers whose selector matches the +// command at the given When stage. Result is a slice (not a generator) +// so callers can iterate without holding the registry lock. +func (r *Registry) MatchingObservers(cmd platform.CommandView, when platform.When) []ObserverEntry { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]ObserverEntry, 0, len(r.observers)) + for _, e := range r.observers { + if e.When == when && e.Selector != nil && e.Selector(cmd) { + out = append(out, e) + } + } + return out +} + +// MatchingWrappers returns the wrappers whose selector matches the +// command. Order matches registration order. +func (r *Registry) MatchingWrappers(cmd platform.CommandView) []WrapperEntry { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]WrapperEntry, 0, len(r.wrappers)) + for _, e := range r.wrappers { + if e.Selector != nil && e.Selector(cmd) { + out = append(out, e) + } + } + return out +} + +// LifecycleHandlers returns handlers for a given event in registration +// order. +func (r *Registry) LifecycleHandlers(event platform.LifecycleEvent) []LifecycleEntry { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]LifecycleEntry, 0, len(r.lifecycles)) + for _, e := range r.lifecycles { + if e.Event == event { + out = append(out, e) + } + } + return out +} + +// ComposeWrappers folds a slice of Wrappers into a single Wrapper that +// applies them in registration order (outermost first). Empty slice +// returns the identity Wrapper (next as-is). Inspired by +// grpc.ChainUnaryInterceptor. +func ComposeWrappers(ws []platform.Wrapper) platform.Wrapper { + if len(ws) == 0 { + return identityWrapper + } + return func(next platform.Handler) platform.Handler { + // Build from the inside out so the first registered Wrapper + // ends up outermost. + for i := len(ws) - 1; i >= 0; i-- { + next = ws[i](next) + } + return next + } +} + +// identityWrapper is the no-op wrapper used when there are no matching +// Wrappers for a command -- callers can always compose into +// next(ctx, inv) without a nil check. +func identityWrapper(next platform.Handler) platform.Handler { + return func(ctx context.Context, inv platform.Invocation) error { + return next(ctx, inv) + } +} diff --git a/internal/hook/testing.go b/internal/hook/testing.go new file mode 100644 index 000000000..611257e1b --- /dev/null +++ b/internal/hook/testing.go @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package hook + +import "io" + +// SetStderrForTesting redirects the hook layer's warning output to a +// custom writer and returns a restore function the caller MUST defer +// (or pass to `t.Cleanup`). Without the restore step, a later test in +// the same binary would inherit the override and either race on a +// shared bytes.Buffer or write user-visible garbage into a real test +// stderr. +// +// Production code never calls this; the default writer is os.Stderr +// via defaultStderr. +func SetStderrForTesting(w io.Writer) (restore func()) { + prev := stderr + stderr = func() interface{ Write(p []byte) (int, error) } { + return w + } + return func() { stderr = prev } +} diff --git a/internal/hook/walk.go b/internal/hook/walk.go new file mode 100644 index 000000000..fe5b0dbf9 --- /dev/null +++ b/internal/hook/walk.go @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package hook + +import "github.com/spf13/cobra" + +// walkTree applies fn to every command in the tree, depth-first. Hidden +// commands are visited too -- they can still be invoked. +func walkTree(root *cobra.Command, fn func(*cobra.Command)) { + if root == nil { + return + } + fn(root) + for _, c := range root.Commands() { + walkTree(c, fn) + } +} diff --git a/internal/identitydiag/diagnostics.go b/internal/identitydiag/diagnostics.go new file mode 100644 index 000000000..f8eb648d6 --- /dev/null +++ b/internal/identitydiag/diagnostics.go @@ -0,0 +1,325 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package identitydiag + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + larkauth "github.com/larksuite/cli/internal/auth" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" +) + +const ( + StatusReady = "ready" + StatusNotConfigured = "not_configured" + StatusMissing = "missing" + StatusNeedsRefresh = "needs_refresh" + StatusVerifyFailed = "verify_failed" +) + +// verifyTimeout bounds each network call made during --verify so that a +// hanging server cannot wedge `auth status --verify` or `doctor`. Mirrors +// the 10s timeout used by the doctor endpoint probe. +const verifyTimeout = 10 * time.Second + +// Result describes the independently usable bot and user identities. +type Result struct { + Bot Identity `json:"bot"` + User Identity `json:"user"` +} + +// Identity is a single identity diagnostic result. +type Identity struct { + Status string `json:"status"` + Available bool `json:"available"` + Verified *bool `json:"verified,omitempty"` + Message string `json:"message,omitempty"` + Hint string `json:"hint,omitempty"` + OpenID string `json:"openId,omitempty"` + AppName string `json:"appName,omitempty"` + UserName string `json:"userName,omitempty"` + TokenStatus string `json:"tokenStatus,omitempty"` + Scope string `json:"scope,omitempty"` + ExpiresAt string `json:"expiresAt,omitempty"` + RefreshExpiresAt string `json:"refreshExpiresAt,omitempty"` + GrantedAt string `json:"grantedAt,omitempty"` +} + +// Diagnose checks bot and user identities separately. When verify is false, +// it only reports local readiness and skips server calls. +func Diagnose(ctx context.Context, f *cmdutil.Factory, cfg *core.CliConfig, verify bool) Result { + if ctx == nil { + ctx = context.Background() + } + return Result{ + Bot: diagnoseBot(ctx, f, cfg, verify), + User: diagnoseUser(ctx, f, cfg, verify), + } +} + +func diagnoseBot(ctx context.Context, f *cmdutil.Factory, cfg *core.CliConfig, verify bool) Identity { + if cfg == nil || cfg.AppID == "" { + return Identity{ + Status: StatusNotConfigured, + Message: "Bot identity: not configured (missing app config)", + Hint: "run: lark-cli config --help", + } + } + if !cfg.CanBot() { + return Identity{ + Status: StatusNotConfigured, + Message: "Bot identity: not configured (bot identity is not available in current credential context)", + Hint: "check strict mode or the active credential provider", + } + } + if cfg.SupportedIdentities == 0 && !credential.HasRealAppSecret(cfg.AppSecret) { + return Identity{ + Status: StatusNotConfigured, + Message: "Bot identity: not configured (missing app secret or bot token)", + Hint: "run: lark-cli config --help", + } + } + + id := Identity{ + Status: StatusReady, + Available: true, + Message: "Bot identity: ready", + } + if !verify { + return id + } + + token, err := resolveBotToken(ctx, f, cfg) + if err != nil { + status := StatusVerifyFailed + var unavailable *credential.TokenUnavailableError + if errors.As(err, &unavailable) { + status = StatusNotConfigured + } + return Identity{ + Status: status, + Verified: boolPtr(false), + Message: "Bot identity: " + StatusMessage(status) + ": " + err.Error(), + Hint: "check app credentials or the active credential provider", + } + } + + info, err := fetchBotInfo(ctx, f, cfg, token) + if err != nil { + return Identity{ + Status: StatusVerifyFailed, + Verified: boolPtr(false), + Message: "Bot identity: verify failed: " + err.Error(), + Hint: "check app credentials, scopes, network, or tenant access token configuration", + } + } + + id.Verified = boolPtr(true) + id.OpenID = info.OpenID + id.AppName = info.AppName + return id +} + +func diagnoseUser(ctx context.Context, f *cmdutil.Factory, cfg *core.CliConfig, verify bool) Identity { + if cfg == nil || cfg.AppID == "" { + return Identity{ + Status: StatusNotConfigured, + Message: "User identity: not configured (missing app config)", + Hint: "run: lark-cli config --help", + } + } + if cfg.UserOpenId == "" { + return Identity{ + Status: StatusMissing, + Message: "User identity: missing (no user logged in)", + Hint: "run: lark-cli auth login --help", + } + } + + id := Identity{ + UserName: cfg.UserName, + OpenID: cfg.UserOpenId, + } + stored := larkauth.GetStoredToken(cfg.AppID, cfg.UserOpenId) + if stored == nil { + id.Status = StatusMissing + id.Message = "User identity: missing (no token in keychain for " + cfg.UserOpenId + ")" + id.Hint = "run: lark-cli auth login --help" + return id + } + + fillTokenFields(&id, stored) + switch larkauth.TokenStatus(stored) { + case "valid": + id.Status = StatusReady + id.Available = true + id.Message = "User identity: ready" + case "needs_refresh": + id.Status = StatusNeedsRefresh + id.Available = true + id.Message = "User identity: needs refresh (will auto-refresh on next user API call)" + default: + id.Status = StatusMissing + id.Message = "User identity: missing (refresh token expired)" + id.Hint = "run: lark-cli auth login --help" + return id + } + + if !verify { + return id + } + + markVerifyFailed := func(reason, hint string) Identity { + id.Status = StatusVerifyFailed + id.Available = false + id.Verified = boolPtr(false) + id.Message = "User identity: verify failed: " + reason + if hint != "" { + id.Hint = hint + } + return id + } + + httpClient, err := f.HttpClient() + if err != nil { + return markVerifyFailed("create HTTP client: "+err.Error(), "") + } + token, err := larkauth.GetValidAccessToken(httpClient, larkauth.NewUATCallOptions(cfg, f.IOStreams.ErrOut)) + if err != nil { + return markVerifyFailed("token unusable: "+err.Error(), "run: lark-cli auth login --help") + } + sdk, err := f.LarkClient() + if err != nil { + return markVerifyFailed("SDK init failed: "+err.Error(), "") + } + verifyCtx, cancel := context.WithTimeout(ctx, verifyTimeout) + defer cancel() + if err := larkauth.VerifyUserToken(verifyCtx, sdk, token); err != nil { + return markVerifyFailed("server rejected token: "+err.Error(), "run: lark-cli auth login --help") + } + + id.Verified = boolPtr(true) + if id.Status == StatusReady { + id.Message = "User identity: ready" + } else { + id.Message = "User identity: needs refresh (server verification succeeded after refresh)" + } + return id +} + +func resolveBotToken(ctx context.Context, f *cmdutil.Factory, cfg *core.CliConfig) (string, error) { + if f == nil || f.Credential == nil { + return "", &credential.TokenUnavailableError{Type: credential.TokenTypeTAT} + } + result, err := f.Credential.ResolveToken(ctx, credential.NewTokenSpec(core.AsBot, cfg.AppID)) + if err != nil { + return "", err + } + if result == nil || result.Token == "" { + return "", &credential.TokenUnavailableError{Type: credential.TokenTypeTAT} + } + return result.Token, nil +} + +type botInfo struct { + OpenID string + AppName string +} + +func fetchBotInfo(ctx context.Context, f *cmdutil.Factory, cfg *core.CliConfig, token string) (*botInfo, error) { + httpClient, err := f.HttpClient() + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, verifyTimeout) + defer cancel() + url := strings.TrimRight(core.ResolveEndpoints(cfg.Brand).Open, "/") + "/open-apis/bot/v3/info" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + // /open-apis/bot/v3/info returns `{code, msg, bot: {...}}` — the bot + // payload is under "bot", not "data" as the newer Lark API convention. + var envelope struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + OpenID string `json:"open_id"` + AppName string `json:"app_name"` + } `json:"bot"` + } + parseErr := json.Unmarshal(body, &envelope) + + if resp.StatusCode >= 400 { + // Lark error responses are usually `{code, msg}` envelopes even on + // non-2xx — surface them when present so callers see why bot auth + // was rejected, not just the bare HTTP code. + if parseErr == nil && envelope.Code != 0 { + return nil, fmt.Errorf("HTTP %d: [%d] %s", resp.StatusCode, envelope.Code, envelope.Msg) + } + return nil, fmt.Errorf("HTTP %d", resp.StatusCode) + } + if parseErr != nil { + return nil, fmt.Errorf("parse response: %w", parseErr) + } + if envelope.Code != 0 { + return nil, fmt.Errorf("[%d] %s", envelope.Code, envelope.Msg) + } + if envelope.Data.OpenID == "" { + return nil, errors.New("open_id is empty") + } + return &botInfo{OpenID: envelope.Data.OpenID, AppName: envelope.Data.AppName}, nil +} + +func fillTokenFields(id *Identity, token *larkauth.StoredUAToken) { + id.TokenStatus = larkauth.TokenStatus(token) + id.Scope = token.Scope + id.ExpiresAt = formatMillis(token.ExpiresAt) + id.RefreshExpiresAt = formatMillis(token.RefreshExpiresAt) + id.GrantedAt = formatMillis(token.GrantedAt) +} + +func formatMillis(ms int64) string { + if ms <= 0 { + return "" + } + return time.UnixMilli(ms).Format(time.RFC3339) +} + +func StatusMessage(status string) string { + switch status { + case StatusNotConfigured: + return "not configured" + case StatusVerifyFailed: + return "verify failed" + case StatusNeedsRefresh: + return "needs refresh" + case StatusMissing: + return "missing" + default: + return status + } +} + +func boolPtr(v bool) *bool { + return &v +} diff --git a/internal/identitydiag/diagnostics_test.go b/internal/identitydiag/diagnostics_test.go new file mode 100644 index 000000000..6d288e3bb --- /dev/null +++ b/internal/identitydiag/diagnostics_test.go @@ -0,0 +1,350 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package identitydiag + +import ( + "context" + "net/http" + "strings" + "testing" + "time" + + larkauth "github.com/larksuite/cli/internal/auth" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/httpmock" + "github.com/zalando/go-keyring" +) + +func TestDiagnose_NoUserReportsBotReadyAndUserMissing(t *testing.T) { + cfg := &core.CliConfig{AppID: "test-app", AppSecret: "secret", Brand: core.BrandFeishu} + f, _, _, _ := cmdutil.TestFactory(t, cfg) + + got := Diagnose(context.Background(), f, cfg, false) + if got.Bot.Status != StatusReady || !got.Bot.Available { + t.Fatalf("bot = %#v, want ready and available", got.Bot) + } + if got.User.Status != StatusMissing || got.User.Available { + t.Fatalf("user = %#v, want missing and unavailable", got.User) + } +} + +func TestDiagnose_BotIdentityNotConfigured(t *testing.T) { + cfg := &core.CliConfig{AppID: "test-app", Brand: core.BrandFeishu} + f, _, _, _ := cmdutil.TestFactory(t, cfg) + + got := Diagnose(context.Background(), f, cfg, false) + if got.Bot.Status != StatusNotConfigured || got.Bot.Available { + t.Fatalf("bot = %#v, want not_configured and unavailable", got.Bot) + } +} + +func TestDiagnose_VerifyBotIdentity(t *testing.T) { + cfg := &core.CliConfig{AppID: "test-app", AppSecret: "secret", Brand: core.BrandFeishu} + f, _, _, reg := cmdutil.TestFactory(t, cfg) + stub := &httpmock.Stub{ + Method: http.MethodGet, + URL: "/open-apis/bot/v3/info", + Body: map[string]interface{}{ + "code": 0, + "msg": "ok", + "bot": map[string]interface{}{ + "open_id": "ou_bot", + "app_name": "diagnostic bot", + }, + }, + } + reg.Register(stub) + + got := Diagnose(context.Background(), f, cfg, true) + if got.Bot.Status != StatusReady || !got.Bot.Available { + t.Fatalf("bot = %#v, want ready and available", got.Bot) + } + if got.Bot.Verified == nil || !*got.Bot.Verified { + t.Fatalf("bot verified = %v, want true", got.Bot.Verified) + } + if got.Bot.OpenID != "ou_bot" || got.Bot.AppName != "diagnostic bot" { + t.Fatalf("bot info = %#v, want open id and app name", got.Bot) + } + if got := stub.CapturedHeaders.Get("Authorization"); got != "Bearer test-token" { + t.Fatalf("Authorization = %q, want %q", got, "Bearer test-token") + } +} + +func TestDiagnose_VerifyUserIdentity(t *testing.T) { + keyring.MockInit() + t.Setenv("HOME", t.TempDir()) + t.Setenv("LARKSUITE_CLI_DATA_DIR", t.TempDir()) + + cfg := &core.CliConfig{ + AppID: "test-app-user", + AppSecret: "secret", + Brand: core.BrandFeishu, + UserOpenId: "ou_user", + UserName: "tester", + } + now := time.Now() + if err := larkauth.SetStoredToken(&larkauth.StoredUAToken{ + AppId: cfg.AppID, + UserOpenId: cfg.UserOpenId, + AccessToken: "user-access-token", + RefreshToken: "refresh-token", + ExpiresAt: now.Add(time.Hour).UnixMilli(), + RefreshExpiresAt: now.Add(24 * time.Hour).UnixMilli(), + GrantedAt: now.Add(-time.Hour).UnixMilli(), + Scope: "offline_access", + }); err != nil { + t.Fatalf("SetStoredToken() error = %v", err) + } + + f, _, _, reg := cmdutil.TestFactory(t, cfg) + reg.Register(&httpmock.Stub{ + Method: http.MethodGet, + URL: "/open-apis/bot/v3/info", + Body: map[string]interface{}{ + "code": 0, + "msg": "ok", + "bot": map[string]interface{}{ + "open_id": "ou_bot", + "app_name": "diagnostic bot", + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: http.MethodGet, + URL: larkauth.PathUserInfoV1, + Body: map[string]interface{}{ + "code": 0, + "msg": "ok", + }, + }) + + got := Diagnose(context.Background(), f, cfg, true) + if got.User.Status != StatusReady || !got.User.Available { + t.Fatalf("user = %#v, want ready and available", got.User) + } + if got.User.Verified == nil || !*got.User.Verified { + t.Fatalf("user verified = %v, want true", got.User.Verified) + } + if got.User.OpenID != "ou_user" || got.User.UserName != "tester" { + t.Fatalf("user = %#v, want user identity details", got.User) + } +} + +func TestDiagnose_VerifyBotIdentity_HTTPErrorSurfacesEnvelope(t *testing.T) { + cfg := &core.CliConfig{AppID: "test-app", AppSecret: "secret", Brand: core.BrandFeishu} + f, _, _, reg := cmdutil.TestFactory(t, cfg) + reg.Register(&httpmock.Stub{ + Method: http.MethodGet, + URL: "/open-apis/bot/v3/info", + Status: http.StatusUnauthorized, + Body: map[string]interface{}{ + "code": 99991663, + "msg": "app ticket invalid", + }, + }) + + got := Diagnose(context.Background(), f, cfg, true) + if got.Bot.Status != StatusVerifyFailed || got.Bot.Available { + t.Fatalf("bot = %#v, want verify_failed and unavailable", got.Bot) + } + if got.Bot.Verified == nil || *got.Bot.Verified { + t.Fatalf("bot verified = %v, want false", got.Bot.Verified) + } + if !strings.Contains(got.Bot.Message, "401") || !strings.Contains(got.Bot.Message, "99991663") { + t.Fatalf("bot message = %q, want both HTTP code and envelope code", got.Bot.Message) + } +} + +func TestDiagnose_VerifyBotIdentity_BusinessErrorCode(t *testing.T) { + cfg := &core.CliConfig{AppID: "test-app", AppSecret: "secret", Brand: core.BrandFeishu} + f, _, _, reg := cmdutil.TestFactory(t, cfg) + reg.Register(&httpmock.Stub{ + Method: http.MethodGet, + URL: "/open-apis/bot/v3/info", + Body: map[string]interface{}{ + "code": 10013, + "msg": "scope not granted", + }, + }) + + got := Diagnose(context.Background(), f, cfg, true) + if got.Bot.Status != StatusVerifyFailed || got.Bot.Available { + t.Fatalf("bot = %#v, want verify_failed and unavailable", got.Bot) + } + if !strings.Contains(got.Bot.Message, "10013") || !strings.Contains(got.Bot.Message, "scope not granted") { + t.Fatalf("bot message = %q, want envelope code/msg", got.Bot.Message) + } +} + +func TestDiagnose_VerifyUserIdentity_ServerRejects(t *testing.T) { + keyring.MockInit() + t.Setenv("HOME", t.TempDir()) + t.Setenv("LARKSUITE_CLI_DATA_DIR", t.TempDir()) + + cfg := &core.CliConfig{ + AppID: "test-app-reject", + AppSecret: "secret", + Brand: core.BrandFeishu, + UserOpenId: "ou_user", + UserName: "tester", + } + now := time.Now() + if err := larkauth.SetStoredToken(&larkauth.StoredUAToken{ + AppId: cfg.AppID, + UserOpenId: cfg.UserOpenId, + AccessToken: "user-access-token", + RefreshToken: "refresh-token", + ExpiresAt: now.Add(time.Hour).UnixMilli(), + RefreshExpiresAt: now.Add(24 * time.Hour).UnixMilli(), + GrantedAt: now.Add(-time.Hour).UnixMilli(), + Scope: "offline_access", + }); err != nil { + t.Fatalf("SetStoredToken() error = %v", err) + } + + f, _, _, reg := cmdutil.TestFactory(t, cfg) + reg.Register(&httpmock.Stub{ + Method: http.MethodGet, + URL: "/open-apis/bot/v3/info", + Body: map[string]interface{}{ + "code": 0, + "bot": map[string]interface{}{"open_id": "ou_bot", "app_name": "bot"}, + }, + }) + reg.Register(&httpmock.Stub{ + Method: http.MethodGet, + URL: larkauth.PathUserInfoV1, + Body: map[string]interface{}{ + "code": 99991661, + "msg": "access token invalid", + }, + }) + + got := Diagnose(context.Background(), f, cfg, true) + if got.User.Status != StatusVerifyFailed || got.User.Available { + t.Fatalf("user = %#v, want verify_failed and unavailable", got.User) + } + if got.User.Verified == nil || *got.User.Verified { + t.Fatalf("user verified = %v, want false", got.User.Verified) + } + if !strings.Contains(got.User.Message, "server rejected token") { + t.Fatalf("user message = %q, want 'server rejected token'", got.User.Message) + } +} + +func TestDiagnose_UserIdentityExpired(t *testing.T) { + keyring.MockInit() + t.Setenv("HOME", t.TempDir()) + t.Setenv("LARKSUITE_CLI_DATA_DIR", t.TempDir()) + + cfg := &core.CliConfig{ + AppID: "test-app-expired", + AppSecret: "secret", + Brand: core.BrandFeishu, + UserOpenId: "ou_expired", + UserName: "tester", + } + now := time.Now() + if err := larkauth.SetStoredToken(&larkauth.StoredUAToken{ + AppId: cfg.AppID, + UserOpenId: cfg.UserOpenId, + AccessToken: "user-access-token", + RefreshToken: "refresh-token", + ExpiresAt: now.Add(-time.Hour).UnixMilli(), + RefreshExpiresAt: now.Add(-time.Minute).UnixMilli(), + GrantedAt: now.Add(-24 * time.Hour).UnixMilli(), + Scope: "offline_access", + }); err != nil { + t.Fatalf("SetStoredToken() error = %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, cfg) + got := Diagnose(context.Background(), f, cfg, false) + if got.User.Status != StatusMissing || got.User.Available { + t.Fatalf("user = %#v, want missing and unavailable", got.User) + } + if got.User.Hint == "" { + t.Fatalf("user hint is empty, want re-login hint") + } +} + +func TestDiagnose_BotIdentityStrictUserOnly(t *testing.T) { + // SupportedIdentities = SupportsUser (1) only — bot path should be + // reported as not_configured even though an app secret is present. + cfg := &core.CliConfig{ + AppID: "test-app", + AppSecret: "secret", + Brand: core.BrandFeishu, + SupportedIdentities: 1, + } + f, _, _, _ := cmdutil.TestFactory(t, cfg) + + got := Diagnose(context.Background(), f, cfg, false) + if got.Bot.Status != StatusNotConfigured || got.Bot.Available { + t.Fatalf("bot = %#v, want not_configured and unavailable", got.Bot) + } +} + +func TestDiagnose_UserIdentityMissingAppConfig(t *testing.T) { + cfg := &core.CliConfig{Brand: core.BrandFeishu} + f, _, _, _ := cmdutil.TestFactory(t, cfg) + + got := Diagnose(context.Background(), f, cfg, false) + if got.User.Status != StatusNotConfigured || got.User.Available { + t.Fatalf("user = %#v, want not_configured and unavailable", got.User) + } +} + +func TestStatusMessage(t *testing.T) { + cases := map[string]string{ + StatusReady: StatusReady, + StatusNotConfigured: "not configured", + StatusVerifyFailed: "verify failed", + StatusNeedsRefresh: "needs refresh", + StatusMissing: "missing", + "unknown": "unknown", + } + for in, want := range cases { + if got := StatusMessage(in); got != want { + t.Errorf("StatusMessage(%q) = %q, want %q", in, got, want) + } + } +} + +func TestDiagnose_UserIdentityNeedsRefresh(t *testing.T) { + keyring.MockInit() + t.Setenv("HOME", t.TempDir()) + t.Setenv("LARKSUITE_CLI_DATA_DIR", t.TempDir()) + + cfg := &core.CliConfig{ + AppID: "test-app-needs-refresh", + AppSecret: "secret", + Brand: core.BrandFeishu, + UserOpenId: "ou_refresh", + UserName: "tester", + } + now := time.Now() + if err := larkauth.SetStoredToken(&larkauth.StoredUAToken{ + AppId: cfg.AppID, + UserOpenId: cfg.UserOpenId, + AccessToken: "user-access-token", + RefreshToken: "refresh-token", + ExpiresAt: now.Add(time.Minute).UnixMilli(), + RefreshExpiresAt: now.Add(24 * time.Hour).UnixMilli(), + GrantedAt: now.Add(-time.Hour).UnixMilli(), + Scope: "offline_access", + }); err != nil { + t.Fatalf("SetStoredToken() error = %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, cfg) + got := Diagnose(context.Background(), f, cfg, false) + if got.User.Status != StatusNeedsRefresh || !got.User.Available { + t.Fatalf("user = %#v, want needs_refresh and available", got.User) + } + if got.User.TokenStatus != "needs_refresh" { + t.Fatalf("token status = %q, want needs_refresh", got.User.TokenStatus) + } +} diff --git a/internal/platform/doc.go b/internal/platform/doc.go new file mode 100644 index 000000000..1a70e594c --- /dev/null +++ b/internal/platform/doc.go @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +// Package platformhost is the bootstrap-time orchestrator that turns the +// global plugin registry (extension/platform.RegisteredPlugins) into: +// +// - a populated internal/hook.Registry (Observer / Wrapper / Lifecycle) +// - a list of cmdpolicy.PluginRule contributions (one per plugin that +// called r.Restrict) +// +// Two key invariants: +// +// - **Atomic install.** A plugin's Install() runs against a staging +// Registrar; only when Install returns nil AND validateSelf passes +// does the host commit the staged hooks/rule. Partial install never +// reaches the live Registry, so a half-loaded plugin cannot leave +// stale Observer / Wrap entries behind. +// +// - **FailurePolicy honoured.** Each plugin declares FailOpen or +// FailClosed. FailOpen plugins are skipped on error (warning to +// stderr); FailClosed plugins abort the whole bootstrap. The +// framework also enforces the Restricts↔FailClosed consistency +// contract (a Restricts=true plugin with FailOpen would be a +// silent security hole and is rejected during install). +// +// The host returns: +// +// - a *hook.Registry ready to install on the command tree +// - a []cmdpolicy.PluginRule for the pruning resolver +// - an error when a FailClosed plugin failed +package internalplatform diff --git a/internal/platform/error.go b/internal/platform/error.go new file mode 100644 index 000000000..8ee037aa6 --- /dev/null +++ b/internal/platform/error.go @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package internalplatform + +import "fmt" + +// PluginInstallError is the typed install-time failure. ReasonCode comes +// from the closed enum in the design doc (section 5.3 reason_code +// table). Cause carries the underlying error, if any, so consumers can +// errors.As to inspect it. +type PluginInstallError struct { + PluginName string + ReasonCode string + Reason string + Cause error +} + +func (e *PluginInstallError) Error() string { + prefix := fmt.Sprintf("plugin %q (%s)", e.PluginName, e.ReasonCode) + if e.Reason != "" { + prefix += ": " + e.Reason + } + if e.Cause != nil { + prefix += ": " + e.Cause.Error() + } + return prefix +} + +func (e *PluginInstallError) Unwrap() error { return e.Cause } + +// ReasonCodes for PluginInstallError. The closed enum is referenced by +// the design doc's hard-constraint #15 (reason_code enum closure) and +// drives the JSON envelope's error.detail.reason_code field. +const ( + ReasonInvalidPluginName = "invalid_plugin_name" + ReasonPluginNamePanic = "plugin_name_panic" + ReasonInvalidHookName = "invalid_hook_name" + ReasonDuplicateHookName = "duplicate_hook_name" + ReasonInvalidHookRegister = "invalid_hook_registration" + ReasonInvalidRule = "invalid_rule" + ReasonDoubleRestrict = "double_restrict" + ReasonRestrictsMismatch = "restricts_mismatch" + ReasonCapabilityUnmet = "capability_unmet" + ReasonCapabilitiesPanic = "capabilities_panic" + // ReasonInvalidCapability flags a plugin authoring error in + // Capabilities() output -- e.g. a syntactically malformed + // RequiredCLIVersion string. This is distinct from + // ReasonCapabilityUnmet (legitimate version mismatch): an authoring + // bug must NOT be hidden by FailurePolicy=FailOpen, so this code is + // classified as untrusted-config and aborts unconditionally. + ReasonInvalidCapability = "invalid_capability" + ReasonInstallFailed = "install_failed" + ReasonInstallPanic = "install_panic" + ReasonDuplicatePluginName = "duplicate_plugin_name" + ReasonMultipleRestricts = "multiple_restrict_plugins" +) diff --git a/internal/platform/host.go b/internal/platform/host.go new file mode 100644 index 000000000..2f13cf59d --- /dev/null +++ b/internal/platform/host.go @@ -0,0 +1,344 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package internalplatform + +import ( + "errors" + "fmt" + "io" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/cmdpolicy" + "github.com/larksuite/cli/internal/hook" +) + +// PluginInfo is the metadata of a successfully-installed plugin, +// captured at install time so diagnostic commands (config plugins show) +// can enumerate plugins without re-calling potentially panic-prone +// plugin methods at display time. +type PluginInfo struct { + Name string + Version string + Capabilities platform.Capabilities +} + +// InstallResult is the output of InstallAll. Registry is ready for +// hook.Install; PluginRules feeds into cmdpolicy.Resolve as the +// "plugin contribution" half of the resolver input. Plugins lists +// every plugin that committed successfully (FailOpen-skipped plugins +// are absent), for downstream diagnostics. +type InstallResult struct { + Registry *hook.Registry + PluginRules []cmdpolicy.PluginRule + Plugins []PluginInfo +} + +// InstallAll runs every registered plugin through the staging +// Registrar, validates, and commits the survivors. FailOpen plugins +// that fail are skipped with a warning; the first FailClosed failure +// stops the loop and returns the error. +// +// Plugins are processed in registration order so the result is +// deterministic. +// +// errOut receives warnings about FailOpen plugin skips. nil errOut +// means warnings are dropped (useful in tests). +func InstallAll(plugins []platform.Plugin, errOut io.Writer) (*InstallResult, error) { + if errOut == nil { + errOut = io.Discard + } + result := &InstallResult{ + Registry: hook.NewRegistry(), + } + + // Detect duplicate Plugin.Name. We do this up-front so the error + // surfaces before any Install runs; design hard-constraint #7 + // treats this as configuration error (fail-closed regardless of + // individual FailurePolicy). + if err := detectDuplicateNames(plugins); err != nil { + return nil, err + } + + for _, p := range plugins { + name, nameErr := safeCallName(p) + if nameErr != nil { + // Fail-closed on bad Name: we don't know the plugin's + // FailurePolicy yet (it's behind Capabilities, and we + // cannot trust Capabilities() before Name() succeeds). + return nil, nameErr + } + if err := installOne(name, p, result); err != nil { + // Some errors must abort regardless of FailurePolicy + // because they imply the plugin's FailurePolicy itself + // cannot be trusted (e.g. the consistency check between + // Restricts and FailClosed failed). + if isUntrustedConfigError(err) { + return nil, err + } + policy := readFailurePolicy(p) + switch policy { + case platform.FailClosed: + return nil, err + default: + fmt.Fprintf(errOut, "warning: plugin %q skipped: %v\n", name, err) + continue + } + } + } + + return result, nil +} + +// isUntrustedConfigError flags errors where the plugin's declared +// FailurePolicy is itself part of the misconfiguration. For these the +// host MUST abort unconditionally; honouring an FailOpen declaration on +// a misconfigured Restricts plugin would defeat the whole point of the +// consistency check. +func isUntrustedConfigError(err error) bool { + var pi *PluginInstallError + if !errors.As(err, &pi) { + return false + } + return pi.ReasonCode == ReasonRestrictsMismatch || + pi.ReasonCode == ReasonInvalidPluginName || + pi.ReasonCode == ReasonPluginNamePanic || + pi.ReasonCode == ReasonDuplicatePluginName || + pi.ReasonCode == ReasonInvalidCapability +} + +// installOne handles a single plugin: build a staging Registrar, call +// Install, run validateSelf, and on success commit to the live +// Registry / PluginRules. Any error means staged data is discarded. +func installOne(name string, p platform.Plugin, result *InstallResult) error { + caps, capsErr := safeCallCapabilities(p) + if capsErr != nil { + return capsErr + } + + // FailurePolicy is a closed enum. An out-of-range value almost + // always means the plugin author shipped FailurePolicy(2)/etc. by + // mistake, and the host's switch on caps.FailurePolicy below would + // silently treat the unknown value as FailOpen — defeating the + // security boundary the policy was meant to express. Reject up + // front with ReasonInvalidCapability (classified as + // untrusted-config, so the abort is unconditional). + if caps.FailurePolicy != platform.FailOpen && caps.FailurePolicy != platform.FailClosed { + return &PluginInstallError{ + PluginName: name, + ReasonCode: ReasonInvalidCapability, + Reason: fmt.Sprintf("FailurePolicy=%d is not a recognised value (expected FailOpen or FailClosed)", + caps.FailurePolicy), + } + } + + // Strict consistency check: Restricts=true must pair with + // FailClosed (design hard-constraint #6). + if caps.Restricts && caps.FailurePolicy != platform.FailClosed { + return &PluginInstallError{ + PluginName: name, + ReasonCode: ReasonRestrictsMismatch, + Reason: "Restricts=true requires FailurePolicy=FailClosed", + } + } + + // Version compatibility check. Two distinct failure modes: + // + // 1. Parse error (constraint is malformed, e.g. ">=abc") + // -> ReasonInvalidCapability, classified as untrusted-config + // so the host aborts unconditionally. This is a plugin + // authoring bug; FailurePolicy must NOT mask it. + // + // 2. Legitimate version mismatch (constraint parses fine but + // current CLI does not satisfy it) + // -> ReasonCapabilityUnmet, honours FailurePolicy. A FailOpen + // plugin announcing ">=2.0" against a 1.x CLI is skipped + // with a warning; a FailClosed plugin aborts. + if ok, err := satisfiesRequiredCLIVersion(currentCLIVersion(), caps.RequiredCLIVersion); err != nil { + return &PluginInstallError{ + PluginName: name, + ReasonCode: ReasonInvalidCapability, + Reason: err.Error(), + } + } else if !ok { + return &PluginInstallError{ + PluginName: name, + ReasonCode: ReasonCapabilityUnmet, + Reason: fmt.Sprintf("CLI version %q does not satisfy plugin requirement %q", + currentCLIVersion(), caps.RequiredCLIVersion), + } + } + + staging := newStagingRegistrar(name) + if err := safeCallInstall(p, staging); err != nil { + // Don't double-wrap typed PluginInstallError -- safeCallInstall + // already produces install_panic for recovered panics, and a + // re-wrap would bury the precise reason_code under + // install_failed. + var pi *PluginInstallError + if errors.As(err, &pi) { + return err + } + return &PluginInstallError{ + PluginName: name, + ReasonCode: ReasonInstallFailed, + Reason: "Install returned error", + Cause: err, + } + } + + if err := staging.validateSelf(caps); err != nil { + return err + } + + // Commit staged data atomically. + for _, e := range staging.stagedObservers { + result.Registry.AddObserver(e) + } + for _, e := range staging.stagedWrappers { + result.Registry.AddWrapper(e) + } + for _, e := range staging.stagedLifecycles { + result.Registry.AddLifecycle(e) + } + if staging.rule != nil { + result.PluginRules = append(result.PluginRules, cmdpolicy.PluginRule{ + PluginName: name, + Rule: staging.rule, + }) + } + + // Record the plugin in the inventory. Version is fetched here under + // a recover-wrapped helper so a plugin's Version() panic does not + // abort the install we just committed. + result.Plugins = append(result.Plugins, PluginInfo{ + Name: name, + Version: safeCallVersion(p), + Capabilities: caps, + }) + return nil +} + +// safeCallVersion mirrors safeCallName but for Plugin.Version. Failures +// degrade to the empty string -- Version is informational, not a hard +// contract field, so we never want it to abort installation. +func safeCallVersion(p platform.Plugin) (v string) { + defer func() { + if r := recover(); r != nil { + v = "" + } + }() + return p.Version() +} + +// readFailurePolicy reads Capabilities and returns the policy, falling +// back to FailClosed if Capabilities() panics. Defensive default: we +// assume the worst-case (safety-sensitive) when we cannot read the +// declaration. +// +// **Implementation note**: FailClosed must be the value set BEFORE the +// panic-prone call. The zero value of platform.FailurePolicy is +// FailOpen, so a "just return after recover" pattern would silently +// flip the safe-default to FailOpen on panic -- the opposite of what +// the comment claims. +func readFailurePolicy(p platform.Plugin) (policy platform.FailurePolicy) { + policy = platform.FailClosed + defer func() { _ = recover() }() + policy = p.Capabilities().FailurePolicy + return +} + +// safeCallName recovers from a panic in Plugin.Name() and surfaces it +// as a typed PluginInstallError. Without recovery, a buggy plugin could +// crash the binary before main has a chance to emit a JSON envelope. +func safeCallName(p platform.Plugin) (string, error) { + var ( + name string + err error + ) + func() { + defer func() { + if r := recover(); r != nil { + err = &PluginInstallError{ + PluginName: "", + ReasonCode: ReasonPluginNamePanic, + Reason: fmt.Sprintf("Plugin.Name() panicked: %v", r), + } + } + }() + name = p.Name() + }() + if err != nil { + return "", err + } + if !hookNamePattern.MatchString(name) { + return "", &PluginInstallError{ + PluginName: name, + ReasonCode: ReasonInvalidPluginName, + Reason: fmt.Sprintf("Plugin.Name() %q must match ^[a-z0-9][a-z0-9-]*$ (no dots)", name), + } + } + return name, nil +} + +// safeCallCapabilities mirrors safeCallName for Capabilities(). +func safeCallCapabilities(p platform.Plugin) (caps platform.Capabilities, err error) { + defer func() { + if r := recover(); r != nil { + err = &PluginInstallError{ + PluginName: pluginNameOrPlaceholder(p), + ReasonCode: ReasonCapabilitiesPanic, + Reason: fmt.Sprintf("Plugin.Capabilities() panicked: %v", r), + } + } + }() + caps = p.Capabilities() + return caps, nil +} + +// safeCallInstall mirrors safeCallName for Install(). Install panics +// become install_panic errors, not crashes. +func safeCallInstall(p platform.Plugin, r platform.Registrar) (err error) { + defer func() { + if rec := recover(); rec != nil { + err = &PluginInstallError{ + PluginName: pluginNameOrPlaceholder(p), + ReasonCode: ReasonInstallPanic, + Reason: fmt.Sprintf("Install panicked: %v", rec), + } + } + }() + return p.Install(r) +} + +func pluginNameOrPlaceholder(p platform.Plugin) string { + defer func() { _ = recover() }() + if n := p.Name(); n != "" { + return n + } + return "" +} + +// detectDuplicateNames scans the plugin slice for repeated Plugin.Name +// values. Returns a typed PluginInstallError on the first duplicate so +// the bootstrap aborts. +func detectDuplicateNames(plugins []platform.Plugin) error { + seen := map[string]bool{} + for _, p := range plugins { + name, err := safeCallName(p) + if err != nil { + // Don't double-report: let installOne handle naming + // errors per-plugin so we get the same code path. + continue + } + if seen[name] { + return &PluginInstallError{ + PluginName: name, + ReasonCode: ReasonDuplicatePluginName, + Reason: fmt.Sprintf("duplicate Plugin.Name() %q across plugins", name), + } + } + seen[name] = true + } + return nil +} diff --git a/internal/platform/host_test.go b/internal/platform/host_test.go new file mode 100644 index 000000000..13b1f574e --- /dev/null +++ b/internal/platform/host_test.go @@ -0,0 +1,391 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package internalplatform_test + +import ( + "bytes" + "context" + "errors" + "strings" + "testing" + + "github.com/larksuite/cli/extension/platform" + internalplatform "github.com/larksuite/cli/internal/platform" +) + +// happyPlugin is a textbook plugin: declares Capabilities, calls a few +// Registrar methods, returns nil. The install pipeline must accept it. +type happyPlugin struct{ name string } + +func (p happyPlugin) Name() string { return p.name } +func (p happyPlugin) Version() string { return "1.0.0" } +func (p happyPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{ + FailurePolicy: platform.FailOpen, + } +} +func (p happyPlugin) Install(r platform.Registrar) error { + r.Observe(platform.Before, "audit-pre", platform.All(), + func(context.Context, platform.Invocation) {}) + r.Wrap("policy", platform.All(), + func(next platform.Handler) platform.Handler { + return func(ctx context.Context, inv platform.Invocation) error { + return next(ctx, inv) + } + }) + r.On(platform.Shutdown, "flush", + func(context.Context, *platform.LifecycleContext) error { return nil }) + return nil +} + +func TestInstallAll_happyPlugin(t *testing.T) { + result, err := internalplatform.InstallAll([]platform.Plugin{happyPlugin{name: "audit"}}, nil) + if err != nil { + t.Fatalf("InstallAll: %v", err) + } + if result.Registry == nil { + t.Fatalf("registry should be populated") + } + if len(result.PluginRules) != 0 { + t.Errorf("happy plugin did not call Restrict; rules should be empty") + } + // Cross-check: observers, wrappers, lifecycles got staged through to the live Registry. + if len(result.Registry.MatchingObservers(fakeView{}, platform.Before)) != 1 { + t.Errorf("Before observer not committed") + } + if len(result.Registry.MatchingWrappers(fakeView{})) != 1 { + t.Errorf("Wrapper not committed") + } + if len(result.Registry.LifecycleHandlers(platform.Shutdown)) != 1 { + t.Errorf("Shutdown lifecycle not committed") + } +} + +// fakeView satisfies platform.CommandView for selector lookups in the +// platformhost tests; All() matches everything so the type can stay +// trivial. +type fakeView struct{} + +func (fakeView) Path() string { return "" } +func (fakeView) Domain() string { return "" } +func (fakeView) Risk() (platform.Risk, bool) { return "", false } +func (fakeView) Identities() []platform.Identity { return nil } +func (fakeView) Annotation(string) (string, bool) { return "", false } + +// A FailClosed plugin whose Install returns an error must abort +// InstallAll. Design hard-constraint #6. +type failClosedPlugin struct{} + +func (failClosedPlugin) Name() string { return "secaudit" } +func (failClosedPlugin) Version() string { return "1.0.0" } +func (failClosedPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{ + FailurePolicy: platform.FailClosed, + } +} +func (failClosedPlugin) Install(platform.Registrar) error { + return errors.New("upstream unreachable") +} + +func TestInstallAll_failClosedAborts(t *testing.T) { + _, err := internalplatform.InstallAll([]platform.Plugin{failClosedPlugin{}}, nil) + if err == nil { + t.Fatalf("FailClosed install error should abort") + } + var pi *internalplatform.PluginInstallError + if !errors.As(err, &pi) { + t.Fatalf("error must be *PluginInstallError, got %T", err) + } + if pi.ReasonCode != internalplatform.ReasonInstallFailed { + t.Errorf("ReasonCode = %q, want install_failed", pi.ReasonCode) + } +} + +// FailOpen install failure logs a warning and skips this plugin; other +// plugins still get installed. +type failOpenPlugin struct{} + +func (failOpenPlugin) Name() string { return "audit-broken" } +func (failOpenPlugin) Version() string { return "1.0.0" } +func (failOpenPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{FailurePolicy: platform.FailOpen} +} +func (failOpenPlugin) Install(platform.Registrar) error { + return errors.New("could not connect") +} + +func TestInstallAll_failOpenSkips(t *testing.T) { + var buf bytes.Buffer + plugins := []platform.Plugin{ + failOpenPlugin{}, + happyPlugin{name: "audit"}, + } + result, err := internalplatform.InstallAll(plugins, &buf) + if err != nil { + t.Fatalf("FailOpen failure must not abort, got %v", err) + } + if !strings.Contains(buf.String(), "audit-broken") { + t.Errorf("FailOpen warning should mention plugin name, got %q", buf.String()) + } + // Second plugin's observer should be present. + if len(result.Registry.MatchingObservers(fakeView{}, platform.Before)) != 1 { + t.Errorf("happy plugin's observer should still be installed after first plugin skipped") + } +} + +// Restricts=true with FailOpen is a configuration error: a policy +// plugin that silently disappears under FailOpen would erase the +// security boundary. The host must reject this combo BEFORE Install +// runs. +type misconfiguredRestrictPlugin struct{} + +func (misconfiguredRestrictPlugin) Name() string { return "secaudit" } +func (misconfiguredRestrictPlugin) Version() string { return "1.0.0" } +func (misconfiguredRestrictPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{ + Restricts: true, // policy plugin + FailurePolicy: platform.FailOpen, // contradicts safety contract + } +} +func (misconfiguredRestrictPlugin) Install(platform.Registrar) error { return nil } + +func TestInstallAll_restrictsRequiresFailClosed(t *testing.T) { + _, err := internalplatform.InstallAll([]platform.Plugin{misconfiguredRestrictPlugin{}}, nil) + if err == nil { + t.Fatalf("Restricts+FailOpen must abort") + } + var pi *internalplatform.PluginInstallError + if !errors.As(err, &pi) || pi.ReasonCode != internalplatform.ReasonRestrictsMismatch { + t.Fatalf("ReasonCode = %v, want restricts_mismatch", pi) + } +} + +// Restricts=true but Install didn't call r.Restrict -> mismatch. +type lyingRestrictPlugin struct{} + +func (lyingRestrictPlugin) Name() string { return "p" } +func (lyingRestrictPlugin) Version() string { return "1.0.0" } +func (lyingRestrictPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{ + Restricts: true, + FailurePolicy: platform.FailClosed, + } +} +func (lyingRestrictPlugin) Install(platform.Registrar) error { + // Forgot to call r.Restrict. + return nil +} + +func TestInstallAll_restrictsDeclaredButNotCalled(t *testing.T) { + _, err := internalplatform.InstallAll([]platform.Plugin{lyingRestrictPlugin{}}, nil) + if err == nil { + t.Fatalf("missing Restrict call when declared must fail") + } + var pi *internalplatform.PluginInstallError + if !errors.As(err, &pi) || pi.ReasonCode != internalplatform.ReasonRestrictsMismatch { + t.Fatalf("ReasonCode = %v, want restricts_mismatch", pi) + } +} + +// Plugin that panics inside Install must NOT crash the binary -- the +// host recovers and converts the panic into a typed install_panic. +type panicInstallPlugin struct{} + +func (panicInstallPlugin) Name() string { return "panicker" } +func (panicInstallPlugin) Version() string { return "1.0.0" } +func (panicInstallPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{FailurePolicy: platform.FailClosed} +} +func (panicInstallPlugin) Install(platform.Registrar) error { + panic("boom") +} + +func TestInstallAll_installPanicRecovered(t *testing.T) { + _, err := internalplatform.InstallAll([]platform.Plugin{panicInstallPlugin{}}, nil) + if err == nil { + t.Fatalf("Install panic should surface as error") + } + var pi *internalplatform.PluginInstallError + if !errors.As(err, &pi) || pi.ReasonCode != internalplatform.ReasonInstallPanic { + t.Fatalf("ReasonCode = %v, want install_panic", pi) + } +} + +// Two plugins with the same Name must abort before any Install runs. +func TestInstallAll_duplicatePluginName(t *testing.T) { + _, err := internalplatform.InstallAll([]platform.Plugin{ + happyPlugin{name: "audit"}, + happyPlugin{name: "audit"}, + }, nil) + if err == nil { + t.Fatalf("duplicate Plugin.Name must abort") + } + var pi *internalplatform.PluginInstallError + if !errors.As(err, &pi) || pi.ReasonCode != internalplatform.ReasonDuplicatePluginName { + t.Fatalf("ReasonCode = %v, want duplicate_plugin_name", pi) + } +} + +// Plugin with an invalid Name (contains "." or starts with a hyphen) +// must abort with invalid_plugin_name. The dot ban is critical -- the +// "{plugin}.{hook}" namespace join would become ambiguous if dots were +// allowed inside Plugin.Name(). +type badNamePlugin struct{ n string } + +func (p badNamePlugin) Name() string { return p.n } +func (p badNamePlugin) Version() string { return "1.0.0" } +func (p badNamePlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{FailurePolicy: platform.FailClosed} +} +func (p badNamePlugin) Install(platform.Registrar) error { return nil } + +func TestInstallAll_invalidPluginName(t *testing.T) { + cases := []string{"with.dot", "", "-leading-hyphen", "UPPER"} + for _, name := range cases { + t.Run(name, func(t *testing.T) { + _, err := internalplatform.InstallAll([]platform.Plugin{badNamePlugin{n: name}}, nil) + if err == nil { + t.Fatalf("invalid name %q should abort", name) + } + var pi *internalplatform.PluginInstallError + if !errors.As(err, &pi) || pi.ReasonCode != internalplatform.ReasonInvalidPluginName { + t.Fatalf("ReasonCode = %v, want invalid_plugin_name", pi) + } + }) + } +} + +// Plugin's Install registers two hooks with the same name -- the +// staging Registrar rejects the second one with duplicate_hook_name. +type duplicateHookPlugin struct{} + +func (duplicateHookPlugin) Name() string { return "dup" } +func (duplicateHookPlugin) Version() string { return "1.0.0" } +func (duplicateHookPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{FailurePolicy: platform.FailClosed} +} +func (duplicateHookPlugin) Install(r platform.Registrar) error { + r.Observe(platform.Before, "x", platform.All(), func(context.Context, platform.Invocation) {}) + r.Observe(platform.After, "x", platform.All(), func(context.Context, platform.Invocation) {}) + return nil +} + +func TestInstallAll_duplicateHookName(t *testing.T) { + _, err := internalplatform.InstallAll([]platform.Plugin{duplicateHookPlugin{}}, nil) + if err == nil { + t.Fatalf("duplicate hookName within same plugin must abort") + } + var pi *internalplatform.PluginInstallError + if !errors.As(err, &pi) || pi.ReasonCode != internalplatform.ReasonDuplicateHookName { + t.Fatalf("ReasonCode = %v, want duplicate_hook_name", pi) + } +} + +// Restrict contributes a rule to result.PluginRules so the pruning +// resolver can pick it up. Exercise the full path. +type restrictPlugin struct{ rule *platform.Rule } + +func (p restrictPlugin) Name() string { return "secaudit" } +func (p restrictPlugin) Version() string { return "1.0.0" } +func (p restrictPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{ + Restricts: true, + FailurePolicy: platform.FailClosed, + } +} +func (p restrictPlugin) Install(r platform.Registrar) error { + r.Restrict(p.rule) + return nil +} + +func TestInstallAll_restrictPropagatesRule(t *testing.T) { + rule := &platform.Rule{ + Name: "secaudit-policy", + MaxRisk: "read", + Allow: []string{"docs/**"}, + Deny: []string{"docs/+delete-doc"}, + Identities: []platform.Identity{"bot"}, + } + result, err := internalplatform.InstallAll([]platform.Plugin{restrictPlugin{rule: rule}}, nil) + if err != nil { + t.Fatalf("InstallAll: %v", err) + } + if len(result.PluginRules) != 1 { + t.Fatalf("expected 1 plugin rule, got %d", len(result.PluginRules)) + } + stored := result.PluginRules[0].Rule + if stored == nil { + t.Fatalf("stored rule is nil") + } + + // stagingRegistrar.Restrict defensively clones the plugin-supplied + // rule so a misbehaving plugin can't mutate it after Install + // returns. The clone must carry identical contents but live on a + // distinct pointer. + if stored == rule { + t.Errorf("stored rule should be a clone, got identical pointer") + } + if stored.Name != rule.Name || stored.MaxRisk != rule.MaxRisk { + t.Errorf("stored rule lost data: %+v", stored) + } + if got, want := len(stored.Allow), len(rule.Allow); got != want { + t.Errorf("stored Allow len = %d, want %d", got, want) + } + + // Verify the clone is actually isolated: mutating the plugin's + // rule after install must not change the stored one. + rule.Allow[0] = "evil/**" + rule.Deny = append(rule.Deny, "extra/**") + if stored.Allow[0] == "evil/**" { + t.Errorf("Allow slice aliased plugin storage") + } + if len(stored.Deny) != 1 { + t.Errorf("Deny slice aliased plugin storage: %v", stored.Deny) + } + + if result.PluginRules[0].PluginName != "secaudit" { + t.Errorf("PluginName = %q", result.PluginRules[0].PluginName) + } +} + +// Atomic install: a plugin whose validation fails AFTER it registered +// some hooks must NOT leak those hooks into the live registry. The +// staging buffer is the atomicity boundary. +type partiallyRegisterThenFailPlugin struct{} + +func (partiallyRegisterThenFailPlugin) Name() string { return "partial" } +func (partiallyRegisterThenFailPlugin) Version() string { return "1.0.0" } +func (partiallyRegisterThenFailPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{ + Restricts: true, // declares Restrict but won't call it + FailurePolicy: platform.FailClosed, + } +} +func (partiallyRegisterThenFailPlugin) Install(r platform.Registrar) error { + r.Observe(platform.Before, "would-leak", platform.All(), + func(context.Context, platform.Invocation) {}) + // validateSelf will fail because Restricts=true but Restrict + // was not called -- this is the atomic-rollback case. + return nil +} + +func TestInstallAll_atomicRollback(t *testing.T) { + _, err := internalplatform.InstallAll( + []platform.Plugin{partiallyRegisterThenFailPlugin{}, happyPlugin{name: "audit"}}, + nil, + ) + if err == nil { + t.Fatalf("partial plugin should abort (FailClosed)") + } + // We cannot check Registry contents here because InstallAll + // returns nil on failure; the rollback invariant is "nothing the + // failing plugin staged ever reached a live Registry", which is + // proven by the fact that we got nil back. A weaker but useful + // check: even if we passed a happy second plugin, the loop must + // have stopped at the first FailClosed failure. + var pi *internalplatform.PluginInstallError + if !errors.As(err, &pi) { + t.Fatalf("error must be *PluginInstallError, got %T", err) + } +} diff --git a/internal/platform/inventory.go b/internal/platform/inventory.go new file mode 100644 index 000000000..1127f9f46 --- /dev/null +++ b/internal/platform/inventory.go @@ -0,0 +1,264 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package internalplatform + +import ( + "strings" + "sync" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/hook" +) + +// HookEntry is the displayable form of one registered hook. +type HookEntry struct { + Name string `json:"name"` + When string `json:"when,omitempty"` // observers only + Event string `json:"event,omitempty"` // lifecycle only +} + +// PluginEntry collects everything one plugin contributed. +type PluginEntry struct { + Name string + Version string + Capabilities CapabilitiesView + + // Rule is non-nil only when the plugin called r.Restrict. + Rule *RuleView + + Observers []HookEntry + Wrappers []HookEntry + Lifecycles []HookEntry +} + +// CapabilitiesView mirrors platform.Capabilities for display. We keep a +// separate struct so the JSON shape stays under our control and does +// not drift with extension/platform. +type CapabilitiesView struct { + Restricts bool `json:"restricts"` + FailurePolicy string `json:"failure_policy"` + RequiredCLIVersion string `json:"required_cli_version,omitempty"` +} + +// NewCapabilitiesView converts a platform.Capabilities value into the +// display struct. +func NewCapabilitiesView(c platform.Capabilities) CapabilitiesView { + return CapabilitiesView{ + Restricts: c.Restricts, + FailurePolicy: failurePolicyLabel(c.FailurePolicy), + RequiredCLIVersion: c.RequiredCLIVersion, + } +} + +func failurePolicyLabel(p platform.FailurePolicy) string { + switch p { + case platform.FailOpen: + return "FailOpen" + case platform.FailClosed: + return "FailClosed" + } + return "" +} + +// RuleView is the displayable form of a Plugin.Restrict contribution. +type RuleView struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Allow []string `json:"allow"` + Deny []string `json:"deny"` + MaxRisk string `json:"max_risk"` + Identities []string `json:"identities"` + AllowUnannotated bool `json:"allow_unannotated"` +} + +// Inventory is the full snapshot. +type Inventory struct { + Plugins []PluginEntry +} + +// PluginInventorySource is the minimum slice of PluginInfo BuildInventory needs. +type PluginInventorySource struct { + Name string + Version string + Capabilities platform.Capabilities +} + +// RuleInventorySource is the minimum slice of cmdpolicy.PluginRule +// BuildInventory needs. Kept as plain strings to avoid an import +// cycle with cmdpolicy (the caller converts platform.Risk / Identity +// to string at the boundary). +type RuleInventorySource struct { + PluginName string + Allow []string + Deny []string + MaxRisk string + Identities []string + RuleName string + Desc string + AllowUnannotated bool +} + +// BuildInventory assembles an Inventory from the parts produced by +// InstallAll: the plugin metadata list, the hook registry (may be nil +// when no hooks were registered), and the plugin rules. +// +// Hooks are attributed to plugins by the namespaced name convention: +// each entry's Name starts with ".", and we group by the +// leading segment up to the first dot. +func BuildInventory(plugins []PluginInventorySource, registry *hook.Registry, rules []RuleInventorySource) *Inventory { + byPlugin := make(map[string]*PluginEntry, len(plugins)) + out := &Inventory{Plugins: make([]PluginEntry, 0, len(plugins))} + for _, p := range plugins { + entry := PluginEntry{ + Name: p.Name, + Version: p.Version, + Capabilities: NewCapabilitiesView(p.Capabilities), + } + out.Plugins = append(out.Plugins, entry) + } + for i := range out.Plugins { + byPlugin[out.Plugins[i].Name] = &out.Plugins[i] + } + + if registry != nil { + for _, e := range registry.Observers() { + if entry := byPlugin[ownerOf(e.Name)]; entry != nil { + entry.Observers = append(entry.Observers, HookEntry{ + Name: e.Name, + When: whenLabel(e.When), + }) + } + } + for _, e := range registry.Wrappers() { + if entry := byPlugin[ownerOf(e.Name)]; entry != nil { + entry.Wrappers = append(entry.Wrappers, HookEntry{ + Name: e.Name, + }) + } + } + for _, e := range registry.Lifecycles() { + if entry := byPlugin[ownerOf(e.Name)]; entry != nil { + entry.Lifecycles = append(entry.Lifecycles, HookEntry{ + Name: e.Name, + Event: eventLabel(e.Event), + }) + } + } + } + + for _, r := range rules { + if entry := byPlugin[r.PluginName]; entry != nil { + entry.Rule = &RuleView{ + Name: r.RuleName, + Description: r.Desc, + Allow: r.Allow, + Deny: r.Deny, + MaxRisk: r.MaxRisk, + Identities: r.Identities, + AllowUnannotated: r.AllowUnannotated, + } + } + } + return out +} + +// ownerOf extracts the plugin name from a namespaced hook name. The +// platform forbids "." in plugin names, so the first dot is always the +// namespace separator. Names without a dot are returned as-is. +func ownerOf(hookName string) string { + if i := strings.IndexByte(hookName, '.'); i >= 0 { + return hookName[:i] + } + return hookName +} + +func whenLabel(w platform.When) string { + switch w { + case platform.Before: + return "Before" + case platform.After: + return "After" + } + return "" +} + +func eventLabel(e platform.LifecycleEvent) string { + switch e { + case platform.Startup: + return "Startup" + case platform.Shutdown: + return "Shutdown" + } + return "" +} + +// --- Active inventory storage (process-global) --- + +var ( + inventoryMu sync.RWMutex + activeInventory *Inventory +) + +// SetActiveInventory records the inventory built at bootstrap. Called +// once from cmd/policy.go after install + wireHooks complete. +// +// A deep copy is taken so the snapshot is immune to later mutations of +// the input by the caller (or by any other goroutine reading the same +// PluginEntry slice). Without deep-copy, the shallow `cp := *inv` +// previously still aliased Plugins / observer / wrapper / lifecycle +// slices and the embedded RuleView's slice fields. +func SetActiveInventory(inv *Inventory) { + inventoryMu.Lock() + defer inventoryMu.Unlock() + if inv == nil { + activeInventory = nil + return + } + activeInventory = cloneInventory(inv) +} + +// GetActiveInventory returns a deep copy of the inventory, or nil if +// bootstrap has not finished. Same reasoning as SetActiveInventory: +// returning a shallow copy would let callers reach into the stored +// global through any of the embedded slices. +func GetActiveInventory() *Inventory { + inventoryMu.RLock() + defer inventoryMu.RUnlock() + if activeInventory == nil { + return nil + } + return cloneInventory(activeInventory) +} + +// cloneInventory deep-copies every level the snapshot exposes: +// top-level struct, Plugins slice, each PluginEntry's hook slices, and +// the rule's slice fields. The hook entries themselves are value types +// so the slice copy already disjoints them. +func cloneInventory(in *Inventory) *Inventory { + if in == nil { + return nil + } + out := &Inventory{ + Plugins: make([]PluginEntry, len(in.Plugins)), + } + for i, p := range in.Plugins { + entry := PluginEntry{ + Name: p.Name, + Version: p.Version, + Capabilities: p.Capabilities, + } + if p.Rule != nil { + rv := *p.Rule + rv.Allow = append([]string(nil), p.Rule.Allow...) + rv.Deny = append([]string(nil), p.Rule.Deny...) + rv.Identities = append([]string(nil), p.Rule.Identities...) + entry.Rule = &rv + } + entry.Observers = append([]HookEntry(nil), p.Observers...) + entry.Wrappers = append([]HookEntry(nil), p.Wrappers...) + entry.Lifecycles = append([]HookEntry(nil), p.Lifecycles...) + out.Plugins[i] = entry + } + return out +} diff --git a/internal/platform/inventory_test.go b/internal/platform/inventory_test.go new file mode 100644 index 000000000..a9d8d8b51 --- /dev/null +++ b/internal/platform/inventory_test.go @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package internalplatform_test + +import ( + "context" + "testing" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/hook" + internalplatform "github.com/larksuite/cli/internal/platform" +) + +func TestBuildInventory_groupsByPluginName(t *testing.T) { + plugins := []internalplatform.PluginInventorySource{ + {Name: "a", Version: "1.0", Capabilities: platform.Capabilities{ + Restricts: true, FailurePolicy: platform.FailClosed, + }}, + {Name: "b", Version: "2.0"}, + } + + r := hook.NewRegistry() + obs := func(context.Context, platform.Invocation) {} + wrap := func(next platform.Handler) platform.Handler { return next } + lc := func(context.Context, *platform.LifecycleContext) error { return nil } + + r.AddObserver(hook.ObserverEntry{Name: "a.pre", When: platform.Before, Selector: platform.All(), Fn: obs}) + r.AddObserver(hook.ObserverEntry{Name: "a.post", When: platform.After, Selector: platform.All(), Fn: obs}) + r.AddObserver(hook.ObserverEntry{Name: "b.audit", When: platform.Before, Selector: platform.All(), Fn: obs}) + r.AddWrapper(hook.WrapperEntry{Name: "a.approval", Selector: platform.All(), Fn: wrap}) + r.AddLifecycle(hook.LifecycleEntry{Name: "a.boot", Event: platform.Startup, Fn: lc}) + r.AddLifecycle(hook.LifecycleEntry{Name: "b.bye", Event: platform.Shutdown, Fn: lc}) + + rules := []internalplatform.RuleInventorySource{ + {PluginName: "a", RuleName: "a-rule", Allow: []string{"docs/**"}, MaxRisk: "read"}, + } + + inv := internalplatform.BuildInventory(plugins, r, rules) + + if got := len(inv.Plugins); got != 2 { + t.Fatalf("Plugins len = %d, want 2", got) + } + a := findPlugin(inv, "a") + b := findPlugin(inv, "b") + if a == nil || b == nil { + t.Fatalf("missing entries: a=%v b=%v", a, b) + } + + if got := len(a.Observers); got != 2 { + t.Errorf("a.Observers = %d, want 2", got) + } + if got := len(a.Wrappers); got != 1 { + t.Errorf("a.Wrappers = %d, want 1", got) + } + if got := len(a.Lifecycles); got != 1 { + t.Errorf("a.Lifecycles = %d, want 1", got) + } + if a.Rule == nil || a.Rule.Name != "a-rule" { + t.Errorf("a.Rule = %+v, want name a-rule", a.Rule) + } + if a.Capabilities.FailurePolicy != "FailClosed" { + t.Errorf("a.Capabilities.FailurePolicy = %q, want FailClosed", a.Capabilities.FailurePolicy) + } + + if got := len(b.Observers); got != 1 { + t.Errorf("b.Observers = %d, want 1 (only b.audit)", got) + } + if b.Rule != nil { + t.Errorf("b.Rule = %+v, want nil (b did not call Restrict)", b.Rule) + } + if b.Capabilities.FailurePolicy != "FailOpen" { + t.Errorf("b.Capabilities.FailurePolicy = %q, want FailOpen (zero value)", b.Capabilities.FailurePolicy) + } +} + +func TestBuildInventory_empty(t *testing.T) { + inv := internalplatform.BuildInventory(nil, nil, nil) + if got := len(inv.Plugins); got != 0 { + t.Errorf("Plugins len = %d, want 0", got) + } +} + +func findPlugin(inv *internalplatform.Inventory, name string) *internalplatform.PluginEntry { + for i := range inv.Plugins { + if inv.Plugins[i].Name == name { + return &inv.Plugins[i] + } + } + return nil +} diff --git a/internal/platform/staging.go b/internal/platform/staging.go new file mode 100644 index 000000000..1b0b7668a --- /dev/null +++ b/internal/platform/staging.go @@ -0,0 +1,228 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package internalplatform + +import ( + "fmt" + "regexp" + + "github.com/larksuite/cli/extension/platform" + "github.com/larksuite/cli/internal/hook" +) + +// hookNamePattern is the grammar both Plugin.Name() and hookName must +// match -- design hard-constraint #9. The "." character is forbidden so +// the namespace join "{plugin}.{hook}" is unambiguous. +var hookNamePattern = regexp.MustCompile(`^[a-z0-9][a-z0-9-]*$`) + +// stagingRegistrar buffers every Registrar call so the platformhost can +// commit them atomically (or discard them all) once Install returns. +// +// All validation happens here at staging time -- bad hookName, nil +// handler, duplicate names, etc. produce typed errors that surface in +// validateSelf and are translated into PluginInstallError by the host +// loop. +type stagingRegistrar struct { + pluginName string + + stagedObservers []hook.ObserverEntry + stagedWrappers []hook.WrapperEntry + stagedLifecycles []hook.LifecycleEntry + + // rule is the staged Restrict contribution, captured for the host + // to merge with the yaml side later. nil means the plugin did not + // call r.Restrict. + rule *platform.Rule + + // actuallyRestricted records whether r.Restrict was called at all. + // Even a Restrict(nil) flips this to true so the + // Restricts-vs-actual consistency check can detect the call. + actuallyRestricted bool + + // seenHookNames detects duplicate hookName within this plugin's + // Install call. + seenHookNames map[string]bool + + // stagingErrs accumulates per-call validation errors. A single + // Install can violate the grammar multiple times; collecting all + // of them lets diagnostic output show the full picture. + stagingErrs []stagingErr +} + +// stagingErr is the per-call buffered validation failure. +type stagingErr struct { + reasonCode string + message string +} + +func newStagingRegistrar(pluginName string) *stagingRegistrar { + return &stagingRegistrar{ + pluginName: pluginName, + seenHookNames: map[string]bool{}, + } +} + +// --- Registrar interface --- + +func (r *stagingRegistrar) Observe(when platform.When, name string, sel platform.Selector, fn platform.Observer) { + if !r.validateName(name) { + return + } + if !r.validateNonNilSelector(name, sel) { + return + } + if fn == nil { + r.bufferErr(ReasonInvalidHookRegister, fmt.Sprintf("observe %q: handler is nil", name)) + return + } + if !isValidWhen(when) { + r.bufferErr(ReasonInvalidHookRegister, fmt.Sprintf("observe %q: invalid When value %d", name, when)) + return + } + r.stagedObservers = append(r.stagedObservers, hook.ObserverEntry{ + Name: r.namespaced(name), + When: when, + Selector: sel, + Fn: fn, + }) +} + +func (r *stagingRegistrar) Wrap(name string, sel platform.Selector, w platform.Wrapper) { + if !r.validateName(name) { + return + } + if !r.validateNonNilSelector(name, sel) { + return + } + if w == nil { + r.bufferErr(ReasonInvalidHookRegister, fmt.Sprintf("wrap %q: handler is nil", name)) + return + } + r.stagedWrappers = append(r.stagedWrappers, hook.WrapperEntry{ + Name: r.namespaced(name), + Selector: sel, + Fn: w, + }) +} + +func (r *stagingRegistrar) On(event platform.LifecycleEvent, name string, fn platform.LifecycleHandler) { + if !r.validateName(name) { + return + } + if fn == nil { + r.bufferErr(ReasonInvalidHookRegister, fmt.Sprintf("on %q: handler is nil", name)) + return + } + if !isValidLifecycleEvent(event) { + r.bufferErr(ReasonInvalidHookRegister, fmt.Sprintf("on %q: invalid LifecycleEvent value %d", name, event)) + return + } + r.stagedLifecycles = append(r.stagedLifecycles, hook.LifecycleEntry{ + Name: r.namespaced(name), + Event: event, + Fn: fn, + }) +} + +func (r *stagingRegistrar) Restrict(rule *platform.Rule) { + if r.actuallyRestricted { + r.bufferErr(ReasonDoubleRestrict, "Restrict called more than once") + return + } + r.actuallyRestricted = true + if rule == nil { + r.bufferErr(ReasonInvalidRule, "Restrict(nil)") + return + } + // Defensive clone: retaining the caller's *Rule directly would let + // the plugin mutate Allow/Deny/Identities (or even the whole rule) + // after Install returns, bypassing the validation we run on the + // stored copy in validateSelf. Take an independent snapshot of + // every slice field so the post-validation rule is frozen. + cp := *rule + cp.Allow = append([]string(nil), rule.Allow...) + cp.Deny = append([]string(nil), rule.Deny...) + cp.Identities = append([]platform.Identity(nil), rule.Identities...) + r.rule = &cp +} + +// --- helpers --- + +func (r *stagingRegistrar) namespaced(name string) string { + return r.pluginName + "." + name +} + +func (r *stagingRegistrar) validateName(name string) bool { + if !hookNamePattern.MatchString(name) { + r.bufferErr(ReasonInvalidHookName, fmt.Sprintf("hookName %q must match ^[a-z0-9][a-z0-9-]*$", name)) + return false + } + if r.seenHookNames[name] { + r.bufferErr(ReasonDuplicateHookName, fmt.Sprintf("hookName %q registered twice in same plugin", name)) + return false + } + r.seenHookNames[name] = true + return true +} + +func (r *stagingRegistrar) validateNonNilSelector(name string, sel platform.Selector) bool { + if sel == nil { + r.bufferErr(ReasonInvalidHookRegister, fmt.Sprintf("hook %q: selector is nil", name)) + return false + } + return true +} + +func (r *stagingRegistrar) bufferErr(reasonCode, message string) { + r.stagingErrs = append(r.stagingErrs, stagingErr{ + reasonCode: reasonCode, + message: message, + }) +} + +// validateSelf runs after Install returns. It checks: +// +// - any buffered staging error -> abort +// - Restricts declared but Install did not call r.Restrict -> abort +// - Restricts NOT declared but Install did call r.Restrict -> abort +// +// Returns the first PluginInstallError encountered (callers can use +// errors.As to inspect it). Nil means staging is clean. +func (r *stagingRegistrar) validateSelf(caps platform.Capabilities) error { + if len(r.stagingErrs) > 0 { + first := r.stagingErrs[0] + return &PluginInstallError{ + PluginName: r.pluginName, + ReasonCode: first.reasonCode, + Reason: first.message, + } + } + if caps.Restricts && !r.actuallyRestricted { + return &PluginInstallError{ + PluginName: r.pluginName, + ReasonCode: ReasonRestrictsMismatch, + Reason: "Capabilities.Restricts=true but Install did not call r.Restrict", + } + } + if !caps.Restricts && r.actuallyRestricted { + return &PluginInstallError{ + PluginName: r.pluginName, + ReasonCode: ReasonRestrictsMismatch, + Reason: "Capabilities.Restricts=false but Install called r.Restrict", + } + } + return nil +} + +func isValidWhen(w platform.When) bool { + return w == platform.Before || w == platform.After +} + +func isValidLifecycleEvent(e platform.LifecycleEvent) bool { + switch e { + case platform.Startup, platform.Shutdown: + return true + } + return false +} diff --git a/internal/platform/version.go b/internal/platform/version.go new file mode 100644 index 000000000..9cdc05fb7 --- /dev/null +++ b/internal/platform/version.go @@ -0,0 +1,154 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package internalplatform + +import ( + "fmt" + "strconv" + "strings" + + "github.com/larksuite/cli/internal/build" +) + +// currentCLIVersion returns the running binary's version, redirectable +// from tests via SetCurrentCLIVersionForTesting. Production reads from +// internal/build.Version, which is set by -ldflags at release time. +var currentCLIVersion = func() string { return build.Version } + +// SetCurrentCLIVersionForTesting overrides the version reported to the +// RequiredCLIVersion check. Returns a restore function tests must defer. +func SetCurrentCLIVersionForTesting(v string) func() { + old := currentCLIVersion + currentCLIVersion = func() string { return v } + return func() { currentCLIVersion = old } +} + +// satisfiesRequiredCLIVersion reports whether buildVersion meets the +// constraint declared by Plugin.Capabilities().RequiredCLIVersion. +// +// Supported constraint forms (single comparator, no compound): +// +// "" - no requirement (always satisfied) +// "1.2.3" - exact match (equivalent to "=1.2.3") +// "=1.2.3" - exact match +// ">=1.2" - buildVersion >= 1.2 (missing patch -> 0) +// ">1.2" - strict greater than +// "<=1.2" - less than or equal +// "<1.2" - strict less than +// +// Development builds (buildVersion == "DEV" or "") always satisfy the +// constraint; the check is meaningful only for tagged releases. +// +// Returns false and an error when constraint is malformed -- callers +// should treat parse errors as fail-closed so an authoring mistake in +// the plugin does not silently load against the wrong CLI version. +// +// **Order of checks**: constraint syntax is validated FIRST, before the +// DEV-build short-circuit. A malformed constraint is a plugin authoring +// bug; we surface it even on DEV builds so the typo can be caught +// during plugin development instead of waiting for the first tagged +// release to expose it. +func satisfiesRequiredCLIVersion(buildVersion, constraint string) (bool, error) { + constraint = strings.TrimSpace(constraint) + if constraint == "" { + return true, nil + } + + op, rhs := splitConstraint(constraint) + rv, err := parseSemverPrefix(rhs) + if err != nil { + return false, fmt.Errorf("invalid RequiredCLIVersion %q: %w", constraint, err) + } + + if buildVersion == "" || buildVersion == "DEV" { + return true, nil + } + + bv, err := parseSemverPrefix(buildVersion) + if err != nil { + // Build version is unparseable -- treat as DEV so an exotic + // build tag doesn't lock plugins out. + return true, nil //nolint:nilerr // intentional fail-open for unparseable buildVersion + } + cmp := compareSemver(bv, rv) + switch op { + case "=", "": + return cmp == 0, nil + case ">=": + return cmp >= 0, nil + case ">": + return cmp > 0, nil + case "<=": + return cmp <= 0, nil + case "<": + return cmp < 0, nil + default: + return false, fmt.Errorf("invalid RequiredCLIVersion %q: unknown operator %q", constraint, op) + } +} + +// splitConstraint extracts the leading comparator (if any) from a +// constraint string. The operator is one of "", "=", ">=", ">", "<=", "<". +func splitConstraint(s string) (op, rest string) { + switch { + case strings.HasPrefix(s, ">="): + return ">=", strings.TrimSpace(s[2:]) + case strings.HasPrefix(s, "<="): + return "<=", strings.TrimSpace(s[2:]) + case strings.HasPrefix(s, ">"): + return ">", strings.TrimSpace(s[1:]) + case strings.HasPrefix(s, "<"): + return "<", strings.TrimSpace(s[1:]) + case strings.HasPrefix(s, "="): + return "=", strings.TrimSpace(s[1:]) + default: + return "", s + } +} + +// parseSemverPrefix parses MAJOR[.MINOR[.PATCH]] and drops any pre-release / +// build suffix. Missing minor / patch default to 0. Accepts a leading "v". +func parseSemverPrefix(s string) (parts [3]int, err error) { + s = strings.TrimPrefix(strings.TrimSpace(s), "v") + if s == "" { + return parts, fmt.Errorf("empty version") + } + // Trim pre-release/build suffix at first '-' or '+'. + for i, c := range s { + if c == '-' || c == '+' { + s = s[:i] + break + } + } + fields := strings.Split(s, ".") + // Reject `1.2.3.4` and longer instead of silently truncating — + // truncation hides the typo and lets a malformed RequiredCLIVersion + // pass validation while the comparator below operates on the wrong + // components. Build-version parsing has its own fail-open guard + // upstream (see satisfiesRequiredCLIVersion comment about exotic + // build tags), so it stays compatible. + if len(fields) > 3 { + return [3]int{}, fmt.Errorf("version %q has more than three numeric components", s) + } + for i, f := range fields { + n, err := strconv.Atoi(strings.TrimSpace(f)) + if err != nil || n < 0 { + return [3]int{}, fmt.Errorf("non-numeric component %q in version %q", f, s) + } + parts[i] = n + } + return parts, nil +} + +func compareSemver(a, b [3]int) int { + for i := 0; i < 3; i++ { + if a[i] < b[i] { + return -1 + } + if a[i] > b[i] { + return 1 + } + } + return 0 +} diff --git a/internal/platform/version_test.go b/internal/platform/version_test.go new file mode 100644 index 000000000..fec37bf0b --- /dev/null +++ b/internal/platform/version_test.go @@ -0,0 +1,178 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package internalplatform + +import ( + "errors" + "testing" + + "github.com/larksuite/cli/extension/platform" +) + +func TestSatisfiesRequiredCLIVersion_constraints(t *testing.T) { + cases := []struct { + name string + build string + constraint string + want bool + wantErr bool + }{ + {"empty constraint always satisfied", "1.0.0", "", true, false}, + {"DEV build always satisfied", "DEV", ">=99.0.0", true, false}, + {"empty build counts as DEV", "", ">=99.0.0", true, false}, + {"v prefix stripped", "v1.0.28", ">=1.0.0", true, false}, + {"exact match implicit operator", "1.0.0", "1.0.0", true, false}, + {"exact match explicit =", "1.0.0", "=1.0.0", true, false}, + {">= equal", "1.0.0", ">=1.0.0", true, false}, + {">= higher", "1.2.0", ">=1.0.0", true, false}, + {">= lower fails", "1.0.0", ">=2.0.0", false, false}, + {"> strict higher", "1.0.1", ">1.0.0", true, false}, + {"> equal fails", "1.0.0", ">1.0.0", false, false}, + {"<= equal", "1.0.0", "<=1.0.0", true, false}, + {"<= higher fails", "2.0.0", "<=1.0.0", false, false}, + {"< strict lower", "0.9.0", "<1.0.0", true, false}, + {"missing patch defaults to 0", "1.0", ">=1.0.0", true, false}, + {"constraint with pre-release suffix", "1.0.0-rc1", ">=1.0.0", true, false}, + {"malformed constraint returns error", "1.0.0", ">=abc", false, true}, + {"malformed constraint errors on DEV too", "DEV", ">=abc", false, true}, + {"malformed constraint errors on empty build", "", ">=zzz", false, true}, + {"unparseable build version treated as DEV", "abc", ">=1.0.0", true, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := satisfiesRequiredCLIVersion(tc.build, tc.constraint) + if tc.wantErr { + if err == nil { + t.Errorf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.want { + t.Errorf("got %v, want %v", got, tc.want) + } + }) + } +} + +// A plugin whose RequiredCLIVersion exceeds the running build must +// abort install with reason_code capability_unmet. The plugin's +// FailurePolicy then decides whether the abort bubbles up. +func TestInstallOne_RequiredCLIVersion_UnmetFailClosedAborts(t *testing.T) { + restore := SetCurrentCLIVersionForTesting("1.0.0") + t.Cleanup(restore) + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + + platform.Register(&capVersionPlugin{ + name: "needs-future", + requirement: ">=99.0.0", + fail: platform.FailClosed, + }) + + _, err := InstallAll(platform.RegisteredPlugins(), nil) + if err == nil { + t.Fatal("expected FailClosed install error, got nil") + } + var pi *PluginInstallError + if !errors.As(err, &pi) { + t.Fatalf("expected *PluginInstallError, got %T", err) + } + if pi.ReasonCode != ReasonCapabilityUnmet { + t.Errorf("reason_code = %q, want %q", pi.ReasonCode, ReasonCapabilityUnmet) + } +} + +// FailOpen plugin with unmet RequiredCLIVersion is skipped (warning), +// other plugins still install. +func TestInstallOne_RequiredCLIVersion_UnmetFailOpenSkips(t *testing.T) { + restore := SetCurrentCLIVersionForTesting("1.0.0") + t.Cleanup(restore) + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + + platform.Register(&capVersionPlugin{ + name: "future-failopen", + requirement: ">=99.0.0", + fail: platform.FailOpen, + }) + + result, err := InstallAll(platform.RegisteredPlugins(), nil) + if err != nil { + t.Fatalf("FailOpen unmet must not bubble up, got: %v", err) + } + if result.Registry == nil { + t.Errorf("Registry should be non-nil even after FailOpen skip") + } +} + +// A plugin authoring error in RequiredCLIVersion (parse failure) must +// abort installation UNCONDITIONALLY. Even FailOpen cannot mask a +// typo in the constraint string -- the plugin author asked the host +// to do something it cannot parse, and silently skipping would hide +// the bug from CI. +// +// Implementation: parse errors return ReasonInvalidCapability, which +// isUntrustedConfigError lists alongside restricts_mismatch so +// InstallAll's switch treats it as a hard abort. +func TestInstallOne_RequiredCLIVersion_MalformedAbortsRegardlessOfFailurePolicy(t *testing.T) { + restore := SetCurrentCLIVersionForTesting("1.0.0") + t.Cleanup(restore) + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + + // FailOpen + malformed constraint: still aborts. + platform.Register(&capVersionPlugin{ + name: "typo", + requirement: ">=abc", + fail: platform.FailOpen, + }) + + _, err := InstallAll(platform.RegisteredPlugins(), nil) + if err == nil { + t.Fatal("expected malformed constraint to abort even FailOpen, got nil") + } + var pi *PluginInstallError + if !errors.As(err, &pi) { + t.Fatalf("expected *PluginInstallError, got %T", err) + } + if pi.ReasonCode != ReasonInvalidCapability { + t.Errorf("reason_code = %q, want %q", pi.ReasonCode, ReasonInvalidCapability) + } +} + +// A plugin whose RequiredCLIVersion is satisfied installs normally. +func TestInstallOne_RequiredCLIVersion_SatisfiedInstalls(t *testing.T) { + restore := SetCurrentCLIVersionForTesting("1.5.0") + t.Cleanup(restore) + platform.ResetForTesting() + t.Cleanup(platform.ResetForTesting) + + platform.Register(&capVersionPlugin{ + name: "ok", + requirement: ">=1.0.0", + fail: platform.FailClosed, + }) + if _, err := InstallAll(platform.RegisteredPlugins(), nil); err != nil { + t.Errorf("expected install success, got %v", err) + } +} + +type capVersionPlugin struct { + name string + requirement string + fail platform.FailurePolicy +} + +func (p *capVersionPlugin) Name() string { return p.name } +func (p *capVersionPlugin) Version() string { return "0.0.1" } +func (p *capVersionPlugin) Capabilities() platform.Capabilities { + return platform.Capabilities{ + RequiredCLIVersion: p.requirement, + FailurePolicy: p.fail, + } +} +func (p *capVersionPlugin) Install(platform.Registrar) error { return nil } diff --git a/internal/selfupdate/updater.go b/internal/selfupdate/updater.go index d9cb5ab97..365f84ab7 100644 --- a/internal/selfupdate/updater.go +++ b/internal/selfupdate/updater.go @@ -84,6 +84,7 @@ type Updater struct { DetectOverride func() DetectResult NpmInstallOverride func(version string) *NpmResult SkillsUpdateOverride func() *NpmResult + SkillsCommandOverride func(args ...string) *NpmResult VerifyOverride func(expectedVersion string) error RestoreAvailableOverride func() bool @@ -166,7 +167,46 @@ func (u *Updater) RunSkillsUpdate() *NpmResult { return r } +func (u *Updater) ListOfficialSkills() *NpmResult { + r := u.runSkillsListOfficial("https://open.feishu.cn") + if r.Err != nil { + r = u.runSkillsListOfficial("larksuite/cli") + } + return r +} + +func (u *Updater) ListGlobalSkills() *NpmResult { + return u.runSkillsListGlobal() +} + +func (u *Updater) InstallSkill(name string) *NpmResult { + r := u.runSkillsInstall("https://open.feishu.cn", name) + if r.Err != nil { + r = u.runSkillsInstall("larksuite/cli", name) + } + return r +} + func (u *Updater) runSkillsAdd(source string) *NpmResult { + return u.runSkillsCommand("-y", "skills", "add", source, "-g", "-y") +} + +func (u *Updater) runSkillsListOfficial(source string) *NpmResult { + return u.runSkillsCommand("-y", "skills", "add", source, "--list") +} + +func (u *Updater) runSkillsListGlobal() *NpmResult { + return u.runSkillsCommand("-y", "skills", "ls", "-g") +} + +func (u *Updater) runSkillsInstall(source string, name string) *NpmResult { + return u.runSkillsCommand("-y", "skills", "add", source, "-s", name, "-g", "-y") +} + +func (u *Updater) runSkillsCommand(args ...string) *NpmResult { + if u.SkillsCommandOverride != nil { + return u.SkillsCommandOverride(args...) + } r := &NpmResult{} npxPath, err := exec.LookPath("npx") if err != nil { @@ -175,7 +215,7 @@ func (u *Updater) runSkillsAdd(source string) *NpmResult { } ctx, cancel := context.WithTimeout(context.Background(), skillsUpdateTimeout) defer cancel() - cmd := exec.CommandContext(ctx, npxPath, "-y", "skills", "add", source, "-g", "-y") + cmd := exec.CommandContext(ctx, npxPath, args...) cmd.Stdout = &r.Stdout cmd.Stderr = &r.Stderr r.Err = cmd.Run() diff --git a/internal/selfupdate/updater_test.go b/internal/selfupdate/updater_test.go index f13c80b65..b2da83f54 100644 --- a/internal/selfupdate/updater_test.go +++ b/internal/selfupdate/updater_test.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" "github.com/larksuite/cli/internal/vfs" @@ -166,3 +167,87 @@ func TestVerifyBinaryEmptyOutput(t *testing.T) { t.Fatal("VerifyBinary(empty output) expected error, got nil") } } + +func TestSkillsCommandsUseExpectedArgs(t *testing.T) { + tests := []struct { + name string + run func(*Updater) *NpmResult + want string + }{ + { + name: "list official primary", + run: func(u *Updater) *NpmResult { + return u.runSkillsListOfficial("https://open.feishu.cn") + }, + want: "-y skills add https://open.feishu.cn --list", + }, + { + name: "list global", + run: func(u *Updater) *NpmResult { + return u.runSkillsListGlobal() + }, + want: "-y skills ls -g", + }, + { + name: "install skill primary", + run: func(u *Updater) *NpmResult { + return u.runSkillsInstall("https://open.feishu.cn", "lark-mail") + }, + want: "-y skills add https://open.feishu.cn -s lark-mail -g -y", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("uses a POSIX shell script") + } + dir := t.TempDir() + script := filepath.Join(dir, "npx") + logPath := filepath.Join(dir, "npx.log") + if err := os.WriteFile(script, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \""+logPath+"\"\nexit 0\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) + + result := tt.run(New()) + if result.Err != nil { + t.Fatalf("command err = %v, want nil", result.Err) + } + raw, err := os.ReadFile(logPath) + if err != nil { + t.Fatal(err) + } + if strings.TrimSpace(string(raw)) != tt.want { + t.Fatalf("args = %q, want %q", strings.TrimSpace(string(raw)), tt.want) + } + }) + } +} + +func TestListOfficialSkillsFallsBack(t *testing.T) { + called := []string{} + updater := &Updater{ + SkillsCommandOverride: func(args ...string) *NpmResult { + called = append(called, strings.Join(args, " ")) + r := &NpmResult{} + if strings.Contains(strings.Join(args, " "), "https://open.feishu.cn") { + r.Err = fmt.Errorf("primary failed") + return r + } + r.Stdout.WriteString("lark-calendar\n") + return r + }, + } + + result := updater.ListOfficialSkills() + if result.Err != nil { + t.Fatalf("ListOfficialSkills() err = %v, want nil", result.Err) + } + if len(called) != 2 { + t.Fatalf("called %d commands, want 2: %#v", len(called), called) + } + if !strings.Contains(called[1], "larksuite/cli --list") { + t.Fatalf("fallback call = %q, want larksuite/cli --list", called[1]) + } +} diff --git a/internal/skillscheck/check.go b/internal/skillscheck/check.go index 429117a18..029a4d01f 100644 --- a/internal/skillscheck/check.go +++ b/internal/skillscheck/check.go @@ -3,46 +3,29 @@ package skillscheck -// Init runs the synchronous skills version check. Stores a StaleNotice -// when the local stamp records a version that does not match -// currentVersion. Safe to call from cmd/root.go before rootCmd.Execute(); -// zero network, zero subprocess — only a local stamp file read. +import "strings" + +// Init runs the synchronous skills version check. Stores a StaleNotice when +// the local skills state records a version that does not match currentVersion. +// Safe to call from cmd/root.go before rootCmd.Execute(); zero network, zero +// subprocess — only a local state file read. // // Skip rules: see shouldSkip (CI envs, DEV builds, non-release semver, // LARKSUITE_CLI_NO_SKILLS_NOTIFIER opt-out). -// -// Failure modes (all → no notice, no nag): -// - shouldSkip rule met -// - ReadStamp returns an I/O error other than ENOENT -// - Stamp matches currentVersion (in-sync) -// - Stamp is missing (cold start) — only users who ran `lark-cli update` -// opt into drift tracking; npx-only installs are intentionally silent. func Init(currentVersion string) { - // Clear any stale notice from a prior call so early returns below - // (skip rules / read errors / cold start / in-sync) leave pending == nil - // instead of preserving a stale value from a previous Init invocation. SetPending(nil) if shouldSkip(currentVersion) { return } - stamp, err := ReadStamp() - if err != nil { - // Fail closed — don't nag for a transient FS problem. - return - } - if stamp == "" { - // Cold start: the stamp is written exclusively by `lark-cli update` - // (runSkillsAndStamp). Users who installed skills via - // `npx skills add larksuite/cli -g` have no stamp yet — they must - // not be nagged with "skills not installed", since the on-disk - // skills directory may already be fully populated. + version, ok := ReadSyncedVersion() + if !ok { return } - if stamp == currentVersion { + if strings.TrimPrefix(strings.TrimPrefix(version, "v"), "V") == strings.TrimPrefix(strings.TrimPrefix(currentVersion, "v"), "V") { return } SetPending(&StaleNotice{ - Current: stamp, // guaranteed non-empty under the new contract + Current: version, Target: currentVersion, }) } diff --git a/internal/skillscheck/check_test.go b/internal/skillscheck/check_test.go index 64525bc5a..2674d5424 100644 --- a/internal/skillscheck/check_test.go +++ b/internal/skillscheck/check_test.go @@ -18,9 +18,8 @@ func resetPending(t *testing.T) { func TestInit_InSync_NoNotice(t *testing.T) { clearSkillsSkipEnv(t) resetPending(t) - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := WriteStamp("1.0.21"); err != nil { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if err := WriteState(SkillsState{Version: "1.0.21"}); err != nil { t.Fatal(err) } Init("1.0.21") @@ -39,12 +38,24 @@ func TestInit_ColdStart_NoNotice(t *testing.T) { } } -func TestInit_Drift_NoticeWithStampVersion(t *testing.T) { +func TestInit_NormalizedVersion_NoNotice(t *testing.T) { clearSkillsSkipEnv(t) resetPending(t) - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := WriteStamp("1.0.20"); err != nil { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if err := WriteState(SkillsState{Version: "1.0.21"}); err != nil { + t.Fatal(err) + } + Init("v1.0.21") + if got := GetPending(); got != nil { + t.Errorf("GetPending() = %+v, want nil (normalized versions are in-sync)", got) + } +} + +func TestInit_Drift_NoticeWithStateVersion(t *testing.T) { + clearSkillsSkipEnv(t) + resetPending(t) + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if err := WriteState(SkillsState{Version: "1.0.20"}); err != nil { t.Fatal(err) } Init("1.0.21") @@ -61,22 +72,18 @@ func TestInit_Skipped_NoNotice(t *testing.T) { clearSkillsSkipEnv(t) resetPending(t) t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) - // Even with an empty config dir (no stamp), DEV version should skip - // the check entirely and never emit a notice. Init("DEV") if got := GetPending(); got != nil { t.Errorf("GetPending() = %+v, want nil (skip rules met)", got) } } -func TestInit_ReadStampError_FailsClosed(t *testing.T) { +func TestInit_ReadStateError_FailsClosed(t *testing.T) { clearSkillsSkipEnv(t) resetPending(t) dir := t.TempDir() t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - // Make the stamp path a directory so vfs.ReadFile returns a - // non-ENOENT I/O error. - if err := os.MkdirAll(filepath.Join(dir, "skills.stamp"), 0o755); err != nil { + if err := os.MkdirAll(filepath.Join(dir, "skills-state.json"), 0o755); err != nil { t.Fatal(err) } Init("1.0.21") diff --git a/internal/skillscheck/notice.go b/internal/skillscheck/notice.go index b1f972218..c1425fbb7 100644 --- a/internal/skillscheck/notice.go +++ b/internal/skillscheck/notice.go @@ -3,9 +3,8 @@ // Package skillscheck verifies that the locally installed lark-cli // skills are in sync with the running binary version, by comparing -// the current binary version against a stamp file written when skills -// are last synced (by `lark-cli update`). On mismatch it stores a -// notice for injection into JSON envelopes via output.PendingNotice. +// the current binary version against skills-state.json. On mismatch it +// stores a notice for injection into JSON envelopes via output.PendingNotice. package skillscheck import ( @@ -26,8 +25,7 @@ type StaleNotice struct { // Message returns a single-line, AI-agent-parseable description of the // drift plus the canonical fix command. Mirrors internal/update.UpdateInfo.Message // in style ("..., run: lark-cli update" suffix). Current is guaranteed -// non-empty because Init only emits a StaleNotice for the drift case -// (stamp present and != binary version). +// non-empty because Init only emits a StaleNotice for the drift case. func (s *StaleNotice) Message() string { return fmt.Sprintf( "lark-cli skills %s out of sync with binary %s, run: lark-cli update", diff --git a/internal/skillscheck/stamp.go b/internal/skillscheck/stamp.go deleted file mode 100644 index 052e331c9..000000000 --- a/internal/skillscheck/stamp.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2026 Lark Technologies Pte. Ltd. -// SPDX-License-Identifier: MIT - -package skillscheck - -import ( - "errors" - "io/fs" - "path/filepath" - "strings" - - "github.com/larksuite/cli/internal/core" - "github.com/larksuite/cli/internal/validate" - "github.com/larksuite/cli/internal/vfs" -) - -const stampFile = "skills.stamp" - -// stampPath returns ~/.lark-cli/skills.stamp. -// Uses the BASE config dir (not workspace-aware) because skills install -// globally via `npx -g`; per-workspace tracking would produce false -// drift signals when switching workspaces. -func stampPath() string { - return filepath.Join(core.GetBaseConfigDir(), stampFile) -} - -// ReadStamp returns the version recorded in the stamp file. Returns -// ("", nil) when the file does not exist (interpreted as "never synced"). -// Other I/O errors are returned as-is so callers can fail closed. -func ReadStamp() (string, error) { - data, err := vfs.ReadFile(stampPath()) - if err != nil { - if errors.Is(err, fs.ErrNotExist) { - return "", nil - } - return "", err - } - return strings.TrimSpace(string(data)), nil -} - -// WriteStamp records `version` as the last successfully synced skills -// version. Atomic via tmp + rename (validate.AtomicWrite). Creates -// the base config directory if it does not exist. -func WriteStamp(version string) error { - if err := vfs.MkdirAll(core.GetBaseConfigDir(), 0o700); err != nil { - return err - } - return validate.AtomicWrite(stampPath(), []byte(version), 0o644) -} diff --git a/internal/skillscheck/stamp_test.go b/internal/skillscheck/stamp_test.go deleted file mode 100644 index 8e60dfbb4..000000000 --- a/internal/skillscheck/stamp_test.go +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) 2026 Lark Technologies Pte. Ltd. -// SPDX-License-Identifier: MIT - -package skillscheck - -import ( - "os" - "path/filepath" - "testing" -) - -func TestReadStamp_Missing(t *testing.T) { - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) - got, err := ReadStamp() - if err != nil { - t.Fatalf("ReadStamp() err = %v, want nil for ENOENT", err) - } - if got != "" { - t.Errorf("ReadStamp() = %q, want \"\" for missing file", got) - } -} - -func TestReadStamp_Normal(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := os.WriteFile(filepath.Join(dir, "skills.stamp"), []byte("1.0.21"), 0o644); err != nil { - t.Fatal(err) - } - got, err := ReadStamp() - if err != nil || got != "1.0.21" { - t.Errorf("ReadStamp() = (%q, %v), want (\"1.0.21\", nil)", got, err) - } -} - -func TestReadStamp_TrailingNewlineTolerated(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := os.WriteFile(filepath.Join(dir, "skills.stamp"), []byte("1.0.21\n"), 0o644); err != nil { - t.Fatal(err) - } - got, _ := ReadStamp() - if got != "1.0.21" { - t.Errorf("ReadStamp() = %q, want \"1.0.21\" (newline trimmed)", got) - } -} - -func TestReadStamp_EmptyFile(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := os.WriteFile(filepath.Join(dir, "skills.stamp"), []byte(""), 0o644); err != nil { - t.Fatal(err) - } - got, err := ReadStamp() - if err != nil || got != "" { - t.Errorf("ReadStamp() = (%q, %v), want (\"\", nil)", got, err) - } -} - -func TestWriteStamp_CreatesDir(t *testing.T) { - dir := filepath.Join(t.TempDir(), "nested") - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := WriteStamp("1.0.21"); err != nil { - t.Fatalf("WriteStamp() = %v, want nil", err) - } - got, _ := os.ReadFile(filepath.Join(dir, "skills.stamp")) - if string(got) != "1.0.21" { - t.Errorf("file content = %q, want \"1.0.21\"", string(got)) - } -} - -func TestWriteStamp_OverwritesExisting(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := WriteStamp("1.0.20"); err != nil { - t.Fatal(err) - } - if err := WriteStamp("1.0.21"); err != nil { - t.Fatal(err) - } - got, _ := ReadStamp() - if got != "1.0.21" { - t.Errorf("ReadStamp() after overwrite = %q, want \"1.0.21\"", got) - } -} - -func TestWriteStamp_NoTrailingNewline(t *testing.T) { - dir := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) - if err := WriteStamp("1.0.21"); err != nil { - t.Fatal(err) - } - raw, _ := os.ReadFile(filepath.Join(dir, "skills.stamp")) - if string(raw) != "1.0.21" { - t.Errorf("raw file = %q, want exactly \"1.0.21\" (no newline)", string(raw)) - } -} - -// TestWriteStamp_MkdirAllFailure verifies WriteStamp returns the mkdir error -// when the base config dir cannot be created (parent path is a regular file). -func TestWriteStamp_MkdirAllFailure(t *testing.T) { - tmp := t.TempDir() - blocker := filepath.Join(tmp, "blocker") - // Create a regular file where MkdirAll wants to create a directory. - if err := os.WriteFile(blocker, []byte("not-a-dir"), 0o644); err != nil { - t.Fatal(err) - } - // Point the config dir at a path UNDER the regular file — MkdirAll must fail. - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", filepath.Join(blocker, "child")) - - if err := WriteStamp("1.0.21"); err == nil { - t.Fatal("WriteStamp() = nil, want non-nil error from MkdirAll failure") - } -} diff --git a/internal/skillscheck/state.go b/internal/skillscheck/state.go new file mode 100644 index 000000000..abb2e2a6f --- /dev/null +++ b/internal/skillscheck/state.go @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package skillscheck + +import ( + "encoding/json" + "errors" + "io/fs" + "path/filepath" + + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" +) + +const ( + stateFile = "skills-state.json" + stateSchemaVersion = 1 +) + +type SkillsState struct { + SchemaVersion int `json:"schema_version"` + Version string `json:"version"` + OfficialSkills []string `json:"official_skills"` + UpdatedSkills []string `json:"updated_skills"` + AddedSkills []string `json:"added_skills"` + SkippedDeletedSkills []string `json:"skipped_deleted_skills"` + UpdatedAt string `json:"updated_at"` +} + +func statePath() string { + return filepath.Join(core.GetBaseConfigDir(), stateFile) +} + +func ReadState() (*SkillsState, bool, error) { + data, err := vfs.ReadFile(statePath()) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, false, nil + } + return nil, false, err + } + + var state SkillsState + if json.Unmarshal(data, &state) != nil { + state = SkillsState{} + } + if state.SchemaVersion != stateSchemaVersion { + return nil, false, nil + } + return &state, true, nil +} + +func WriteState(state SkillsState) error { + state.SchemaVersion = stateSchemaVersion + state.ensureNonNilSlices() + + if err := vfs.MkdirAll(core.GetBaseConfigDir(), 0o700); err != nil { + return err + } + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return err + } + return validate.AtomicWrite(statePath(), append(data, '\n'), 0o644) +} + +func ReadSyncedVersion() (string, bool) { + state, ok, err := ReadState() + if err != nil || !ok || state.Version == "" { + return "", false + } + return state.Version, true +} + +func (s *SkillsState) ensureNonNilSlices() { + if s.OfficialSkills == nil { + s.OfficialSkills = []string{} + } + if s.UpdatedSkills == nil { + s.UpdatedSkills = []string{} + } + if s.AddedSkills == nil { + s.AddedSkills = []string{} + } + if s.SkippedDeletedSkills == nil { + s.SkippedDeletedSkills = []string{} + } +} diff --git a/internal/skillscheck/state_test.go b/internal/skillscheck/state_test.go new file mode 100644 index 000000000..77eab85d4 --- /dev/null +++ b/internal/skillscheck/state_test.go @@ -0,0 +1,153 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package skillscheck + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "testing" +) + +func TestReadState_Missing(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + state, ok, err := ReadState() + if err != nil { + t.Fatalf("ReadState() err = %v, want nil for missing file", err) + } + if ok { + t.Fatal("ReadState() ok = true, want false for missing file") + } + if state != nil { + t.Fatalf("ReadState() state = %#v, want nil for missing file", state) + } +} + +func TestReadState_Valid(t *testing.T) { + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + want := SkillsState{ + SchemaVersion: 1, + Version: "1.2.3", + OfficialSkills: []string{"lark-doc", "lark-im"}, + UpdatedSkills: []string{"lark-doc"}, + AddedSkills: []string{"lark-task"}, + SkippedDeletedSkills: []string{"custom-skill"}, + UpdatedAt: "2026-05-18T10:00:00Z", + } + data, err := json.Marshal(want) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, stateFile), data, 0o644); err != nil { + t.Fatal(err) + } + + got, ok, err := ReadState() + if err != nil { + t.Fatalf("ReadState() err = %v, want nil", err) + } + if !ok { + t.Fatal("ReadState() ok = false, want true") + } + if got == nil { + t.Fatal("ReadState() state = nil, want state") + } + if !reflect.DeepEqual(*got, want) { + t.Fatalf("ReadState() state = %#v, want %#v", *got, want) + } +} + +func TestReadState_CorruptOrUnknownSchemaUnreadable(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + {name: "corrupt json", data: []byte(`{"schema_version":`)}, + {name: "unknown schema", data: []byte(`{"schema_version":2,"version":"1.2.3"}`)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + if err := os.WriteFile(filepath.Join(dir, stateFile), tt.data, 0o644); err != nil { + t.Fatal(err) + } + + state, ok, err := ReadState() + if err != nil { + t.Fatalf("ReadState() err = %v, want nil", err) + } + if ok { + t.Fatal("ReadState() ok = true, want false") + } + if state != nil { + t.Fatalf("ReadState() state = %#v, want nil", state) + } + }) + } +} + +func TestWriteState_CreatesDirAndWritesState(t *testing.T) { + dir := filepath.Join(t.TempDir(), "nested") + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + + state := SkillsState{ + Version: "1.2.3", + UpdatedAt: "2026-05-18T10:00:00Z", + } + if err := WriteState(state); err != nil { + t.Fatalf("WriteState() err = %v, want nil", err) + } + + raw, err := os.ReadFile(filepath.Join(dir, stateFile)) + if err != nil { + t.Fatal(err) + } + var got SkillsState + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("written state is invalid JSON: %v", err) + } + if got.SchemaVersion != 1 { + t.Fatalf("schema_version = %d, want 1", got.SchemaVersion) + } + if got.Version != state.Version { + t.Fatalf("version = %q, want %q", got.Version, state.Version) + } + if got.OfficialSkills == nil { + t.Fatal("official_skills decoded as nil, want empty slice") + } + if got.UpdatedSkills == nil { + t.Fatal("updated_skills decoded as nil, want empty slice") + } + if got.AddedSkills == nil { + t.Fatal("added_skills decoded as nil, want empty slice") + } + if got.SkippedDeletedSkills == nil { + t.Fatal("skipped_deleted_skills decoded as nil, want empty slice") + } +} + +func TestReadSyncedVersionFromState(t *testing.T) { + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + + if got, ok := ReadSyncedVersion(); ok || got != "" { + t.Fatalf("ReadSyncedVersion() = (%q, %v), want (\"\", false) for missing state", got, ok) + } + if err := WriteState(SkillsState{Version: "1.2.3"}); err != nil { + t.Fatal(err) + } + if got, ok := ReadSyncedVersion(); !ok || got != "1.2.3" { + t.Fatalf("ReadSyncedVersion() = (%q, %v), want (\"1.2.3\", true)", got, ok) + } + if err := WriteState(SkillsState{}); err != nil { + t.Fatal(err) + } + if got, ok := ReadSyncedVersion(); ok || got != "" { + t.Fatalf("ReadSyncedVersion() = (%q, %v), want (\"\", false) for empty version", got, ok) + } +} diff --git a/internal/skillscheck/sync.go b/internal/skillscheck/sync.go new file mode 100644 index 000000000..068707bde --- /dev/null +++ b/internal/skillscheck/sync.go @@ -0,0 +1,265 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package skillscheck + +import ( + "fmt" + "regexp" + "sort" + "strings" + "time" + + "github.com/larksuite/cli/internal/selfupdate" +) + +var skillNamePattern = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9_:-]*(@[^\s]+)?$`) + +type SyncInput struct { + Version string + OfficialSkills []string + LocalSkills []string + PreviousState *SkillsState + StateReadable bool + Force bool +} + +type SyncPlan struct { + Version string + OfficialSkills []string + ToUpdate []string + Added []string + SkippedDeleted []string +} + +func ParseSkillsList(text string) []string { + seen := map[string]bool{} + for _, line := range strings.Split(text, "\n") { + token := strings.TrimSpace(line) + token = strings.TrimPrefix(token, "-") + token = strings.TrimSpace(token) + if token == "" || strings.Contains(token, " ") || strings.HasSuffix(token, ":") { + continue + } + if !skillNamePattern.MatchString(token) { + continue + } + if at := strings.Index(token, "@"); at > 0 { + token = token[:at] + } + seen[token] = true + } + return sortedKeys(seen) +} + +func PlanSync(input SyncInput) SyncPlan { + official := uniqueSorted(input.OfficialSkills) + if input.Force { + return SyncPlan{ + Version: input.Version, + OfficialSkills: official, + ToUpdate: official, + Added: []string{}, + SkippedDeleted: []string{}, + } + } + + officialSet := toSet(official) + localOfficial := intersection(input.LocalSkills, officialSet) + + previousOfficial := []string{} + if input.StateReadable && input.PreviousState != nil { + previousOfficial = input.PreviousState.OfficialSkills + } + previousSet := toSet(previousOfficial) + + newOfficial := []string{} + for _, skill := range official { + if !previousSet[skill] { + newOfficial = append(newOfficial, skill) + } + } + + updateSet := toSet(localOfficial) + for _, skill := range newOfficial { + updateSet[skill] = true + } + toUpdate := sortedKeys(updateSet) + updateSet = toSet(toUpdate) + + skipped := []string{} + for _, skill := range official { + if !updateSet[skill] { + skipped = append(skipped, skill) + } + } + + return SyncPlan{ + Version: input.Version, + OfficialSkills: official, + ToUpdate: toUpdate, + Added: uniqueSorted(newOfficial), + SkippedDeleted: skipped, + } +} + +type SkillsRunner interface { + ListOfficialSkills() *selfupdate.NpmResult + ListGlobalSkills() *selfupdate.NpmResult + InstallSkill(name string) *selfupdate.NpmResult +} + +type SyncOptions struct { + Version string + Force bool + Runner SkillsRunner + Now func() time.Time +} + +type SyncResult struct { + Action string + Official []string + Updated []string + Added []string + SkippedDeleted []string + Failed []string + Err error + Detail string + Force bool +} + +func SyncSkills(opts SyncOptions) *SyncResult { + if opts.Now == nil { + opts.Now = time.Now + } + if opts.Runner == nil { + return &SyncResult{Action: "failed", Err: fmt.Errorf("skills runner is nil")} + } + + officialResult := opts.Runner.ListOfficialSkills() + if officialResult == nil { + return &SyncResult{Action: "failed", Err: fmt.Errorf("failed to list official skills: empty result")} + } + if officialResult.Err != nil { + return &SyncResult{Action: "failed", Err: fmt.Errorf("failed to list official skills: %w", officialResult.Err), Detail: resultDetail(officialResult)} + } + official := ParseSkillsList(officialResult.Stdout.String()) + + localResult := opts.Runner.ListGlobalSkills() + if localResult == nil { + return &SyncResult{Action: "failed", Official: official, Err: fmt.Errorf("failed to list installed skills: empty result")} + } + if localResult.Err != nil { + return &SyncResult{Action: "failed", Official: official, Err: fmt.Errorf("failed to list installed skills: %w", localResult.Err), Detail: resultDetail(localResult)} + } + local := ParseSkillsList(localResult.Stdout.String()) + + previous, readable, err := ReadState() + if err != nil { + return &SyncResult{Action: "failed", Official: official, Err: fmt.Errorf("failed to read skills state: %w", err)} + } + + plan := PlanSync(SyncInput{ + Version: opts.Version, + OfficialSkills: official, + LocalSkills: local, + PreviousState: previous, + StateReadable: readable, + Force: opts.Force, + }) + + result := &SyncResult{ + Action: "synced", + Official: plan.OfficialSkills, + Updated: plan.ToUpdate, + Added: plan.Added, + SkippedDeleted: plan.SkippedDeleted, + Force: opts.Force, + } + + failed := []string{} + var details []string + for _, skill := range plan.ToUpdate { + installResult := opts.Runner.InstallSkill(skill) + if installResult == nil { + failed = append(failed, skill) + details = append(details, skill+": empty result") + continue + } + if installResult.Err != nil { + failed = append(failed, skill) + details = append(details, skill+": "+resultDetail(installResult)) + } + } + if len(failed) > 0 { + result.Action = "failed" + result.Failed = failed + result.Err = fmt.Errorf("%d skill(s) failed", len(failed)) + result.Detail = strings.Join(details, "\n") + return result + } + + state := SkillsState{ + Version: opts.Version, + OfficialSkills: plan.OfficialSkills, + UpdatedSkills: plan.ToUpdate, + AddedSkills: plan.Added, + SkippedDeletedSkills: plan.SkippedDeleted, + UpdatedAt: opts.Now().UTC().Format(time.RFC3339), + } + if err := WriteState(state); err != nil { + result.Action = "failed" + result.Err = fmt.Errorf("skills synced but state not written: %w", err) + return result + } + + return result +} + +func resultDetail(result *selfupdate.NpmResult) string { + if result == nil { + return "" + } + parts := []string{} + if output := strings.TrimSpace(result.CombinedOutput()); output != "" { + parts = append(parts, output) + } + if result.Err != nil { + parts = append(parts, result.Err.Error()) + } + return strings.Join(parts, "\n") +} + +func uniqueSorted(values []string) []string { + return sortedKeys(toSet(values)) +} + +func toSet(values []string) map[string]bool { + out := map[string]bool{} + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + out[value] = true + } + } + return out +} + +func intersection(values []string, allowed map[string]bool) []string { + out := map[string]bool{} + for _, value := range values { + if allowed[value] { + out[value] = true + } + } + return sortedKeys(out) +} + +func sortedKeys(values map[string]bool) []string { + out := make([]string, 0, len(values)) + for value := range values { + out = append(out, value) + } + sort.Strings(out) + return out +} diff --git a/internal/skillscheck/sync_test.go b/internal/skillscheck/sync_test.go new file mode 100644 index 000000000..4b7de39c2 --- /dev/null +++ b/internal/skillscheck/sync_test.go @@ -0,0 +1,222 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package skillscheck + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + "time" + + "github.com/larksuite/cli/internal/selfupdate" +) + +func TestParseSkillsList(t *testing.T) { + input := `Installed skills: +- lark-calendar +- lark-mail +lark-im +custom-skill +lark-base@1.0.0 +lark-cli-harness:dev@0.1.0 +` + got := ParseSkillsList(input) + want := []string{"custom-skill", "lark-base", "lark-calendar", "lark-cli-harness:dev", "lark-im", "lark-mail"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("ParseSkillsList() = %#v, want %#v", got, want) + } +} + +func TestPlanNormal_WithReadableStatePreservesDeletedAndAddsNew(t *testing.T) { + previous := &SkillsState{OfficialSkills: []string{"lark-calendar", "lark-mail"}} + got := PlanSync(SyncInput{ + Version: "1.0.33", + OfficialSkills: []string{"lark-calendar", "lark-mail", "lark-new"}, + LocalSkills: []string{"lark-calendar", "lark-custom"}, + PreviousState: previous, + StateReadable: true, + Force: false, + }) + + assertStrings(t, got.ToUpdate, []string{"lark-calendar", "lark-new"}) + assertStrings(t, got.Added, []string{"lark-new"}) + assertStrings(t, got.SkippedDeleted, []string{"lark-mail"}) +} + +func TestPlanNormal_MissingStateInstallsAllOfficial(t *testing.T) { + got := PlanSync(SyncInput{ + Version: "1.0.33", + OfficialSkills: []string{"lark-calendar", "lark-mail", "lark-new"}, + LocalSkills: []string{"lark-calendar"}, + StateReadable: false, + Force: false, + }) + + assertStrings(t, got.ToUpdate, []string{"lark-calendar", "lark-mail", "lark-new"}) + assertStrings(t, got.Added, []string{"lark-calendar", "lark-mail", "lark-new"}) + assertStrings(t, got.SkippedDeleted, []string{}) +} + +func TestPlanForceRestoresAllOfficial(t *testing.T) { + got := PlanSync(SyncInput{ + Version: "1.0.33", + OfficialSkills: []string{"lark-calendar", "lark-mail", "lark-new"}, + LocalSkills: []string{"lark-calendar"}, + PreviousState: &SkillsState{OfficialSkills: []string{"lark-calendar", "lark-mail"}}, + StateReadable: true, + Force: true, + }) + + assertStrings(t, got.ToUpdate, []string{"lark-calendar", "lark-mail", "lark-new"}) + assertStrings(t, got.Added, []string{}) + assertStrings(t, got.SkippedDeleted, []string{}) +} + +type fakeSkillsRunner struct { + officialOut string + globalOut string + officialErr error + globalErr error + installErr map[string]error + installed []string +} + +func (f *fakeSkillsRunner) ListOfficialSkills() *selfupdate.NpmResult { + r := &selfupdate.NpmResult{} + r.Stdout.WriteString(f.officialOut) + r.Err = f.officialErr + return r +} + +func (f *fakeSkillsRunner) ListGlobalSkills() *selfupdate.NpmResult { + r := &selfupdate.NpmResult{} + r.Stdout.WriteString(f.globalOut) + r.Err = f.globalErr + return r +} + +func (f *fakeSkillsRunner) InstallSkill(name string) *selfupdate.NpmResult { + f.installed = append(f.installed, name) + r := &selfupdate.NpmResult{} + if f.installErr != nil { + r.Err = f.installErr[name] + } + return r +} + +func TestSyncSkills_WritesStateAndDoesNotWriteStamp(t *testing.T) { + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + if err := WriteState(SkillsState{ + Version: "1.0.30", + OfficialSkills: []string{"lark-calendar", "lark-mail"}, + UpdatedAt: "2026-05-18T00:00:00Z", + }); err != nil { + t.Fatal(err) + } + + runner := &fakeSkillsRunner{ + officialOut: "lark-calendar\nlark-mail\nlark-new\n", + globalOut: "lark-calendar\nlark-custom\n", + } + result := SyncSkills(SyncOptions{ + Version: "1.0.33", + Runner: runner, + Now: func() time.Time { return time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC) }, + }) + + if result.Err != nil { + t.Fatalf("SyncSkills() err = %v, want nil", result.Err) + } + assertStrings(t, runner.installed, []string{"lark-calendar", "lark-new"}) + + state, readable, err := ReadState() + if err != nil || !readable { + t.Fatalf("ReadState() = (_, %v, %v), want readable", readable, err) + } + assertStrings(t, state.OfficialSkills, []string{"lark-calendar", "lark-mail", "lark-new"}) + assertStrings(t, state.UpdatedSkills, []string{"lark-calendar", "lark-new"}) + assertStrings(t, state.AddedSkills, []string{"lark-new"}) + assertStrings(t, state.SkippedDeletedSkills, []string{"lark-mail"}) + if _, err := os.Stat(filepath.Join(dir, "skills.stamp")); !os.IsNotExist(err) { + t.Fatalf("skills.stamp exists or stat failed with unexpected err: %v", err) + } +} + +func TestSyncSkills_ListFailureDoesNotInstallOrWriteState(t *testing.T) { + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + runner := &fakeSkillsRunner{officialErr: fmt.Errorf("list failed")} + + result := SyncSkills(SyncOptions{Version: "1.0.33", Runner: runner, Now: time.Now}) + if result.Err == nil || !strings.Contains(result.Err.Error(), "failed to list official skills") { + t.Fatalf("SyncSkills() err = %v, want official list failure", result.Err) + } + if len(runner.installed) != 0 { + t.Fatalf("installed = %#v, want none", runner.installed) + } + if _, readable, err := ReadState(); err != nil || readable { + t.Fatalf("ReadState() = (_, %v, %v), want unreadable missing state", readable, err) + } +} + +func TestSyncSkills_GlobalListFailureDoesNotInstallOrWriteState(t *testing.T) { + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + runner := &fakeSkillsRunner{ + officialOut: "lark-calendar\nlark-mail\n", + globalErr: fmt.Errorf("global list failed"), + } + + result := SyncSkills(SyncOptions{Version: "1.0.33", Runner: runner, Now: time.Now}) + if result.Err == nil || !strings.Contains(result.Err.Error(), "failed to list installed skills") { + t.Fatalf("SyncSkills() err = %v, want installed list failure", result.Err) + } + if len(runner.installed) != 0 { + t.Fatalf("installed = %#v, want none", runner.installed) + } + if _, readable, err := ReadState(); err != nil || readable { + t.Fatalf("ReadState() = (_, %v, %v), want unreadable missing state", readable, err) + } +} + +func TestSyncSkills_InstallFailureContinuesAndDoesNotWriteState(t *testing.T) { + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + runner := &fakeSkillsRunner{ + officialOut: "lark-calendar\nlark-mail\n", + globalOut: "lark-calendar\nlark-mail\n", + installErr: map[string]error{"lark-calendar": fmt.Errorf("boom")}, + } + + result := SyncSkills(SyncOptions{Version: "1.0.33", Runner: runner, Now: time.Now}) + if result.Err == nil || !strings.Contains(result.Err.Error(), "1 skill(s) failed") { + t.Fatalf("SyncSkills() err = %v, want install failure", result.Err) + } + assertStrings(t, runner.installed, []string{"lark-calendar", "lark-mail"}) + assertStrings(t, result.Failed, []string{"lark-calendar"}) + if !strings.Contains(result.Detail, "boom") { + t.Fatalf("SyncSkills() detail = %q, want install error text", result.Detail) + } + if _, readable, err := ReadState(); err != nil || readable { + t.Fatalf("ReadState() = (_, %v, %v), want no success state", readable, err) + } +} + +func TestSyncSkills_NilRunnerFails(t *testing.T) { + result := SyncSkills(SyncOptions{Version: "1.0.33", Now: time.Now}) + if result.Err == nil || !strings.Contains(result.Err.Error(), "skills runner is nil") { + t.Fatalf("SyncSkills() err = %v, want nil runner failure", result.Err) + } +} + +func assertStrings(t *testing.T, got, want []string) { + t.Helper() + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %#v, want %#v", got, want) + } +} diff --git a/package.json b/package.json index da06bc69c..a5eb850a4 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@larksuite/cli", - "version": "1.0.31", + "version": "1.0.34", "description": "The official CLI for Lark/Feishu open platform", "bin": { "lark-cli": "scripts/run.js" diff --git a/scripts/check-doc-tokens.sh b/scripts/check-doc-tokens.sh new file mode 100755 index 000000000..a02c8f140 --- /dev/null +++ b/scripts/check-doc-tokens.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Lark Technologies Pte. Ltd. +# SPDX-License-Identifier: MIT +# +# check-doc-tokens.sh +# +# Scans skill reference docs for token-like values that look realistic but +# are not using the required placeholder format (*_EXAMPLE_TOKEN or similar). +# +# Real token patterns (Lark API) often look like: +# wikcnXXXXXXXXX doccnXXXXXXX shtcnXXX fldcnXXX ou_XXXX cli_XXXX +# +# Docs MUST use clearly fake placeholders, e.g.: +# wikcn_EXAMPLE_TOKEN doccn_EXAMPLE_TOKEN your_token_here +# +# If this check fails, replace the realistic-looking value with a placeholder +# like `wikcn_EXAMPLE_TOKEN` so gitleaks CI won't flag it as a real secret. + +set -euo pipefail + +SKILLS_DIR="${1:-skills}" +ERRORS=0 + +# Patterns that indicate a realistic-looking Lark token value. +# Three forms are detected: +# 1. JSON-style quoted strings: "field": "token_value" +# 2. Markdown backtick spans: `token_value` +# 3. Bare tokens: --flag wikcnABC123 (e.g. inside fenced code blocks) +# +# Token prefixes used by Lark Open Platform: +# wikcn doccn docx shtcn bascn fldcn vewcn tbln ou_ cli_ obcn flec +# +# Excluded (clearly fake, matched by PLACEHOLDER_RE below): +# - Values containing EXAMPLE / _TOKEN / XXXX / your_ / _here +# - Angle-bracket placeholders +# Require at least one digit in the suffix — real API tokens are always alphanumeric +# with digits. Pure-letter suffixes (e.g. ou_manager, ou_director) are clearly fake names. +PREFIXES='(wikcn|doccn|docx[a-z]|shtcn|bascn|fldcn|vewcn|tbln|obcn|flec|ou_|cli_)' +TOKEN_BODY="${PREFIXES}"'[A-Za-z0-9]*[0-9][A-Za-z0-9]{3,}' +REALISTIC_TOKEN_RE="\"${TOKEN_BODY}\"|\`${TOKEN_BODY}\`|\\b${TOKEN_BODY}\\b" +PLACEHOLDER_RE='(EXAMPLE|_TOKEN|XXXX|xxxx|<|>|your_|_here)' + +while IFS= read -r -d '' file; do + # grep returns exit 1 when no match — use || true to avoid set -e killing us + # Then filter out values that are clearly placeholders (EXAMPLE, XXXX, etc.) + matches=$(grep -nEo "$REALISTIC_TOKEN_RE" "$file" 2>/dev/null | grep -vE "$PLACEHOLDER_RE" || true) + if [[ -n "$matches" ]]; then + echo "" + echo "❌ $file" + echo " Contains realistic-looking token values that may trigger gitleaks:" + while IFS= read -r line; do + echo " $line" + done <<< "$matches" + echo " → Replace with a placeholder, e.g.: wikcn_EXAMPLE_TOKEN, doccn_EXAMPLE_TOKEN" + ERRORS=$((ERRORS + 1)) + fi +done < <(find "$SKILLS_DIR" -path "*/references/*.md" -print0) + +if [[ $ERRORS -gt 0 ]]; then + echo "" + echo "❌ check-doc-tokens: $ERRORS file(s) contain realistic token values in reference docs." + echo " Use _EXAMPLE_TOKEN placeholders to avoid false positives in gitleaks CI." + exit 1 +else + echo "✅ check-doc-tokens: all reference docs use safe placeholder tokens." +fi diff --git a/scripts/install-wizard.js b/scripts/install-wizard.js index 4bc76f5d1..91fa2271a 100644 --- a/scripts/install-wizard.js +++ b/scripts/install-wizard.js @@ -10,6 +10,8 @@ const p = require("@clack/prompts"); const PKG = "@larksuite/cli"; const SKILLS_REPO = "https://open.feishu.cn"; const SKILLS_REPO_FALLBACK = "larksuite/cli"; +const CONFIG_DIR = process.env.LARKSUITE_CLI_CONFIG_DIR || path.join(process.env.HOME || process.env.USERPROFILE || "", ".lark-cli"); +const SKILLS_STATE_FILE = path.join(CONFIG_DIR, "skills-state.json"); const isWindows = process.platform === "win32"; // --------------------------------------------------------------------------- @@ -236,7 +238,7 @@ async function stepInstallGlobally(msg) { if (installedVer && !needsUpgrade) { p.log.info(fmt(msg.step1Skip, installedVer)); - return false; + return installedVer; } const s = p.spinner(); @@ -248,41 +250,111 @@ async function stepInstallGlobally(msg) { try { await runSilentAsync("npm", ["install", "-g", PKG], { timeout: 120000 }); s.stop(needsUpgrade ? fmt(msg.step1Upgraded, latestVer) : msg.step1Done); - return needsUpgrade; + return latestVer || getGloballyInstalledVersion() || installedVer || null; } catch (_) { s.stop(fmt(msg.step1Fail, PKG)); process.exit(1); } } -async function skillsAlreadyInstalled() { +function parseSkillsList(text) { + const seen = new Set(); + for (const rawLine of text.split("\n")) { + let token = rawLine.trim(); + if (token.startsWith("-")) token = token.slice(1).trim(); + if (!token || token.includes(" ") || token.endsWith(":")) continue; + if (!/^[A-Za-z0-9][A-Za-z0-9_:-]*(?:@\S+)?$/.test(token)) continue; + const at = token.indexOf("@"); + if (at > 0) token = token.slice(0, at); + seen.add(token); + } + return [...seen].sort(); +} + +function readSkillsState() { try { - const out = await runSilentAsync("npx", ["-y", "skills", "ls", "-g"], { - timeout: 120000, - }); - return /^lark-/m.test(out.toString()); + const state = JSON.parse(fs.readFileSync(SKILLS_STATE_FILE, "utf8")); + if (state.schema_version !== 1 || !Array.isArray(state.official_skills)) return null; + return state; + } catch (_) { + return null; + } +} + +function writeSkillsState(version, official, updated, added, skipped) { + if (!CONFIG_DIR) return; + fs.mkdirSync(CONFIG_DIR, { recursive: true, mode: 0o700 }); + fs.writeFileSync(SKILLS_STATE_FILE, JSON.stringify({ + schema_version: 1, + version, + official_skills: official, + updated_skills: updated, + added_skills: added, + skipped_deleted_skills: skipped, + updated_at: new Date().toISOString(), + }, null, 2) + "\n"); +} + +async function listOfficialSkills() { + try { + return parseSkillsList(await runSilentAsync("npx", ["-y", "skills", "add", SKILLS_REPO, "--list"], { timeout: 120000 })); } catch (_) { - return false; + return parseSkillsList(await runSilentAsync("npx", ["-y", "skills", "add", SKILLS_REPO_FALLBACK, "--list"], { timeout: 120000 })); } } -async function stepInstallSkills(msg) { +async function listGlobalSkills() { + return parseSkillsList(await runSilentAsync("npx", ["-y", "skills", "ls", "-g"], { timeout: 120000 })); +} + +function planSkillsSync(version, official, local, previousState) { + const officialSet = new Set(official); + const previousSet = new Set(previousState ? previousState.official_skills : []); + const localOfficial = local.filter((skill) => officialSet.has(skill)); + const added = official.filter((skill) => !previousSet.has(skill)); + const updateSet = new Set([...localOfficial, ...added]); + const updated = official.filter((skill) => updateSet.has(skill)); + return { + version, + official, + updated, + added, + skipped: official.filter((skill) => !updateSet.has(skill)), + }; +} + +async function installSkill(name) { + try { + await runSilentAsync("npx", ["-y", "skills", "add", SKILLS_REPO, "-s", name, "-g", "-y"], { timeout: 120000 }); + } catch (_) { + await runSilentAsync("npx", ["-y", "skills", "add", SKILLS_REPO_FALLBACK, "-s", name, "-g", "-y"], { timeout: 120000 }); + } +} + +async function stepInstallSkills(msg, cliVersion) { const s = p.spinner(); s.start(msg.step2Spinner); try { - if (await skillsAlreadyInstalled()) { + const official = await listOfficialSkills(); + const local = await listGlobalSkills(); + const plan = planSkillsSync(cliVersion || "unknown", official, local, readSkillsState()); + if (plan.updated.length === 0) { + writeSkillsState(plan.version, plan.official, plan.updated, plan.added, plan.skipped); s.stop(msg.step2Skip); return; } - try { - await runSilentAsync("npx", ["-y", "skills", "add", SKILLS_REPO, "-y", "-g"], { - timeout: 120000, - }); - } catch (_) { - await runSilentAsync("npx", ["-y", "skills", "add", SKILLS_REPO_FALLBACK, "-y", "-g"], { - timeout: 120000, - }); + const failed = []; + for (const skill of plan.updated) { + try { + await installSkill(skill); + } catch (_) { + failed.push(skill); + } + } + if (failed.length > 0) { + throw new Error(`${failed.length} skill(s) failed: ${failed.join(", ")}`); } + writeSkillsState(plan.version, plan.official, plan.updated, plan.added, plan.skipped); s.stop(msg.step2Done); } catch (_) { s.stop(fmt(msg.step2Fail, SKILLS_REPO_FALLBACK)); @@ -361,15 +433,15 @@ async function main() { if (isInteractive) { p.intro(msg.setup); - await stepInstallGlobally(msg); - await stepInstallSkills(msg); + const cliVersion = await stepInstallGlobally(msg); + await stepInstallSkills(msg, cliVersion); await stepConfigInit(msg, lang); await stepAuthLogin(msg); p.outro(msg.done); } else { console.log(msg.setup); - await stepInstallGlobally(msg); - await stepInstallSkills(msg); + const cliVersion = await stepInstallGlobally(msg); + await stepInstallSkills(msg, cliVersion); console.log(msg.nonTtyHint); } } diff --git a/shortcuts/base/base_dryrun_ops_test.go b/shortcuts/base/base_dryrun_ops_test.go index b3d59aa7b..f25b99ac2 100644 --- a/shortcuts/base/base_dryrun_ops_test.go +++ b/shortcuts/base/base_dryrun_ops_test.go @@ -149,29 +149,26 @@ func TestDryRunRecordOps(t *testing.T) { assertDryRunContains(t, dryRunRecordGet(ctx, getJSONRT), "POST /open-apis/base/v3/bases/app_x/tables/tbl_1/records/batch_get", `"record_id_list":["rec_3"]`, `"select_fields":["Status"]`) assertDryRunContains(t, dryRunRecordDelete(ctx, getJSONRT), "POST /open-apis/base/v3/bases/app_x/tables/tbl_1/records/batch_delete", `"record_id_list":["rec_3"]`) - uploadAttachmentRT := newBaseTestRuntime( + uploadAttachmentRT := newBaseTestRuntimeWithArrays( map[string]string{ "base-token": "app_x", "table-id": "tbl_1", "record-id": "rec_1", "field-id": "fld_att", - "file": "/tmp/report.pdf", - "name": "report-final.pdf", }, + map[string][]string{"file": {"/tmp/report.pdf"}}, nil, nil, ) assertDryRunContains(t, BaseRecordUploadAttachment.DryRun(ctx, uploadAttachmentRT), "GET /open-apis/base/v3/bases/app_x/tables/tbl_1/fields/fld_att", - "GET /open-apis/base/v3/bases/app_x/tables/tbl_1/records/rec_1", "POST /open-apis/drive/v1/medias/upload_all", "bitable_file", - "PATCH /open-apis/base/v3/bases/app_x/tables/tbl_1/records/rec_1", - "report-final.pdf", - `"mime_type":"\u003cdetected_mime_type\u003e"`, - `"size":"\u003cfile_size\u003e"`, - "deprecated_set_attachment", + "POST /open-apis/base/v3/bases/app_x/tables/tbl_1/append_attachments", + "report.pdf", + `"image_width":"\u003cimage_width_if_image\u003e"`, + `"image_height":"\u003cimage_height_if_image\u003e"`, ) } diff --git a/shortcuts/base/base_execute_test.go b/shortcuts/base/base_execute_test.go index 741b2f0e3..0a9a1d709 100644 --- a/shortcuts/base/base_execute_test.go +++ b/shortcuts/base/base_execute_test.go @@ -7,6 +7,11 @@ import ( "bytes" "context" "encoding/json" + "errors" + "image" + "image/color" + "image/png" + "net/url" "os" "path/filepath" "strings" @@ -15,6 +20,7 @@ import ( "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/httpmock" + "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/shortcuts/common" "github.com/spf13/cobra" ) @@ -432,7 +438,7 @@ func TestBaseFieldExecuteUpdate(t *testing.T) { "data": map[string]interface{}{"id": "fld_x", "name": "Amount", "type": "number"}, }, }) - if err := runShortcut(t, BaseFieldUpdate, []string{"+field-update", "--base-token", "app_x", "--table-id", "tbl_x", "--field-id", "fld_x", "--json", `{"name":"Amount","type":"number"}`}, factory, stdout); err != nil { + if err := runShortcut(t, BaseFieldUpdate, []string{"+field-update", "--base-token", "app_x", "--table-id", "tbl_x", "--field-id", "fld_x", "--json", `{"name":"Amount","type":"number"}`, "--yes"}, factory, stdout); err != nil { t.Fatalf("err=%v", err) } if got := stdout.String(); !strings.Contains(got, `"updated": true`) || !strings.Contains(got, `"fld_x"`) { @@ -1589,12 +1595,14 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Run("upload attachment", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - tmpFile, err := os.CreateTemp(t.TempDir(), "base-attachment-*.txt") + tmpFile, err := os.CreateTemp(t.TempDir(), "base-attachment-*.png") if err != nil { t.Fatalf("CreateTemp() err=%v", err) } - if _, err := tmpFile.WriteString("hello attachment"); err != nil { - t.Fatalf("WriteString() err=%v", err) + img := image.NewRGBA(image.Rect(0, 0, 3, 2)) + img.Set(0, 0, color.RGBA{R: 255, A: 255}) + if err := png.Encode(tmpFile, img); err != nil { + t.Fatalf("png.Encode() err=%v", err) } if err := tmpFile.Close(); err != nil { t.Fatalf("Close() err=%v", err) @@ -1609,28 +1617,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { "data": map[string]interface{}{"id": "fld_att", "name": "附件", "type": "attachment"}, }, }) - reg.Register(&httpmock.Stub{ - Method: "GET", - URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/records/rec_x", - Body: map[string]interface{}{ - "code": 0, - "data": map[string]interface{}{ - "record_id": "rec_x", - "fields": map[string]interface{}{ - "附件": []interface{}{ - map[string]interface{}{ - "file_token": "existing_tok", - "name": "existing.pdf", - "size": 2048, - "image_width": 640, - "image_height": 480, - "deprecated_set_attachment": false, - }, - }, - }, - }, - }, - }) uploadStub := &httpmock.Stub{ Method: "POST", URL: "/open-apis/drive/v1/medias/upload_all", @@ -1640,34 +1626,27 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { }, } reg.Register(uploadStub) - updateStub := &httpmock.Stub{ - Method: "PATCH", - URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/records/rec_x", + appendStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/append_attachments", Body: map[string]interface{}{ "code": 0, "data": map[string]interface{}{ - "record_id": "rec_x", - "fields": map[string]interface{}{ - "附件": []interface{}{ - map[string]interface{}{ - "file_token": "existing_tok", - "name": "existing.pdf", - "size": 2048, - "image_width": 640, - "image_height": 480, - "deprecated_set_attachment": true, - }, - map[string]interface{}{ - "file_token": "file_tok_1", - "name": "report.txt", - "deprecated_set_attachment": true, + "attachments": map[string]interface{}{ + "rec_x": map[string]interface{}{ + "fld_att": []interface{}{ + map[string]interface{}{ + "file_token": "file_tok_1", + "name": "base-attachment.png", + "size": 73, + }, }, }, }, }, }, } - reg.Register(updateStub) + reg.Register(appendStub) if err := runShortcut(t, BaseRecordUploadAttachment, []string{ "+record-upload-attachment", @@ -1676,11 +1655,10 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { "--record-id", "rec_x", "--field-id", "fld_att", "--file", "./" + filepath.Base(tmpFile.Name()), - "--name", "report.txt", }, factory, stdout); err != nil { t.Fatalf("err=%v", err) } - if got := stdout.String(); !strings.Contains(got, `"updated": true`) || !strings.Contains(got, `"file_tok_1"`) || !strings.Contains(got, `"report.txt"`) { + if got := stdout.String(); !strings.Contains(got, `"file_tok_1"`) || strings.Contains(got, `"updated"`) || strings.Contains(got, `"uploaded"`) { t.Fatalf("stdout=%s", got) } @@ -1689,19 +1667,13 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Fatalf("upload body=%s", uploadBody) } - updateBody := string(updateStub.CapturedBody) - if !strings.Contains(updateBody, `"附件"`) || - !strings.Contains(updateBody, `"file_token":"existing_tok"`) || - !strings.Contains(updateBody, `"name":"existing.pdf"`) || - !strings.Contains(updateBody, `"size":2048`) || - !strings.Contains(updateBody, `"image_width":640`) || - !strings.Contains(updateBody, `"image_height":480`) || - !strings.Contains(updateBody, `"deprecated_set_attachment":true`) || - !strings.Contains(updateBody, `"file_token":"file_tok_1"`) || - !strings.Contains(updateBody, `"name":"report.txt"`) || - !strings.Contains(updateBody, `"size":16`) || - !strings.Contains(updateBody, `"mime_type":"text/plain"`) { - t.Fatalf("update body=%s", updateBody) + appendBody := string(appendStub.CapturedBody) + if !strings.Contains(appendBody, `"rec_x"`) || + !strings.Contains(appendBody, `"fld_att"`) || + !strings.Contains(appendBody, `"file_token":"file_tok_1"`) || + !strings.Contains(appendBody, `"image_width":3`) || + !strings.Contains(appendBody, `"image_height":2`) { + t.Fatalf("append body=%s", appendBody) } }) @@ -1728,17 +1700,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { "data": map[string]interface{}{"id": "fld_att", "name": "附件", "type": "attachment"}, }, }) - reg.Register(&httpmock.Stub{ - Method: "GET", - URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/records/rec_x", - Body: map[string]interface{}{ - "code": 0, - "data": map[string]interface{}{ - "record_id": "rec_x", - "fields": map[string]interface{}{}, - }, - }, - }) prepareStub := &httpmock.Stub{ Method: "POST", @@ -1778,26 +1739,23 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { } reg.Register(finishStub) - updateStub := &httpmock.Stub{ - Method: "PATCH", - URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/records/rec_x", + appendStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/append_attachments", Body: map[string]interface{}{ "code": 0, "data": map[string]interface{}{ - "record_id": "rec_x", - "fields": map[string]interface{}{ - "附件": []interface{}{ - map[string]interface{}{ - "file_token": "file_tok_big", - "name": "large-report.bin", - "deprecated_set_attachment": true, + "attachments": map[string]interface{}{ + "rec_x": map[string]interface{}{ + "fld_att": []interface{}{ + map[string]interface{}{"file_token": "file_tok_big"}, }, }, }, }, }, } - reg.Register(updateStub) + reg.Register(appendStub) if err := runShortcut(t, BaseRecordUploadAttachment, []string{ "+record-upload-attachment", @@ -1806,17 +1764,16 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { "--record-id", "rec_x", "--field-id", "fld_att", "--file", "./" + filepath.Base(tmpFile.Name()), - "--name", "large-report.bin", }, factory, stdout); err != nil { t.Fatalf("err=%v", err) } - if got := stdout.String(); !strings.Contains(got, `"updated": true`) || !strings.Contains(got, `"file_tok_big"`) || !strings.Contains(got, `"large-report.bin"`) { + if got := stdout.String(); !strings.Contains(got, `"file_tok_big"`) || strings.Contains(got, `"updated"`) || strings.Contains(got, `"uploaded"`) { t.Fatalf("stdout=%s", got) } prepareBody := string(prepareStub.CapturedBody) - if !strings.Contains(prepareBody, `"file_name":"large-report.bin"`) || + if !strings.Contains(prepareBody, `"file_name":"`+filepath.Base(tmpFile.Name())+`"`) || !strings.Contains(prepareBody, `"parent_type":"bitable_file"`) || !strings.Contains(prepareBody, `"parent_node":"app_x"`) || !strings.Contains(prepareBody, `"size":20971521`) { @@ -1847,14 +1804,11 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Fatalf("finish body=%s", finishBody) } - updateBody := string(updateStub.CapturedBody) - if !strings.Contains(updateBody, `"附件"`) || - !strings.Contains(updateBody, `"file_token":"file_tok_big"`) || - !strings.Contains(updateBody, `"name":"large-report.bin"`) || - !strings.Contains(updateBody, `"size":20971521`) || - !strings.Contains(updateBody, `"mime_type":"application/octet-stream"`) || - !strings.Contains(updateBody, `"deprecated_set_attachment":true`) { - t.Fatalf("update body=%s", updateBody) + appendBody := string(appendStub.CapturedBody) + if !strings.Contains(appendBody, `"rec_x"`) || + !strings.Contains(appendBody, `"fld_att"`) || + !strings.Contains(appendBody, `"file_token":"file_tok_big"`) { + t.Fatalf("append body=%s", appendBody) } }) @@ -1928,6 +1882,434 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Fatalf("err=%v", err) } }) + + t.Run("upload attachment rejects deprecated name flag", func(t *testing.T) { + factory, stdout, _ := newExecuteFactory(t) + + tmpFile, err := os.CreateTemp(t.TempDir(), "base-name-*.txt") + if err != nil { + t.Fatalf("CreateTemp() err=%v", err) + } + if err := tmpFile.Close(); err != nil { + t.Fatalf("Close() err=%v", err) + } + withBaseWorkingDir(t, filepath.Dir(tmpFile.Name())) + + err = runShortcut(t, BaseRecordUploadAttachment, []string{ + "+record-upload-attachment", + "--base-token", "app_x", + "--table-id", "tbl_x", + "--record-id", "rec_x", + "--field-id", "fld_att", + "--file", "./" + filepath.Base(tmpFile.Name()), + "--name", "renamed.txt", + }, factory, stdout) + if err == nil || !strings.Contains(err.Error(), "--name is no longer supported") { + t.Fatalf("err=%v", err) + } + }) + + t.Run("download attachment includes extra query parameter", func(t *testing.T) { + factory, stdout, reg := newExecuteFactory(t) + + extra := `{"bitablePerm":{"tableId":"tbl_x","attachments":{"fld_att":{"rec_x":["box_a"]}}}}` + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/get_attachments", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "attachments": map[string]interface{}{ + "rec_x": map[string]interface{}{ + "fld_att": []interface{}{ + map[string]interface{}{ + "file_token": "box_a", + "name": "pic.png", + "size": 7, + "extra_info": extra, + }, + }, + }, + }, + }, + }, + }) + downloadStub := &httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/medias/box_a/download?" + url.Values{"extra": []string{extra}}.Encode(), + RawBody: []byte("payload"), + ContentType: "image/png", + } + reg.Register(downloadStub) + + tmpDir := t.TempDir() + withBaseWorkingDir(t, tmpDir) + if err := os.Mkdir("downloads", 0700); err != nil { + t.Fatalf("Mkdir() err=%v", err) + } + + if err := runShortcut(t, BaseRecordDownloadAttachment, []string{ + "+record-download-attachment", + "--base-token", "app_x", + "--table-id", "tbl_x", + "--record-id", "rec_x", + "--file-token", "box_a", + "--output", "downloads", + }, factory, stdout); err != nil { + t.Fatalf("err=%v", err) + } + if _, err := os.Stat(filepath.Join(tmpDir, "downloads", "pic.png")); err != nil { + t.Fatalf("expected downloaded file: %v", err) + } + data := decodeBaseEnvelope(t, stdout) + gotItems, _ := data["downloaded"].([]interface{}) + if len(gotItems) != 1 { + t.Fatalf("downloaded=%#v", data["downloaded"]) + } + got, _ := gotItems[0].(map[string]interface{}) + if got["file_token"] != "box_a" || got["saved_path"] == "" || got["extra_info_used"] != nil { + t.Fatalf("download output=%#v", got) + } + }) + + t.Run("download all row attachments when file token omitted", func(t *testing.T) { + factory, stdout, reg := newExecuteFactory(t) + + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/get_attachments", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "attachments": map[string]interface{}{ + "rec_x": map[string]interface{}{ + "fld_att": []interface{}{ + map[string]interface{}{"file_token": "box_a", "name": "a.txt", "size": 7}, + map[string]interface{}{"file_token": "box_b", "name": "b.txt", "size": 8}, + }, + }, + }, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/medias/box_a/download", + RawBody: []byte("payload-a"), + ContentType: "text/plain", + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/medias/box_b/download", + RawBody: []byte("payload-b"), + ContentType: "text/plain", + }) + + tmpDir := t.TempDir() + withBaseWorkingDir(t, tmpDir) + if err := os.Mkdir("downloads", 0700); err != nil { + t.Fatalf("Mkdir() err=%v", err) + } + + if err := runShortcut(t, BaseRecordDownloadAttachment, []string{ + "+record-download-attachment", + "--base-token", "app_x", + "--table-id", "tbl_x", + "--record-id", "rec_x", + "--output", "downloads", + }, factory, stdout); err != nil { + t.Fatalf("err=%v", err) + } + if _, err := os.Stat(filepath.Join(tmpDir, "downloads", "a.txt")); err != nil { + t.Fatalf("expected downloaded file a.txt: %v", err) + } + if _, err := os.Stat(filepath.Join(tmpDir, "downloads", "b.txt")); err != nil { + t.Fatalf("expected downloaded file b.txt: %v", err) + } + data := decodeBaseEnvelope(t, stdout) + gotItems, _ := data["downloaded"].([]interface{}) + if len(gotItems) != 2 { + t.Fatalf("downloaded=%#v", data["downloaded"]) + } + }) + + t.Run("download without file token requires output directory", func(t *testing.T) { + factory, stdout, _ := newExecuteFactory(t) + tmpDir := t.TempDir() + withBaseWorkingDir(t, tmpDir) + + err := runShortcut(t, BaseRecordDownloadAttachment, []string{ + "+record-download-attachment", + "--base-token", "app_x", + "--table-id", "tbl_x", + "--record-id", "rec_x", + "--output", "file.txt", + }, factory, stdout) + if err == nil || !strings.Contains(err.Error(), "--output must be an existing directory") { + t.Fatalf("err=%v", err) + } + }) + + t.Run("download all disambiguates duplicate attachment names with file token", func(t *testing.T) { + factory, stdout, reg := newExecuteFactory(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/get_attachments", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "attachments": map[string]interface{}{ + "rec_x": map[string]interface{}{ + "fld_att": []interface{}{ + map[string]interface{}{"file_token": "box_a", "name": "same.txt", "size": 7}, + map[string]interface{}{"file_token": "box_a", "name": "same.txt", "size": 7}, + map[string]interface{}{"file_token": "box_b", "name": "same.txt", "size": 8}, + }, + }, + }, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/medias/box_a/download", + RawBody: []byte("payload-a"), + ContentType: "text/plain", + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/medias/box_b/download", + RawBody: []byte("payload-b"), + ContentType: "text/plain", + }) + + tmpDir := t.TempDir() + withBaseWorkingDir(t, tmpDir) + if err := os.Mkdir("downloads", 0700); err != nil { + t.Fatalf("Mkdir() err=%v", err) + } + + if err := runShortcut(t, BaseRecordDownloadAttachment, []string{ + "+record-download-attachment", + "--base-token", "app_x", + "--table-id", "tbl_x", + "--record-id", "rec_x", + "--output", "downloads", + }, factory, stdout); err != nil { + t.Fatalf("err=%v", err) + } + if _, err := os.Stat(filepath.Join(tmpDir, "downloads", "same_box_a.txt")); err != nil { + t.Fatalf("expected downloaded file same_box_a.txt: %v", err) + } + if _, err := os.Stat(filepath.Join(tmpDir, "downloads", "same_box_b.txt")); err != nil { + t.Fatalf("expected downloaded file same_box_b.txt: %v", err) + } + data := decodeBaseEnvelope(t, stdout) + gotItems, _ := data["downloaded"].([]interface{}) + if len(gotItems) != 2 { + t.Fatalf("downloaded=%#v", data["downloaded"]) + } + }) + + t.Run("download duplicate requested file token only once", func(t *testing.T) { + factory, stdout, reg := newExecuteFactory(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/get_attachments", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "attachments": map[string]interface{}{ + "rec_x": map[string]interface{}{ + "fld_att": []interface{}{ + map[string]interface{}{"file_token": "box_a", "name": "a.txt", "size": 7}, + }, + }, + }, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/medias/box_a/download", + RawBody: []byte("payload-a"), + ContentType: "text/plain", + }) + + tmpDir := t.TempDir() + withBaseWorkingDir(t, tmpDir) + if err := runShortcut(t, BaseRecordDownloadAttachment, []string{ + "+record-download-attachment", + "--base-token", "app_x", + "--table-id", "tbl_x", + "--record-id", "rec_x", + "--file-token", "box_a", + "--file-token", "box_a", + "--output", "a.txt", + }, factory, stdout); err != nil { + t.Fatalf("err=%v", err) + } + data := decodeBaseEnvelope(t, stdout) + gotItems, _ := data["downloaded"].([]interface{}) + if len(gotItems) != 1 { + t.Fatalf("downloaded=%#v", data["downloaded"]) + } + }) + + t.Run("download all preflights local target conflicts before writing", func(t *testing.T) { + factory, stdout, reg := newExecuteFactory(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/get_attachments", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "attachments": map[string]interface{}{ + "rec_x": map[string]interface{}{ + "fld_att": []interface{}{ + map[string]interface{}{"file_token": "box_a", "name": "a.txt", "size": 7}, + map[string]interface{}{"file_token": "box_b", "name": "b.txt", "size": 8}, + }, + }, + }, + }, + }, + }) + + tmpDir := t.TempDir() + withBaseWorkingDir(t, tmpDir) + if err := os.Mkdir("downloads", 0700); err != nil { + t.Fatalf("Mkdir() err=%v", err) + } + if err := os.WriteFile(filepath.Join("downloads", "b.txt"), []byte("existing"), 0600); err != nil { + t.Fatalf("WriteFile() err=%v", err) + } + + err := runShortcut(t, BaseRecordDownloadAttachment, []string{ + "+record-download-attachment", + "--base-token", "app_x", + "--table-id", "tbl_x", + "--record-id", "rec_x", + "--output", "downloads", + }, factory, stdout) + if err == nil || !strings.Contains(err.Error(), "output file already exists: downloads/b.txt") { + t.Fatalf("err=%v", err) + } + if _, err := os.Stat(filepath.Join(tmpDir, "downloads", "a.txt")); err == nil { + t.Fatalf("a.txt should not be written after preflight conflict") + } + }) + + t.Run("download reports progress when later attachment fails", func(t *testing.T) { + factory, stdout, reg := newExecuteFactory(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/get_attachments", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "attachments": map[string]interface{}{ + "rec_x": map[string]interface{}{ + "fld_att": []interface{}{ + map[string]interface{}{"file_token": "box_a", "name": "a.txt", "size": 7}, + map[string]interface{}{"file_token": "box_b", "name": "b.txt", "size": 8}, + }, + }, + }, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/medias/box_a/download", + RawBody: []byte("payload-a"), + ContentType: "text/plain", + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/medias/box_b/download", + Status: 500, + RawBody: []byte("server error"), + }) + + tmpDir := t.TempDir() + withBaseWorkingDir(t, tmpDir) + if err := os.Mkdir("downloads", 0700); err != nil { + t.Fatalf("Mkdir() err=%v", err) + } + + err := runShortcut(t, BaseRecordDownloadAttachment, []string{ + "+record-download-attachment", + "--base-token", "app_x", + "--table-id", "tbl_x", + "--record-id", "rec_x", + "--output", "downloads", + }, factory, stdout) + if err == nil || !strings.Contains(err.Error(), "download failed after 1 attachment(s) succeeded and 1 failed") { + t.Fatalf("err=%v", err) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured error, got %T %v", err, err) + } + detail, _ := exitErr.Detail.Detail.(map[string]interface{}) + downloaded, _ := detail["downloaded"].([]map[string]interface{}) + failed, _ := detail["failed"].([]map[string]interface{}) + if len(downloaded) != 1 || downloaded[0]["file_token"] != "box_a" || len(failed) != 1 || failed[0]["file_token"] != "box_b" { + t.Fatalf("detail=%#v", exitErr.Detail.Detail) + } + if _, err := os.Stat(filepath.Join(tmpDir, "downloads", "a.txt")); err != nil { + t.Fatalf("expected first file to remain: %v", err) + } + }) + + t.Run("remove attachment", func(t *testing.T) { + factory, stdout, reg := newExecuteFactory(t) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/fields/fld_att", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{"id": "fld_att", "name": "附件", "type": "attachment"}, + }, + }) + removeStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/remove_attachments", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "attachments": map[string]interface{}{ + "rec_x": map[string]interface{}{"fld_att": []interface{}{}}, + }, + }, + }, + } + reg.Register(removeStub) + + if err := runShortcut(t, BaseRecordRemoveAttachment, []string{ + "+record-remove-attachment", + "--base-token", "app_x", + "--table-id", "tbl_x", + "--record-id", "rec_x", + "--field-id", "fld_att", + "--file-token", "box_a", + "--file-token", "box_b", + "--yes", + }, factory, stdout); err != nil { + t.Fatalf("err=%v", err) + } + if got := stdout.String(); strings.Contains(got, `"removed"`) || strings.Contains(got, `"updated"`) { + t.Fatalf("stdout=%s", got) + } + body := string(removeStub.CapturedBody) + if !strings.Contains(body, `"rec_x"`) || + !strings.Contains(body, `"fld_att"`) || + !strings.Contains(body, `"file_token":"box_a"`) || + !strings.Contains(body, `"file_token":"box_b"`) { + t.Fatalf("remove body=%s", body) + } + }) } func TestBaseViewExecuteReadCreateDeleteAndFilter(t *testing.T) { diff --git a/shortcuts/base/base_form_detail.go b/shortcuts/base/base_form_detail.go new file mode 100644 index 000000000..4dc765003 --- /dev/null +++ b/shortcuts/base/base_form_detail.go @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package base + +import ( + "context" + + "github.com/larksuite/cli/shortcuts/common" +) + +var BaseFormDetail = common.Shortcut{ + Service: "base", + Command: "+form-detail", + Description: "Get form detail by share token", + Risk: "read", + Scopes: []string{"base:form:read"}, + AuthTypes: []string{"user", "bot"}, + HasFormat: true, + Flags: []common.Flag{ + {Name: "share-token", Desc: "Form share token (share_token)", Required: true}, + }, + DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + return common.NewDryRunAPI(). + POST("/open-apis/base/v3/bases/tables/forms/detail"). + Body(map[string]interface{}{ + "share_token": runtime.Str("share-token"), + }) + }, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + body := map[string]interface{}{ + "share_token": runtime.Str("share-token"), + } + + data, err := baseV3Call(runtime, "POST", + baseV3Path("bases", "tables", "forms", "detail"), nil, body) + if err != nil { + return err + } + + runtime.Out(data, nil) + return nil + }, +} diff --git a/shortcuts/base/base_form_submit.go b/shortcuts/base/base_form_submit.go new file mode 100644 index 000000000..7c1aeb173 --- /dev/null +++ b/shortcuts/base/base_form_submit.go @@ -0,0 +1,334 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package base + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "path/filepath" + "sync" + + "golang.org/x/sync/errgroup" + + "github.com/larksuite/cli/extension/fileio" + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/shortcuts/common" +) + +const ( + uploadAttachConcurrency = 5 +) + +var BaseFormSubmit = common.Shortcut{ + Service: "base", + Command: "+form-submit", + Description: "Submit a form (fill and submit form data)", + Risk: "write", + Scopes: []string{"base:form:update", "docs:document.media:upload"}, + AuthTypes: authTypes(), + HasFormat: true, + Flags: []common.Flag{ + {Name: "share-token", Desc: "Form share token (required), extracted from the form share link", Required: true}, + {Name: "base-token", Desc: "Base token (required when --json contains attachments, used for uploading attachments to Base Drive Media)"}, + {Name: "json", Desc: `JSON object containing "fields" (field values) and "attachments" (attachment file paths). Example: '{"fields":{"Rating":5,"Review":"Good"},"attachments":{"Attachment":["./a.pdf","./b.png"]}}'`, Required: true}, + }, + Tips: []string{ + `Example (no attachments): --share-token shrXXXX --json '{"fields":{"Service Rating":5,"Review":"Good service"}}'`, + `Example (with attachments): --share-token shrXXXX --base-token basXXX --json '{"fields":{"Service Rating":5},"attachments":{"Attachment":["./report.pdf"]}}'`, + `Cell values in "fields" follow lark-base-cell-value.md conventions; "attachments" maps field names to local file path arrays — the CLI uploads them in parallel and merges them into the submission.`, + }, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + return validateFormSubmit(runtime) + }, + DryRun: dryRunFormSubmit, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + return executeFormSubmit(runtime) + }, +} + +func validateFormSubmit(runtime *common.RuntimeContext) error { + // 校验 --json 结构:提取 "fields" 和 "attachments" + pc := newParseCtx(runtime) + raw, err := parseJSONObject(pc, runtime.Str("json"), "json") + if err != nil { + return err + } + + fields, _ := raw["fields"].(map[string]interface{}) + attachments, hasAttachments := raw["attachments"] + + if !hasAttachments && fields == nil { + return common.FlagErrorf("--json must contain at least \"fields\" or \"attachments\"") + } + + if hasAttachments { + // 有附件时 --base-token 必填(上传附件到 Base Drive Media 需要) + if runtime.Str("base-token") == "" { + return common.FlagErrorf("--base-token is required when --json contains \"attachments\"") + } + + attMap, ok := attachments.(map[string]interface{}) + if !ok { + return common.FlagErrorf("--json.attachments must be a JSON object mapping field names to file path arrays") + } + for fieldName, value := range attMap { + paths, ok := value.([]interface{}) + if !ok { + return common.FlagErrorf("--json.attachments.%q must be a file path array, got %T", fieldName, value) + } + for i, item := range paths { + if _, ok := item.(string); !ok { + return common.FlagErrorf("--json.attachments.%q[%d] must be a file path string, got %T", fieldName, i, item) + } + } + if len(paths) == 0 { + return common.FlagErrorf("--json.attachments.%q must not be empty; remove it or provide at least one file path", fieldName) + } + } + } + + return nil +} + +// parseFormSubmitJSON 将 --json 解析为字段和附件映射。 +func parseFormSubmitJSON(runtime *common.RuntimeContext) (map[string]interface{}, map[string][]string, error) { + pc := newParseCtx(runtime) + raw, err := parseJSONObject(pc, runtime.Str("json"), "json") + if err != nil { + return nil, nil, err + } + + fields, _ := raw["fields"].(map[string]interface{}) + if fields == nil { + fields = make(map[string]interface{}) + } + + var attMap map[string][]string + if attachments, ok := raw["attachments"]; ok { + attObj, ok := attachments.(map[string]interface{}) + if !ok { + return nil, nil, common.FlagErrorf(`--json.attachments must be a JSON object mapping field names to file path arrays`) + } + if len(attObj) > 0 { + attMap = make(map[string][]string, len(attObj)) + for fieldName, value := range attObj { + paths, ok := value.([]interface{}) + if !ok { + return nil, nil, common.FlagErrorf("--json.attachments.%q must be a file path array, got %T", fieldName, value) + } + filePaths := make([]string, 0, len(paths)) + for _, item := range paths { + if s, ok := item.(string); ok { + filePaths = append(filePaths, s) + } else { + return nil, nil, common.FlagErrorf("--json.attachments.%q must contain file path strings only, got %T", fieldName, item) + } + } + if len(filePaths) > 0 { + attMap[fieldName] = filePaths + } + } + } + } + + return fields, attMap, nil +} + +func dryRunFormSubmit(_ context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + fields, attachmentMap, err := parseFormSubmitJSON(runtime) + if err != nil { + return common.NewDryRunAPI().Desc(fmt.Sprintf("dry-run validation failed: %v", err)) + } + + if len(attachmentMap) > 0 { + dry := common.NewDryRunAPI(). + Desc("Form submit with attachments: upload local files per field → merge with fields → submit") + + for fieldName, filePaths := range attachmentMap { + for _, p := range filePaths { + fileName := filepath.Base(p) + dry = dry.POST("/open-apis/drive/v1/medias/upload_all"). + Desc(fmt.Sprintf("Upload attachment for field %q: %s", fieldName, fileName)). + Body(map[string]interface{}{ + "file_name": fileName, + "parent_type": baseFormAttachmentParentType, + "parent_node": runtime.Str("base-token"), + "extra": baseFormAttachmentExtra(runtime.Str("share-token")), + "file": "@" + p, + "size": "", + }) + } + } + + body := buildFormSubmitBody(runtime, fields) + dry = dry.POST("/open-apis/base/v3/bases/tables/forms/submit"). + Body(body). + Desc("Submit form with uploaded attachment tokens merged with fields") + return dry + } + + body := buildFormSubmitBody(runtime, fields) + return common.NewDryRunAPI(). + POST("/open-apis/base/v3/bases/tables/forms/submit"). + Body(body) +} + +func buildFormSubmitBody(runtime *common.RuntimeContext, content map[string]interface{}) map[string]interface{} { + return map[string]interface{}{ + "share_token": runtime.Str("share-token"), + "content": content, + } +} + +func executeFormSubmit(runtime *common.RuntimeContext) error { + fields, attachmentMap, err := parseFormSubmitJSON(runtime) + if err != nil { + return err + } + + // 上传附件并合并到字段中 + if len(attachmentMap) > 0 { + baseToken := runtime.Str("base-token") + fio := runtime.FileIO() + if fio == nil { + return output.ErrValidation("file operations require a FileIO provider (needed for attachments in --json)") + } + + // Step 1: 收集所有唯一路径(跨字段去重) + allPaths := collectUniquePaths(attachmentMap) + if len(allPaths) == 0 { + return common.FlagErrorf("attachments in --json contains no valid file paths") + } + + // Step 2: 前置校验所有文件路径安全性与可访问性,同时收集文件大小供上传使用 + sizeMap := make(map[string]int64, len(allPaths)) + for _, filePath := range allPaths { + if _, err := validate.SafeInputPath(filePath); err != nil { + return output.ErrValidation("unsafe attachment file path: %s: %v", filePath, err) + } + fileInfo, err := fio.Stat(filePath) + if err != nil { + if errors.Is(err, fileio.ErrPathValidation) { + return output.ErrValidation("unsafe attachment file path: %s: %v", filePath, err) + } + return output.ErrValidation("attachment file not accessible: %s: %v", filePath, err) + } + if fileInfo.Size() > baseAttachmentUploadMaxFileSize { + return output.ErrValidation("attachment file %s exceeds 2GB limit", filePath) + } + if !fileInfo.Mode().IsRegular() { + return output.ErrValidation("attachment file %s is not a regular file", filePath) + } + sizeMap[filePath] = fileInfo.Size() + } + + // Step 3: 并行上传,构建路径 → 附件结果映射 + fmt.Fprintf(runtime.IO().ErrOut, "Uploading %d unique attachment(s)...\n", len(allPaths)) + resultMap, err := uploadAttachmentsParallel(runtime, allPaths, baseFormAttachmentUploadTarget(baseToken, runtime.Str("share-token")), sizeMap) + if err != nil { + return err + } + + // Step 4: 根据共享结果映射,按字段组装单元格 + for fieldName, filePaths := range attachmentMap { + cell := make([]interface{}, 0, len(filePaths)) + for _, p := range filePaths { + if att, ok := resultMap[p]; ok { + cell = append(cell, att) + } + } + fields[fieldName] = cell + } + fmt.Fprintf(runtime.IO().ErrOut, "Uploaded %d unique file(s) into %d field(s)\n", len(resultMap), len(attachmentMap)) + } + + body := buildFormSubmitBody(runtime, fields) + data, err := baseV3Call(runtime, "POST", + baseV3Path("bases", "tables", "forms", "submit"), + nil, body) + if err != nil { + return err + } + + runtime.Out(data, nil) + return nil +} + +// collectUniquePaths 收集所有字段中的文件路径,返回去重后的有序列表。 +func collectUniquePaths(attachmentMap map[string][]string) []string { + seen := make(map[string]bool, len(attachmentMap)*4) + var order []string + for _, filePaths := range attachmentMap { + for _, p := range filePaths { + if !seen[p] { + seen[p] = true + order = append(order, p) + } + } + } + return order +} + +func baseFormAttachmentUploadTarget(baseToken, shareToken string) baseAttachmentUploadTarget { + return baseAttachmentUploadTarget{ + ParentType: baseFormAttachmentParentType, + ParentNode: baseToken, + Extra: baseFormAttachmentExtra(shareToken), + } +} + +func baseFormAttachmentExtra(shareToken string) string { + extra, err := json.Marshal(map[string]string{"share_token": shareToken}) + if err != nil { + return "" + } + return string(extra) +} + +// uploadAttachmentsParallel 并发上传文件,返回路径 → 附件对象的映射。 +func uploadAttachmentsParallel(runtime *common.RuntimeContext, paths []string, target baseAttachmentUploadTarget, sizeMap map[string]int64) (map[string]interface{}, error) { + var ( + mu sync.Mutex + resultMap = make(map[string]interface{}, len(paths)) + ) + + g, _ := errgroup.WithContext(runtime.Ctx()) + g.SetLimit(uploadAttachConcurrency) // 限制并发数 + + for _, filePath := range paths { + fp := filePath // 捕获循环变量 + g.Go(func() error { + fileName := filepath.Base(fp) + fmt.Fprintf(runtime.IO().ErrOut, " Uploading: %s\n", fileName) + + att, err := uploadSingleAttachment(runtime, fp, fileName, sizeMap[fp], target) + if err != nil { + return err + } + + mu.Lock() + resultMap[fp] = att + mu.Unlock() + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, err + } + return resultMap, nil +} + +// uploadSingleAttachment 上传单个文件,返回附件单元格项。 +// 前置条件:文件已通过校验(存在、常规文件、大小在限制内)。 +func uploadSingleAttachment(runtime *common.RuntimeContext, filePath, fileName string, fileSize int64, target baseAttachmentUploadTarget) (interface{}, error) { + att, err := uploadAttachmentToBase(runtime, filePath, fileName, fileSize, target) + if err != nil { + return nil, fmt.Errorf("failed to upload attachment %s: %w", filePath, err) + } + return att, nil +} diff --git a/shortcuts/base/base_shortcuts_test.go b/shortcuts/base/base_shortcuts_test.go index eeca3b8d1..9f75ac763 100644 --- a/shortcuts/base/base_shortcuts_test.go +++ b/shortcuts/base/base_shortcuts_test.go @@ -5,6 +5,9 @@ package base import ( "context" + "encoding/json" + "os" + "path/filepath" "reflect" "strconv" "strings" @@ -14,6 +17,7 @@ import ( "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/httpmock" "github.com/larksuite/cli/shortcuts/common" ) @@ -132,14 +136,15 @@ func TestShortcutsCatalog(t *testing.T) { "+table-list", "+table-get", "+table-create", "+table-update", "+table-delete", "+field-list", "+field-get", "+field-create", "+field-update", "+field-delete", "+field-search-options", "+view-list", "+view-get", "+view-create", "+view-delete", "+view-get-filter", "+view-set-filter", "+view-get-visible-fields", "+view-set-visible-fields", "+view-get-group", "+view-set-group", "+view-get-sort", "+view-set-sort", "+view-get-timebar", "+view-set-timebar", "+view-get-card", "+view-set-card", "+view-rename", - "+record-list", "+record-search", "+record-get", "+record-upsert", "+record-batch-create", "+record-batch-update", "+record-share-link-create", "+record-upload-attachment", "+record-delete", + "+record-list", "+record-search", "+record-get", "+record-upsert", "+record-batch-create", "+record-batch-update", "+record-share-link-create", "+record-upload-attachment", "+record-download-attachment", "+record-remove-attachment", "+record-delete", "+record-history-list", "+base-get", "+base-copy", "+base-create", "+role-create", "+role-delete", "+role-update", "+role-list", "+role-get", "+advperm-enable", "+advperm-disable", "+workflow-list", "+workflow-get", "+workflow-create", "+workflow-update", "+workflow-enable", "+workflow-disable", "+data-query", - "+form-create", "+form-delete", "+form-list", "+form-update", "+form-get", + "+form-create", "+form-delete", "+form-list", "+form-update", "+form-get", "+form-detail", "+form-questions-create", "+form-questions-delete", "+form-questions-update", "+form-questions-list", + "+form-submit", "+dashboard-list", "+dashboard-get", "+dashboard-create", "+dashboard-update", "+dashboard-delete", "+dashboard-arrange", "+dashboard-block-list", "+dashboard-block-get", "+dashboard-block-create", "+dashboard-block-update", "+dashboard-block-delete", } @@ -167,16 +172,23 @@ func TestBaseTableDeleteRisk(t *testing.T) { } } +func TestBaseFieldUpdateRisk(t *testing.T) { + if BaseFieldUpdate.Risk != "high-risk-write" { + t.Fatalf("risk=%q want=%q", BaseFieldUpdate.Risk, "high-risk-write") + } +} + func TestBaseDeleteShortcutsRisk(t *testing.T) { cases := map[string]string{ - BaseFieldDelete.Command: BaseFieldDelete.Risk, - BaseViewDelete.Command: BaseViewDelete.Risk, - BaseRecordDelete.Command: BaseRecordDelete.Risk, - BaseFormDelete.Command: BaseFormDelete.Risk, - BaseFormQuestionsDelete.Command: BaseFormQuestionsDelete.Risk, - BaseDashboardDelete.Command: BaseDashboardDelete.Risk, - BaseDashboardBlockDelete.Command: BaseDashboardBlockDelete.Risk, - BaseRoleDelete.Command: BaseRoleDelete.Risk, + BaseFieldDelete.Command: BaseFieldDelete.Risk, + BaseViewDelete.Command: BaseViewDelete.Risk, + BaseRecordDelete.Command: BaseRecordDelete.Risk, + BaseRecordRemoveAttachment.Command: BaseRecordRemoveAttachment.Risk, + BaseFormDelete.Command: BaseFormDelete.Risk, + BaseFormQuestionsDelete.Command: BaseFormQuestionsDelete.Risk, + BaseDashboardDelete.Command: BaseDashboardDelete.Risk, + BaseDashboardBlockDelete.Command: BaseDashboardBlockDelete.Risk, + BaseRoleDelete.Command: BaseRoleDelete.Risk, } for command, risk := range cases { @@ -332,6 +344,79 @@ func TestBaseFieldUpdateHelpGuidesAgents(t *testing.T) { } } +func TestBaseAttachmentHelpGuidesAgents(t *testing.T) { + tests := []struct { + name string + shortcut common.Shortcut + wantHelp []string + wantTips []string + }{ + { + name: "upload attachment", + shortcut: BaseRecordUploadAttachment, + wantHelp: []string{ + "repeat to append multiple attachments in one cell", + "max 50 files, max 2GB each", + }, + wantTips: []string{ + "lark-cli base +record-upload-attachment", + "Repeat --file to append multiple attachments", + "Reuse returned file_token values for download/remove", + }, + }, + { + name: "download attachment", + shortcut: BaseRecordDownloadAttachment, + wantHelp: []string{ + "repeat to download selected files", + "omit to download all attachments in the record", + "with multiple or omitted file tokens this must be an existing directory", + }, + wantTips: []string{ + "lark-cli base +record-download-attachment", + "Omit --file-token to download every attachment in the record", + "Base attachments should be downloaded with this command", + "other download commands may fail", + }, + }, + { + name: "remove attachment", + shortcut: BaseRecordRemoveAttachment, + wantHelp: []string{ + "remove from the target cell", + "max 50 tokens", + }, + wantTips: []string{ + "lark-cli base +record-remove-attachment", + "Repeat --file-token", + "requires --yes", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parent := &cobra.Command{Use: "base"} + tt.shortcut.Mount(parent, &cmdutil.Factory{}) + cmd := parent.Commands()[0] + + help := cmd.Flags().FlagUsages() + for _, want := range tt.wantHelp { + if !strings.Contains(help, want) { + t.Fatalf("flag help missing %q:\n%s", want, help) + } + } + + tips := strings.Join(cmdutil.GetTips(cmd), "\n") + for _, want := range tt.wantTips { + if !strings.Contains(tips, want) { + t.Fatalf("tips missing %q:\n%s", want, tips) + } + } + }) + } +} + func assertHelpOrder(t *testing.T, help string, before string, after string) { t.Helper() beforeIndex := strings.Index(help, before) @@ -419,3 +504,1018 @@ func TestBaseViewValidate(t *testing.T) { t.Fatalf("err=%v", err) } } + +// --- base_form_submit.go 子函数单测 --- + +func TestValidateFormSubmit(t *testing.T) { + t.Run("invalid json", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "json": "{invalid", + }, nil, nil) + err := validateFormSubmit(rt) + if err == nil || !strings.Contains(err.Error(), "invalid JSON") { + t.Fatalf("expected JSON error, got: %v", err) + } + }) + + t.Run("fields only - valid", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "json": `{"fields":{"Rating":5}}`, + }, nil, nil) + if err := validateFormSubmit(rt); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("missing both fields and attachments", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "json": `{}`, + }, nil, nil) + err := validateFormSubmit(rt) + if err == nil || !strings.Contains(err.Error(), "must contain at least") { + t.Fatalf("expected missing fields/attachments error, got: %v", err) + } + }) + + t.Run("attachments without base-token", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "json": `{"attachments":{"File":["./a.pdf"]}}`, + }, nil, nil) + err := validateFormSubmit(rt) + if err == nil || !strings.Contains(err.Error(), "--base-token is required") { + t.Fatalf("expected base-token required error, got: %v", err) + } + }) + + t.Run("attachments not an object", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "base-token": "bas_test", + "json": `{"attachments":"not_an_object"}`, + }, nil, nil) + err := validateFormSubmit(rt) + if err == nil || !strings.Contains(err.Error(), "must be a JSON object") { + t.Fatalf("expected object error, got: %v", err) + } + }) + + t.Run("attachment value not array", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "base-token": "bas_test", + "json": `{"attachments":{"File":"not_array"}}`, + }, nil, nil) + err := validateFormSubmit(rt) + if err == nil || !strings.Contains(err.Error(), "must be a file path array") { + t.Fatalf("expected array error, got: %v", err) + } + }) + + t.Run("attachment path item not string", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "base-token": "bas_test", + "json": `{"attachments":{"File":[123]}}`, + }, nil, nil) + err := validateFormSubmit(rt) + if err == nil || !strings.Contains(err.Error(), "must be a file path string") { + t.Fatalf("expected string error, got: %v", err) + } + }) + + t.Run("empty attachment paths", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "base-token": "bas_test", + "json": `{"attachments":{"File":[]}}`, + }, nil, nil) + err := validateFormSubmit(rt) + if err == nil || !strings.Contains(err.Error(), "must not be empty") { + t.Fatalf("expected empty error, got: %v", err) + } + }) + + t.Run("attachments valid with base-token", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "base-token": "bas_test", + "json": `{"fields":{"Rating":5},"attachments":{"File":["./a.pdf"]}}`, + }, nil, nil) + if err := validateFormSubmit(rt); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestParseFormSubmitJSON(t *testing.T) { + t.Run("fields only", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "json": `{"fields":{"Rating":5,"Review":"Good"}}`, + }, nil, nil) + fields, attMap, err := parseFormSubmitJSON(rt) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(fields) != 2 || fields["Rating"] != float64(5) || fields["Review"] != "Good" { + t.Fatalf("fields=%v", fields) + } + if attMap != nil { + t.Fatalf("expected nil attMap, got %v", attMap) + } + }) + + t.Run("no fields key returns empty map", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "json": `{"attachments":{"File":["./a.pdf"]}}`, + }, nil, nil) + fields, _, err := parseFormSubmitJSON(rt) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(fields) != 0 { + t.Fatalf("expected empty fields, got %v", fields) + } + }) + + t.Run("with attachments", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "json": `{"fields":{"Rating":5},"attachments":{"File":["./a.pdf","./b.png"],"Photo":["./c.jpg"]}}`, + }, nil, nil) + fields, attMap, err := parseFormSubmitJSON(rt) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fields["Rating"] != float64(5) { + t.Fatalf("missing Rating field") + } + if len(attMap) != 2 { + t.Fatalf("attMap size=%d want=2", len(attMap)) + } + if len(attMap["File"]) != 2 || attMap["File"][0] != "./a.pdf" || attMap["File"][1] != "./b.png" { + t.Fatalf("File paths=%v", attMap["File"]) + } + if len(attMap["Photo"]) != 1 || attMap["Photo"][0] != "./c.jpg" { + t.Fatalf("Photo paths=%v", attMap["Photo"]) + } + }) + + t.Run("invalid json", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{"json": "{"}, nil, nil) + _, _, err := parseFormSubmitJSON(rt) + if err == nil || !strings.Contains(err.Error(), "invalid JSON") { + t.Fatalf("expected JSON error, got: %v", err) + } + }) + + t.Run("attachments not object", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "json": `{"attachments":"bad"}`, + }, nil, nil) + _, _, err := parseFormSubmitJSON(rt) + if err == nil || !strings.Contains(err.Error(), "must be a JSON object") { + t.Fatalf("expected object error, got: %v", err) + } + }) + + t.Run("attachment value not array", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "json": `{"attachments":{"File":"str"}}`, + }, nil, nil) + _, _, err := parseFormSubmitJSON(rt) + if err == nil || !strings.Contains(err.Error(), "must be a file path array") { + t.Fatalf("expected array error, got: %v", err) + } + }) + + t.Run("attachment item not string", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "json": `{"attachments":{"File":[42]}}`, + }, nil, nil) + _, _, err := parseFormSubmitJSON(rt) + if err == nil || !strings.Contains(err.Error(), "file path strings only") { + t.Fatalf("expected string error, got: %v", err) + } + }) + + t.Run("empty attachments object returns nil map", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "json": `{"attachments":{}}`, + }, nil, nil) + _, attMap, err := parseFormSubmitJSON(rt) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if attMap != nil { + t.Fatalf("expected nil attMap for empty, got %v", attMap) + } + }) + + t.Run("empty attachment path list excluded from map", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "json": `{"attachments":{"File":[],"Photo":["./x.jpg"]}}`, + }, nil, nil) + _, attMap, err := parseFormSubmitJSON(rt) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := attMap["File"]; ok { + t.Fatalf("empty File should be excluded from attMap") + } + if len(attMap["Photo"]) != 1 { + t.Fatalf("Photo should have 1 entry") + } + }) +} + +func TestBuildFormSubmitBody(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_abc123", + }, nil, nil) + content := map[string]interface{}{"Rating": float64(5), "Review": "Good"} + body := buildFormSubmitBody(rt, content) + + if body["share_token"] != "shr_abc123" { + t.Fatalf("share_token=%q want shr_abc123", body["share_token"]) + } + gotContent, ok := body["content"].(map[string]interface{}) + if !ok { + t.Fatalf("content type=%T want map", body["content"]) + } + if gotContent["Rating"] != float64(5) || gotContent["Review"] != "Good" { + t.Fatalf("content=%v want Rating=5 Review=Good", gotContent) + } +} + +func TestCollectUniquePaths(t *testing.T) { + t.Run("dedup across fields", func(t *testing.T) { + m := map[string][]string{ + "Field1": {"./a.pdf", "./b.png"}, + "Field2": {"./b.png", "./c.jpg"}, + "Field3": {"./a.pdf", "./d.txt"}, + } + result := collectUniquePaths(m) + // Should preserve first-seen order, deduplicated + wantLen := 4 // a.pdf, b.png, c.jpg, d.txt + if len(result) != wantLen { + t.Fatalf("len=%d want=%d result=%v", len(result), wantLen, result) + } + // Check no duplicates + seen := make(map[string]bool) + for _, p := range result { + if seen[p] { + t.Fatalf("duplicate path: %s", p) + } + seen[p] = true + } + }) + + t.Run("empty map", func(t *testing.T) { + result := collectUniquePaths(map[string][]string{}) + if len(result) != 0 { + t.Fatalf("expected empty, got %v", result) + } + }) + + t.Run("single field single path", func(t *testing.T) { + m := map[string][]string{"F": {"./only.pdf"}} + result := collectUniquePaths(m) + if len(result) != 1 || result[0] != "./only.pdf" { + t.Fatalf("result=%v", result) + } + }) + + t.Run("same path in same field", func(t *testing.T) { + m := map[string][]string{"F": {"./same.pdf", "./same.pdf"}} + result := collectUniquePaths(m) + if len(result) != 1 { + t.Fatalf("expected 1 unique, got %d: %v", len(result), result) + } + }) +} + +func TestBaseFormAttachmentUploadTarget(t *testing.T) { + target := baseFormAttachmentUploadTarget("bas_xyz", "shr_abc") + if target.ParentType != baseFormAttachmentParentType { + t.Fatalf("ParentType=%q want %q", target.ParentType, baseFormAttachmentParentType) + } + if target.ParentNode != "bas_xyz" { + t.Fatalf("ParentNode=%q want bas_xyz", target.ParentNode) + } + // Extra should contain share_token + if !strings.Contains(target.Extra, "shr_abc") { + t.Fatalf("Extra=%q should contain share_token", target.Extra) + } +} + +func TestBaseFormAttachmentExtra(t *testing.T) { + t.Run("normal token", func(t *testing.T) { + extra := baseFormAttachmentExtra("shr_test123") + var parsed map[string]string + if err := json.Unmarshal([]byte(extra), &parsed); err != nil { + t.Fatalf("extra is not valid JSON: %v", err) + } + if parsed["share_token"] != "shr_test123" { + t.Fatalf("share_token=%q want shr_test123", parsed["share_token"]) + } + }) + + t.Run("empty token", func(t *testing.T) { + extra := baseFormAttachmentExtra("") + var parsed map[string]string + if err := json.Unmarshal([]byte(extra), &parsed); err != nil { + t.Fatalf("extra is not valid JSON: %v", err) + } + if parsed["share_token"] != "" { + t.Fatalf("share_token=%q want empty", parsed["share_token"]) + } + }) +} + +// --- dryRunFormSubmit & BaseFormDetail DryRun 测试 --- + +func TestDryRunFormSubmitInvalidJSON(t *testing.T) { + ctx := context.Background() + t.Run("invalid json returns desc-only dry run", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_xyz", + "json": `{invalid`, + }, nil, nil) + dry := dryRunFormSubmit(ctx, rt) + if dry == nil { + t.Fatal("dry result is nil") + } + data, err := dry.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + // Should have description about validation failure, no api calls + if _, ok := parsed["description"]; !ok { + t.Fatalf("expected description key for validation failure, got: %s", data) + } + desc := parsed["description"].(string) + if !strings.Contains(desc, "validation failed") { + t.Fatalf("description=%q should mention validation failed", desc) + } + }) +} + +func TestDryRunFormSubmitStructural(t *testing.T) { + ctx := context.Background() + + t.Run("fields only - single POST submit with body check", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_xyz", + "json": `{"fields":{"Rating":5,"Review":"Good"}}`, + }, nil, nil) + dry := dryRunFormSubmit(ctx, rt) + if dry == nil { + t.Fatal("dry result is nil") + } + data, err := dry.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + api, ok := parsed["api"].([]interface{}) + if !ok || len(api) != 1 { + t.Fatalf("expected 1 api call, got: %s", data) + } + call := api[0].(map[string]interface{}) + if call["method"] != "POST" { + t.Fatalf("method=%q want POST", call["method"]) + } + body, _ := call["body"].(map[string]interface{}) + if body["share_token"] != "shr_xyz" { + t.Fatalf("body.share_token=%q want shr_xyz", body["share_token"]) + } + content, _ := body["content"].(map[string]interface{}) + if content == nil || content["Rating"] != float64(5) { + t.Fatalf("content missing or wrong Rating, got: %v", content) + } + }) + + t.Run("with attachments - upload count and submit order", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_xyz", + "base-token": "bas_abc", + "json": `{"fields":{"Name":"test"},"attachments":{"File":["./report.pdf","./img.png"]}}`, + }, nil, nil) + dry := dryRunFormSubmit(ctx, rt) + if dry == nil { + t.Fatal("dry result is nil") + } + data, err := dry.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + api, ok := parsed["api"].([]interface{}) + if !ok { + t.Fatalf("api missing in output: %s", data) + } + // 2 uploads + 1 submit = 3 calls + if len(api) != 3 { + t.Fatalf("expected 3 api calls (2 upload + 1 submit), got %d: %s", len(api), data) + } + for i := 0; i < 2; i++ { + call := api[i].(map[string]interface{}) + if call["method"] != "POST" { + t.Fatalf("call[%d] method=%q want POST", i, call["method"]) + } + if !strings.Contains(call["url"].(string), "medias/upload_all") { + t.Fatalf("call[%d] url=%q should contain medias/upload_all", i, call["url"]) + } + } + submitCall := api[2].(map[string]interface{}) + if !strings.Contains(submitCall["url"].(string), "forms/submit") { + t.Fatalf("last call url=%q should contain forms/submit", submitCall["url"]) + } + }) +} + +func TestBaseFormDetailDryRun(t *testing.T) { + ctx := context.Background() + + t.Run("correct method and url", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "detail123", + }, nil, nil) + dry := BaseFormDetail.DryRun(ctx, rt) + if dry == nil { + t.Fatal("dry result is nil") + } + data, err := dry.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + api, ok := parsed["api"].([]interface{}) + if !ok || len(api) != 1 { + t.Fatalf("expected 1 api call, got: %s", data) + } + call := api[0].(map[string]interface{}) + if call["method"] != "POST" { + t.Fatalf("method=%q want POST", call["method"]) + } + if !strings.Contains(call["url"].(string), "forms/detail") { + t.Fatalf("url=%q should contain forms/detail", call["url"]) + } + body, _ := call["body"].(map[string]interface{}) + if body["share_token"] != "detail123" { + t.Fatalf("body.share_token=%q want detail123", body["share_token"]) + } + }) + + t.Run("shortcut metadata", func(t *testing.T) { + if BaseFormDetail.Command != "+form-detail" { + t.Fatalf("command=%q want +form-detail", BaseFormDetail.Command) + } + if BaseFormDetail.Risk != "read" { + t.Fatalf("risk=%q want read", BaseFormDetail.Risk) + } + if BaseFormDetail.Validate != nil { + t.Fatalf("Validate should be nil for form-detail") + } + }) +} + +// --- 通过 BaseFormSubmit / BaseFormDetail 公开接口测试 --- + +func TestBaseFormSubmitShortcut(t *testing.T) { + ctx := context.Background() + + t.Run("metadata", func(t *testing.T) { + s := BaseFormSubmit + if s.Command != "+form-submit" { + t.Fatalf("Command=%q want +form-submit", s.Command) + } + if s.Service != "base" { + t.Fatalf("Service=%q want base", s.Service) + } + if s.Risk != "write" { + t.Fatalf("Risk=%q want write", s.Risk) + } + if !s.HasFormat { + t.Fatal("HasFormat should be true") + } + }) + + t.Run("flags", func(t *testing.T) { + flags := BaseFormSubmit.Flags + flagNames := make(map[string]bool) + for _, f := range flags { + flagNames[f.Name] = true + } + for _, name := range []string{"share-token", "base-token", "json"} { + if !flagNames[name] { + t.Fatalf("missing flag %q", name) + } + } + // share-token and json are required + for _, f := range flags { + if f.Name == "share-token" && !f.Required { + t.Fatalf("share-token should be Required") + } + if f.Name == "json" && !f.Required { + t.Fatalf("json should be Required") + } + if f.Name == "base-token" && f.Required { + t.Fatalf("base-token should NOT be required (only needed with attachments)") + } + } + }) + + t.Run("scopes contain base:form:update and docs:document.media:upload", func(t *testing.T) { + scopes := BaseFormSubmit.Scopes + foundFormUpdate := false + foundMediaUpload := false + for _, s := range scopes { + if s == "base:form:update" { + foundFormUpdate = true + } + if s == "docs:document.media:upload" { + foundMediaUpload = true + } + } + if !foundFormUpdate { + t.Fatalf("Scopes=%v missing base:form:update", scopes) + } + if !foundMediaUpload { + t.Fatalf("Scopes=%v missing docs:document.media:upload", scopes) + } + }) + + t.Run("auth types", func(t *testing.T) { + authTypes := BaseFormSubmit.AuthTypes + if len(authTypes) == 0 { + t.Fatal("AuthTypes should not be empty") + } + hasUser, hasBot := false, false + for _, at := range authTypes { + if at == "user" { + hasUser = true + } + if at == "bot" { + hasBot = true + } + } + if !hasUser || !hasBot { + t.Fatalf("AuthTypes=%v should include both user and bot", authTypes) + } + }) + + t.Run("validate via shortcut interface - fields only valid", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "json": `{"fields":{"Rating":5}}`, + }, nil, nil) + if err := BaseFormSubmit.Validate(ctx, rt); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("validate via shortcut interface - missing both fields and attachments", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "json": `{}`, + }, nil, nil) + err := BaseFormSubmit.Validate(ctx, rt) + if err == nil || !strings.Contains(err.Error(), "must contain at least") { + t.Fatalf("expected validation error, got: %v", err) + } + }) + + t.Run("validate via shortcut interface - attachments without base-token", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_test", + "json": `{"attachments":{"File":["./a.pdf"]}}`, + }, nil, nil) + err := BaseFormSubmit.Validate(ctx, rt) + if err == nil || !strings.Contains(err.Error(), "--base-token is required") { + t.Fatalf("expected base-token error, got: %v", err) + } + }) + + t.Run("dryrun via shortcut interface - fields only", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_dry1", + "json": `{"fields":{"Name":"Alice"}}`, + }, nil, nil) + dry := BaseFormSubmit.DryRun(ctx, rt) + data, err := dry.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + api, _ := parsed["api"].([]interface{}) + if len(api) != 1 { + t.Fatalf("expected 1 call, got %d", len(api)) + } + call := api[0].(map[string]interface{}) + if call["method"] != "POST" { + t.Fatalf("method=%q want POST", call["method"]) + } + body, _ := call["body"].(map[string]interface{}) + if body["share_token"] != "shr_dry1" { + t.Fatalf("share_token=%q want shr_dry1", body["share_token"]) + } + }) + + t.Run("dryrun via shortcut interface - with attachments", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_dry2", + "base-token": "bas_dry2", + "json": `{"attachments":{"File":["./x.pdf"]}}`, + }, nil, nil) + dry := BaseFormSubmit.DryRun(ctx, rt) + data, err := dry.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + api, _ := parsed["api"].([]interface{}) + // 1 upload + 1 submit = 2 calls + if len(api) != 2 { + t.Fatalf("expected 2 calls (upload+submit), got %d: %s", len(api), data) + } + // First call is upload + uploadCall := api[0].(map[string]interface{}) + if !strings.Contains(uploadCall["url"].(string), "medias/upload_all") { + t.Fatalf("first call url should be upload_all, got: %v", uploadCall["url"]) + } + // Second call is submit + submitCall := api[1].(map[string]interface{}) + if !strings.Contains(submitCall["url"].(string), "forms/submit") { + t.Fatalf("second call url should be forms/submit, got: %v", submitCall["url"]) + } + }) + + t.Run("description contains useful info", func(t *testing.T) { + desc := BaseFormSubmit.Description + if desc == "" { + t.Fatal("Description should not be empty") + } + if !strings.Contains(strings.ToLower(desc), "submit") && + !strings.Contains(strings.ToLower(desc), "form") { + t.Fatalf("Description=%q should mention form or submit", desc) + } + }) + + t.Run("tips not empty", func(t *testing.T) { + if len(BaseFormSubmit.Tips) == 0 { + t.Fatal("Tips should not be empty") + } + }) +} + +func TestBaseFormDetailShortcut(t *testing.T) { + ctx := context.Background() + + t.Run("metadata", func(t *testing.T) { + s := BaseFormDetail + if s.Command != "+form-detail" { + t.Fatalf("Command=%q want +form-detail", s.Command) + } + if s.Service != "base" { + t.Fatalf("Service=%q want base", s.Service) + } + if s.Risk != "read" { + t.Fatalf("Risk=%q want read", s.Risk) + } + if !s.HasFormat { + t.Fatal("HasFormat should be true") + } + }) + + t.Run("flags - only share-token required", func(t *testing.T) { + flags := BaseFormDetail.Flags + if len(flags) != 1 { + t.Fatalf("expected 1 flag, got %d", len(flags)) + } + f := flags[0] + if f.Name != "share-token" { + t.Fatalf("flag Name=%q want share-token", f.Name) + } + if !f.Required { + t.Fatal("share-token should be Required") + } + }) + + t.Run("scopes contain base:form:read", func(t *testing.T) { + scopes := BaseFormDetail.Scopes + found := false + for _, s := range scopes { + if s == "base:form:read" { + found = true + } + } + if !found { + t.Fatalf("Scopes=%v missing base:form:read", scopes) + } + }) + + t.Run("auth types user and bot", func(t *testing.T) { + authTypes := BaseFormDetail.AuthTypes + if len(authTypes) != 2 { + t.Fatalf("expected 2 auth types, got %d: %v", len(authTypes), authTypes) + } + }) + + t.Run("validate is nil (no extra CLI-side validation)", func(t *testing.T) { + if BaseFormDetail.Validate != nil { + t.Fatal("Validate should be nil for form-detail") + } + }) + + t.Run("dryrun via shortcut interface", func(t *testing.T) { + rt := newBaseTestRuntime(map[string]string{ + "share-token": "shr_via_detail", + }, nil, nil) + dry := BaseFormDetail.DryRun(ctx, rt) + data, err := dry.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + api, _ := parsed["api"].([]interface{}) + if len(api) != 1 { + t.Fatalf("expected 1 call, got %d", len(api)) + } + call := api[0].(map[string]interface{}) + if call["method"] != "POST" { + t.Fatalf("method=%q want POST", call["method"]) + } + if !strings.Contains(call["url"].(string), "forms/detail") { + t.Fatalf("url=%q should contain forms/detail", call["url"]) + } + body, _ := call["body"].(map[string]interface{}) + if body["share_token"] != "shr_via_detail" { + t.Fatalf("share_token=%q want shr_via_detail", body["share_token"]) + } + }) + + t.Run("description", func(t *testing.T) { + desc := BaseFormDetail.Description + if desc == "" { + t.Fatal("Description should not be empty") + } + if !strings.Contains(strings.ToLower(desc), "detail") { + t.Fatalf("Description=%q should mention detail", desc) + } + }) +} + +// --- executeFormSubmit & uploadAttachmentsParallel 单元测试 --- + +func TestExecuteFormSubmit(t *testing.T) { + t.Run("fields only - no attachments", func(t *testing.T) { + factory, stdout, reg := newExecuteFactory(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/tables/forms/submit", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "record_id": "rec_submit1", + }, + }, + }) + args := []string{ + "+form-submit", + "--share-token", "shr_exec1", + "--json", `{"fields":{"Name":"Alice","Rating":5}}`, + } + if err := runShortcut(t, BaseFormSubmit, args, factory, stdout); err != nil { + t.Fatalf("err=%v", err) + } + got := stdout.String() + if !strings.Contains(got, `"record_id"`) || !strings.Contains(got, `"rec_submit1"`) { + t.Fatalf("stdout=%s", got) + } + }) + + t.Run("invalid json returns error", func(t *testing.T) { + factory, stdout, _ := newExecuteFactory(t) + args := []string{ + "+form-submit", + "--share-token", "shr_exec3", + "--json", `{not valid`, + } + err := runShortcut(t, BaseFormSubmit, args, factory, stdout) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + if !strings.Contains(err.Error(), "invalid JSON") { + t.Fatalf("error should mention invalid JSON, got: %v", err) + } + }) + + t.Run("missing both fields and attachments returns error", func(t *testing.T) { + factory, stdout, _ := newExecuteFactory(t) + args := []string{ + "+form-submit", + "--share-token", "shr_exec4", + "--json", `{}`, + } + err := runShortcut(t, BaseFormSubmit, args, factory, stdout) + if err == nil { + t.Fatal("expected error for empty JSON") + } + if !strings.Contains(err.Error(), "must contain at least") { + t.Fatalf("error should mention fields/attachments, got: %v", err) + } + }) + + t.Run("attachments without base-token returns error", func(t *testing.T) { + factory, stdout, _ := newExecuteFactory(t) + args := []string{ + "+form-submit", + "--share-token", "shr_exec5", + "--json", `{"attachments":{"File":["./x.pdf"]}}`, + } + err := runShortcut(t, BaseFormSubmit, args, factory, stdout) + if err == nil { + t.Fatal("expected error for missing base-token") + } + if !strings.Contains(err.Error(), "--base-token is required") { + t.Fatalf("error should mention base-token, got: %v", err) + } + }) + + t.Run("attachment file not found returns error", func(t *testing.T) { + tmpDir := t.TempDir() + withBaseWorkingDir(t, tmpDir) + + factory, stdout, _ := newExecuteFactory(t) + args := []string{ + "+form-submit", + "--share-token", "shr_exec6", + "--base-token", "bas_exec6", + "--json", `{"attachments":{"File":["./nonexistent.pdf"]}}`, + } + err := runShortcut(t, BaseFormSubmit, args, factory, stdout) + if err == nil { + t.Fatal("expected error for nonexistent file") + } + errMsg := err.Error() + if !strings.Contains(errMsg, "not accessible") && !strings.Contains(errMsg, "no such file") { + t.Fatalf("error should mention file not found, got: %v", errMsg) + } + }) + + t.Run("duplicate file paths across fields deduplicated in upload", func(t *testing.T) { + tmpDir := t.TempDir() + sharedFile := filepath.Join(tmpDir, "shared.pdf") + if err := os.WriteFile(sharedFile, []byte("%PDF shared"), 0644); err != nil { + t.Fatalf("create file: %v", err) + } + withBaseWorkingDir(t, tmpDir) + + factory, stdout, reg := newExecuteFactory(t) + + // Only ONE upload expected (same file referenced by two fields) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "medias/upload_all", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "file_token": "ft_shared_001", + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/tables/forms/submit", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "record_id": "rec_dedup", + }, + }, + }) + + args := []string{ + "+form-submit", + "--share-token", "shr_dedup", + "--base-token", "bas_dedup", + "--json", `{"attachments":{"FieldA":["./shared.pdf"],"FieldB":["./shared.pdf"]}}`, + } + if err := runShortcut(t, BaseFormSubmit, args, factory, stdout); err != nil { + t.Fatalf("err=%v", err) + } + got := stdout.String() + if !strings.Contains(got, `"rec_dedup"`) { + t.Fatalf("stdout should contain record, got: %s", got) + } + }) +} + +func TestUploadAttachmentsParallel(t *testing.T) { + t.Run("single file upload via execute path", func(t *testing.T) { + tmpDir := t.TempDir() + singleFile := filepath.Join(tmpDir, "doc.txt") + if err := os.WriteFile(singleFile, []byte("single file content"), 0644); err != nil { + t.Fatalf("create file: %v", err) + } + withBaseWorkingDir(t, tmpDir) + + factory, stdout, reg := newExecuteFactory(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "medias/upload_all", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "file_token": "ft_single_001", + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/base/v3/bases/tables/forms/submit", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "record_id": "rec_parallel1", + }, + }, + }) + + args := []string{ + "+form-submit", + "--share-token", "shr_para1", + "--base-token", "bas_para1", + "--json", `{"attachments":{"Doc":["./doc.txt"]}}`, + } + if err := runShortcut(t, BaseFormSubmit, args, factory, stdout); err != nil { + t.Fatalf("err=%v", err) + } + got := stdout.String() + if !strings.Contains(got, `"rec_parallel1"`) { + t.Fatalf("stdout=%s", got) + } + }) + + t.Run("upload failure propagates error", func(t *testing.T) { + tmpDir := t.TempDir() + badFile := filepath.Join(tmpDir, "bad.txt") + if err := os.WriteFile(badFile, []byte("bad"), 0644); err != nil { + t.Fatalf("create file: %v", err) + } + withBaseWorkingDir(t, tmpDir) + + factory, stdout, reg := newExecuteFactory(t) + // Upload returns non-zero code → error + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "medias/upload_all", + Body: map[string]interface{}{ + "code": 12345, + "msg": "upload quota exceeded", + }, + }) + + args := []string{ + "+form-submit", + "--share-token", "shr_err", + "--base-token", "bas_err", + "--json", `{"attachments":{"Bad":["./bad.txt"]}}`, + } + err := runShortcut(t, BaseFormSubmit, args, factory, stdout) + if err == nil { + t.Fatal("expected error from failed upload") + } + // Error should mention upload failure + errMsg := err.Error() + if !strings.Contains(errMsg, "upload") && !strings.Contains(errMsg, "failed") { + t.Fatalf("error should mention upload failure, got: %v", errMsg) + } + }) +} diff --git a/shortcuts/base/field_update.go b/shortcuts/base/field_update.go index 03999c022..f8e8a47d0 100644 --- a/shortcuts/base/field_update.go +++ b/shortcuts/base/field_update.go @@ -13,7 +13,7 @@ var BaseFieldUpdate = common.Shortcut{ Service: "base", Command: "+field-update", Description: "Update a field by ID or name", - Risk: "write", + Risk: "high-risk-write", Scopes: []string{"base:field:update"}, AuthTypes: authTypes(), Flags: []common.Flag{ diff --git a/shortcuts/base/record_upload_attachment.go b/shortcuts/base/record_upload_attachment.go index 67fde80f8..d26d654d0 100644 --- a/shortcuts/base/record_upload_attachment.go +++ b/shortcuts/base/record_upload_attachment.go @@ -8,27 +8,44 @@ import ( "context" "errors" "fmt" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" "io" "mime" + "net/http" "path/filepath" + "sort" "strings" "unicode/utf8" "github.com/larksuite/cli/extension/fileio" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/util" + "github.com/larksuite/cli/internal/validate" "github.com/larksuite/cli/shortcuts/common" + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" ) const ( baseAttachmentUploadMaxFileSize int64 = 2 * 1024 * 1024 * 1024 baseAttachmentParentType = "bitable_file" + baseFormAttachmentParentType = "bitable_tmp_point" + baseAttachmentMaxBatchSize = 50 + baseAttachmentGetMaxRecords = 10 ) +type baseAttachmentUploadTarget struct { + ParentType string + ParentNode string + Extra string +} + var BaseRecordUploadAttachment = common.Shortcut{ Service: "base", Command: "+record-upload-attachment", - Description: "Upload a local file to a Base attachment field and write it into the target record", + Description: "Upload one or more local files and append the returned file_token values to a Base attachment cell", Risk: "write", Scopes: []string{"base:record:update", "base:field:read", "docs:document.media:upload"}, AuthTypes: authTypes(), @@ -37,34 +54,99 @@ var BaseRecordUploadAttachment = common.Shortcut{ tableRefFlag(true), recordRefFlag(true), fieldRefFlag(true), - {Name: "file", Desc: "local file path (max 2GB; files > 20MB use multipart upload automatically)", Required: true}, - {Name: "name", Desc: "attachment file name (default: local file name)"}, + {Name: "file", Type: "string_array", Desc: "local file path; repeat to append multiple attachments in one cell; max 50 files, max 2GB each; files > 20MB use multipart upload automatically", Required: true}, + {Name: "name", Desc: "deprecated; attachment names are derived from local file basenames", Hidden: true}, + }, + Tips: []string{ + `Example: lark-cli base +record-upload-attachment --base-token --table-id --record-id --field-id --file ./report.pdf`, + `Repeat --file to append multiple attachments: --file ./report.pdf --file ./screenshot.png`, + `Reuse returned file_token values for download/remove`, }, DryRun: dryRunRecordUploadAttachment, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + return validateRecordUploadAttachment(runtime) + }, Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { return executeRecordUploadAttachment(runtime) }, } +var BaseRecordDownloadAttachment = common.Shortcut{ + Service: "base", + Command: "+record-download-attachment", + Description: "Download Base record attachments by record-id, optionally filtering by file-token", + Risk: "read", + Scopes: []string{"base:record:read", "docs:document.media:download"}, + AuthTypes: authTypes(), + Flags: []common.Flag{ + baseTokenFlag(true), + tableRefFlag(true), + recordRefFlag(true), + {Name: "file-token", Type: "string_array", Desc: "attachment file_token returned by Base; repeat to download selected files; omit to download all attachments in the record", Required: false}, + {Name: "output", Desc: "local save path; with exactly one file token this may be a file path; with multiple or omitted file tokens this must be an existing directory", Required: true}, + {Name: "overwrite", Type: "bool", Desc: "overwrite existing output file"}, + }, + Tips: []string{ + `Example: lark-cli base +record-download-attachment --base-token --table-id --record-id --file-token --output ./downloads/`, + `Omit --file-token to download every attachment in the record.`, + `Base attachments should be downloaded with this command; other download commands may fail for Base attachment files.`, + `With one --file-token, --output may be a file path or directory; with multiple or omitted --file-token values, --output must be an existing directory.`, + }, + DryRun: dryRunRecordDownloadAttachment, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + return validateRecordDownloadAttachment(runtime) + }, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + return executeRecordDownloadAttachment(ctx, runtime) + }, +} + +var BaseRecordRemoveAttachment = common.Shortcut{ + Service: "base", + Command: "+record-remove-attachment", + Description: "Remove one or more file_token values from a Base record attachment cell", + Risk: "high-risk-write", + Scopes: []string{"base:record:update", "base:field:read"}, + AuthTypes: authTypes(), + Flags: []common.Flag{ + baseTokenFlag(true), + tableRefFlag(true), + recordRefFlag(true), + fieldRefFlag(true), + {Name: "file-token", Type: "string_array", Desc: "attachment file_token to remove from the target cell; repeat to remove multiple attachments; max 50 tokens", Required: true}, + }, + Tips: []string{ + `Example: lark-cli base +record-remove-attachment --base-token --table-id --record-id --field-id --file-token --yes`, + `Repeat --file-token to remove multiple attachments from the same cell in one call.`, + `This is a high-risk write command and requires --yes.`, + }, + DryRun: dryRunRecordRemoveAttachment, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + return validateRecordRemoveAttachment(runtime) + }, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + return executeRecordRemoveAttachment(runtime) + }, +} + func dryRunRecordUploadAttachment(_ context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { - filePath := runtime.Str("file") - fileName := strings.TrimSpace(runtime.Str("name")) - if fileName == "" { + files := runtime.StrArray("file") + filePath := "" + fileName := "" + if len(files) > 0 { + filePath = files[0] fileName = filepath.Base(filePath) } dry := common.NewDryRunAPI(). - Desc("4-step orchestration: validate attachment field → read existing record attachments → upload file to Base → patch merged attachment array"). + Desc("3-step orchestration: validate attachment field → upload local file(s) to Base → append uploaded file token(s) to the attachment cell"). GET("/open-apis/base/v3/bases/:base_token/tables/:table_id/fields/:field_id"). Desc("[1] Read target field and ensure it is an attachment field"). Set("base_token", runtime.Str("base-token")). Set("table_id", baseTableID(runtime)). - Set("field_id", runtime.Str("field-id")). - GET("/open-apis/base/v3/bases/:base_token/tables/:table_id/records/:record_id"). - Desc("[2] Read current record to preserve existing attachments in the target cell"). - Set("record_id", runtime.Str("record-id")) + Set("field_id", runtime.Str("field-id")) if baseAttachmentShouldUseMultipart(runtime.FileIO(), filePath) { dry.POST("/open-apis/drive/v1/medias/upload_prepare"). - Desc("[3a] Initialize multipart attachment upload to the current Base"). + Desc("[2a] Initialize multipart attachment upload to the current Base"). Body(map[string]interface{}{ "file_name": fileName, "parent_type": baseAttachmentParentType, @@ -72,7 +154,7 @@ func dryRunRecordUploadAttachment(_ context.Context, runtime *common.RuntimeCont "size": "", }). POST("/open-apis/drive/v1/medias/upload_part"). - Desc("[3b] Upload attachment parts (repeated)"). + Desc("[2b] Upload attachment parts (repeated for each large file)"). Body(map[string]interface{}{ "upload_id": "", "seq": "", @@ -80,14 +162,14 @@ func dryRunRecordUploadAttachment(_ context.Context, runtime *common.RuntimeCont "file": "", }). POST("/open-apis/drive/v1/medias/upload_finish"). - Desc("[3c] Finalize multipart attachment upload and get file token"). + Desc("[2c] Finalize multipart attachment upload and get file token"). Body(map[string]interface{}{ "upload_id": "", "block_num": "", }) } else { dry.POST("/open-apis/drive/v1/medias/upload_all"). - Desc("[3] Upload local file to the current Base as attachment media (multipart/form-data)"). + Desc("[2] Upload local file(s) to the current Base as attachment media (multipart/form-data)"). Body(map[string]interface{}{ "file_name": fileName, "parent_type": baseAttachmentParentType, @@ -97,46 +179,87 @@ func dryRunRecordUploadAttachment(_ context.Context, runtime *common.RuntimeCont }) } return dry. - PATCH("/open-apis/base/v3/bases/:base_token/tables/:table_id/records/:record_id"). - Desc("[4] Update the target attachment cell with existing attachments plus the uploaded file token"). + POST("/open-apis/base/v3/bases/:base_token/tables/:table_id/append_attachments"). + Desc("[3] Append uploaded file token(s) to the target attachment cell"). Body(map[string]interface{}{ - "": []interface{}{ - map[string]interface{}{ - "file_token": "", - "name": "", - "deprecated_set_attachment": true, - }, - map[string]interface{}{ - "file_token": "", - "name": fileName, - "mime_type": "", - "size": "", - "deprecated_set_attachment": true, + "attachments": map[string]interface{}{ + runtime.Str("record-id"): map[string]interface{}{ + runtime.Str("field-id"): []interface{}{ + map[string]interface{}{ + "file_token": "", + "image_width": "", + "image_height": "", + }, + }, }, }, }) } -func executeRecordUploadAttachment(runtime *common.RuntimeContext) error { - filePath := runtime.Str("file") - fio := runtime.FileIO() - if fio == nil { - return output.ErrValidation("file operations require a FileIO provider") +func dryRunRecordDownloadAttachment(_ context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + return common.NewDryRunAPI(). + Desc("2-step orchestration: read Base attachment metadata → download each requested attachment file"). + POST("/open-apis/base/v3/bases/:base_token/tables/:table_id/get_attachments"). + Desc("[1] Read attachment metadata for the record"). + Body(map[string]interface{}{"record_id_list": []string{runtime.Str("record-id")}}). + Set("base_token", runtime.Str("base-token")). + Set("table_id", baseTableID(runtime)). + GET("/open-apis/drive/v1/medias/:file_token/download"). + Desc("[2] Download attachment media through the Base attachment flow"). + Set("file_token", ""). + Set("output", runtime.Str("output")). + Params(map[string]interface{}{"extra": ""}) +} + +func dryRunRecordRemoveAttachment(_ context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + body := buildSingleCellAttachmentsBody(runtime.Str("record-id"), runtime.Str("field-id"), fileTokenPatchItems(runtime.StrArray("file-token"))) + return common.NewDryRunAPI(). + POST("/open-apis/base/v3/bases/:base_token/tables/:table_id/remove_attachments"). + Desc("Remove attachment file token(s) from the target attachment cell"). + Body(body). + Set("base_token", runtime.Str("base-token")). + Set("table_id", baseTableID(runtime)) +} + +func validateRecordUploadAttachment(runtime *common.RuntimeContext) error { + if runtime.Changed("name") { + return common.FlagErrorf("--name is no longer supported; uploaded attachment names are derived from local file basenames") } - fileInfo, err := fio.Stat(filePath) + files, err := normalizeAttachmentFiles(runtime.StrArray("file")) if err != nil { - if errors.Is(err, fileio.ErrPathValidation) { - return output.ErrValidation("unsafe file path: %s", err) + return err + } + for _, path := range files { + if _, err := validateAttachmentInputFile(runtime, path); err != nil { + return err } - return output.ErrValidation("file not accessible: %s: %v", filePath, err) } - if fileInfo.Size() > baseAttachmentUploadMaxFileSize { - return output.ErrValidation("file %s exceeds 2GB limit", common.FormatSize(fileInfo.Size())) + return nil +} + +func validateRecordDownloadAttachment(runtime *common.RuntimeContext) error { + tokens, err := normalizeOptionalDownloadAttachmentFileTokens(runtime.StrArray("file-token")) + if err != nil { + return err } + if len(tokens) != 1 { + info, statErr := runtime.FileIO().Stat(runtime.Str("output")) + if statErr != nil || !info.IsDir() { + return common.FlagErrorf("--output must be an existing directory when downloading multiple attachments or when --file-token is omitted") + } + } + return nil +} - fileName := strings.TrimSpace(runtime.Str("name")) - if fileName == "" { - fileName = filepath.Base(filePath) +func validateRecordRemoveAttachment(runtime *common.RuntimeContext) error { + _, err := normalizeAttachmentFileTokens(runtime.StrArray("file-token")) + return err +} + +func executeRecordUploadAttachment(runtime *common.RuntimeContext) error { + files, err := normalizeAttachmentFiles(runtime.StrArray("file")) + if err != nil { + return err } field, err := fetchBaseField(runtime, runtime.Str("base-token"), baseTableID(runtime), runtime.Str("field-id")) @@ -146,133 +269,233 @@ func executeRecordUploadAttachment(runtime *common.RuntimeContext) error { if normalized := normalizeFieldTypeName(fieldTypeName(field)); normalized != "attachment" { return output.ErrValidation("field %q is type %q, expected attachment", fieldName(field), normalized) } + resolvedFieldID := fieldID(field) + if resolvedFieldID == "" { + resolvedFieldID = runtime.Str("field-id") + } - record, err := fetchBaseRecord(runtime, runtime.Str("base-token"), baseTableID(runtime), runtime.Str("record-id")) + appendItems := make([]interface{}, 0, len(files)) + for _, filePath := range files { + fileInfo, err := validateAttachmentInputFile(runtime, filePath) + if err != nil { + return err + } + fileName := filepath.Base(filePath) + fmt.Fprintf(runtime.IO().ErrOut, "Uploading attachment: %s -> record %s field %s\n", fileName, runtime.Str("record-id"), fieldName(field)) + if fileInfo.Size() > common.MaxDriveMediaUploadSinglePartSize { + fmt.Fprintf(runtime.IO().ErrOut, "File exceeds 20MB, using multipart upload\n") + } + attachment, err := uploadAttachmentToBase(runtime, filePath, fileName, fileInfo.Size(), baseAttachmentUploadTarget{ + ParentType: baseAttachmentParentType, + ParentNode: runtime.Str("base-token"), + }) + if err != nil { + return err + } + appendItems = append(appendItems, attachmentAppendItem(attachment)) + } + + body := buildSingleCellAttachmentsBody(runtime.Str("record-id"), resolvedFieldID, appendItems) + data, err := baseV3Call(runtime, "POST", baseV3Path("bases", runtime.Str("base-token"), "tables", baseTableID(runtime), "append_attachments"), nil, body) if err != nil { return err } + runtime.Out(data, nil) + return nil +} - fmt.Fprintf(runtime.IO().ErrOut, "Uploading attachment: %s -> record %s field %s\n", fileName, runtime.Str("record-id"), fieldName(field)) - if fileInfo.Size() > common.MaxDriveMediaUploadSinglePartSize { - fmt.Fprintf(runtime.IO().ErrOut, "File exceeds 20MB, using multipart upload\n") +func executeRecordRemoveAttachment(runtime *common.RuntimeContext) error { + tokens, err := normalizeAttachmentFileTokens(runtime.StrArray("file-token")) + if err != nil { + return err } - - attachment, err := uploadAttachmentToBase(runtime, filePath, fileName, runtime.Str("base-token"), fileInfo.Size()) + field, err := fetchBaseField(runtime, runtime.Str("base-token"), baseTableID(runtime), runtime.Str("field-id")) if err != nil { return err } - - attachments, err := mergeRecordAttachments(record, fieldName(field), attachment) + if normalized := normalizeFieldTypeName(fieldTypeName(field)); normalized != "attachment" { + return output.ErrValidation("field %q is type %q, expected attachment", fieldName(field), normalized) + } + resolvedFieldID := fieldID(field) + if resolvedFieldID == "" { + resolvedFieldID = runtime.Str("field-id") + } + body := buildSingleCellAttachmentsBody(runtime.Str("record-id"), resolvedFieldID, fileTokenPatchItems(tokens)) + data, err := baseV3Call(runtime, "POST", baseV3Path("bases", runtime.Str("base-token"), "tables", baseTableID(runtime), "remove_attachments"), nil, body) if err != nil { return err } + runtime.Out(data, nil) + return nil +} - body := map[string]interface{}{ - fieldName(field): attachments, +func executeRecordDownloadAttachment(ctx context.Context, runtime *common.RuntimeContext) error { + tokens, err := normalizeOptionalDownloadAttachmentFileTokens(runtime.StrArray("file-token")) + if err != nil { + return err + } + attachments, err := fetchBaseAttachments(runtime, runtime.Str("base-token"), baseTableID(runtime), []string{runtime.Str("record-id")}) + if err != nil { + return err } - data, err := baseV3Call(runtime, "PATCH", baseV3Path("bases", runtime.Str("base-token"), "tables", baseTableID(runtime), "records", runtime.Str("record-id")), nil, body) + items, err := selectAttachmentDownloadItems(attachments, runtime.Str("record-id"), tokens) if err != nil { return err } - runtime.Out(map[string]interface{}{ - "record": data, - "attachment": attachment, - "attachments": attachments, - "updated": true, - }, nil) + targets, err := planAttachmentDownloadTargets(runtime, items, runtime.Str("output"), len(tokens) != 1 || len(items) > 1, runtime.Bool("overwrite")) + if err != nil { + return err + } + downloaded := make([]map[string]interface{}, 0, len(targets)) + for _, target := range targets { + saved, err := downloadBaseAttachment(ctx, runtime, target.Item, target.TargetPath, runtime.Bool("overwrite")) + if err != nil { + failed := attachmentDownloadFailure(target, err) + return attachmentDownloadProgressError(err, downloaded, []map[string]interface{}{failed}) + } + downloaded = append(downloaded, saved) + } + runtime.Out(map[string]interface{}{"downloaded": downloaded}, nil) return nil } -func baseAttachmentShouldUseMultipart(fio fileio.FileIO, filePath string) bool { - info, err := fio.Stat(filePath) +func validateAttachmentInputFile(runtime *common.RuntimeContext, filePath string) (fileio.FileInfo, error) { + fio := runtime.FileIO() + if fio == nil { + return nil, output.ErrValidation("file operations require a FileIO provider") + } + fileInfo, err := fio.Stat(filePath) if err != nil { - return false + if errors.Is(err, fileio.ErrPathValidation) { + return nil, output.ErrValidation("unsafe file path: %s", err) + } + return nil, output.ErrValidation("file not accessible: %s: %v", filePath, err) } - return info.Mode().IsRegular() && info.Size() > common.MaxDriveMediaUploadSinglePartSize + if fileInfo.IsDir() { + return nil, output.ErrValidation("file path is a directory: %s", filePath) + } + if fileInfo.Size() > baseAttachmentUploadMaxFileSize { + return nil, output.ErrValidation("file %s exceeds 2GB limit", common.FormatSize(fileInfo.Size())) + } + return fileInfo, nil } -func fetchBaseField(runtime *common.RuntimeContext, baseToken, tableIDValue, fieldRef string) (map[string]interface{}, error) { - return baseV3Call(runtime, "GET", baseV3Path("bases", baseToken, "tables", tableIDValue, "fields", fieldRef), nil, nil) +func normalizeAttachmentFiles(files []string) ([]string, error) { + return normalizeStringList(files, stringListNormalizeOptions{ + typeError: "attachment files must be a string array", + emptyError: "provide at least one --file", + itemName: "attachment file", + duplicateName: "attachment file", + limitName: "attachment file count", + max: baseAttachmentMaxBatchSize, + }) } -func fetchBaseRecord(runtime *common.RuntimeContext, baseToken, tableIDValue, recordID string) (map[string]interface{}, error) { - return baseV3Call(runtime, "GET", baseV3Path("bases", baseToken, "tables", tableIDValue, "records", recordID), nil, nil) +func normalizeAttachmentFileTokens(tokens []string) ([]string, error) { + return normalizeStringList(tokens, stringListNormalizeOptions{ + typeError: "attachment file tokens must be a string array", + emptyError: "provide at least one --file-token", + itemName: "attachment file token", + duplicateName: "attachment file token", + limitName: "attachment file token count", + max: baseAttachmentMaxBatchSize, + }) } -func mergeRecordAttachments(record map[string]interface{}, fieldName string, uploaded map[string]interface{}) ([]interface{}, error) { - fields, _ := record["fields"].(map[string]interface{}) - if fields == nil { - return []interface{}{uploaded}, nil - } - current, exists := fields[fieldName] - if !exists || util.IsNil(current) { - return []interface{}{uploaded}, nil +func normalizeOptionalDownloadAttachmentFileTokens(tokens []string) ([]string, error) { + if len(tokens) == 0 { + return nil, nil + } + normalized := make([]string, 0, len(tokens)) + for index, token := range tokens { + token = strings.TrimSpace(token) + if token == "" { + return nil, common.FlagErrorf("attachment file token %d must not be empty", index+1) + } + normalized = append(normalized, token) } - items, ok := current.([]interface{}) - if !ok { - return nil, output.ErrValidation("record field %q has unexpected attachment payload type %T", fieldName, current) + normalized = dedupeStringsPreserveOrder(normalized) + if len(normalized) > baseAttachmentMaxBatchSize { + return nil, common.FlagErrorf("attachment file token count exceeds maximum limit of %d (got %d)", baseAttachmentMaxBatchSize, len(normalized)) } - merged := make([]interface{}, 0, len(items)+1) - for _, item := range items { - attachment, ok := item.(map[string]interface{}) - if !ok { - return nil, output.ErrValidation("record field %q contains unexpected attachment item type %T", fieldName, item) + return normalized, nil +} + +func dedupeStringsPreserveOrder(values []string) []string { + seen := make(map[string]struct{}, len(values)) + result := make([]string, 0, len(values)) + for _, value := range values { + if _, exists := seen[value]; exists { + continue } - merged = append(merged, normalizeAttachmentForPatch(attachment)) + seen[value] = struct{}{} + result = append(result, value) } - merged = append(merged, uploaded) - return merged, nil + return result } -func normalizeAttachmentForPatch(attachment map[string]interface{}) map[string]interface{} { - normalized := map[string]interface{}{} - if fileToken, _ := attachment["file_token"].(string); fileToken != "" { - normalized["file_token"] = fileToken +func baseAttachmentShouldUseMultipart(fio fileio.FileIO, filePath string) bool { + if fio == nil { + return false } - if name, _ := attachment["name"].(string); name != "" { - normalized["name"] = name + info, err := fio.Stat(filePath) + if err != nil { + return false } - if mimeType, _ := attachment["mime_type"].(string); mimeType != "" { - normalized["mime_type"] = mimeType + return info.Mode().IsRegular() && info.Size() > common.MaxDriveMediaUploadSinglePartSize +} + +func fetchBaseField(runtime *common.RuntimeContext, baseToken, tableIDValue, fieldRef string) (map[string]interface{}, error) { + return baseV3Call(runtime, "GET", baseV3Path("bases", baseToken, "tables", tableIDValue, "fields", fieldRef), nil, nil) +} + +func fetchBaseAttachments(runtime *common.RuntimeContext, baseToken, tableIDValue string, recordIDs []string) (map[string]interface{}, error) { + if len(recordIDs) == 0 { + return nil, output.ErrValidation("provide at least one record id") } - if size, ok := attachment["size"]; ok && !util.IsNil(size) { - normalized["size"] = size + if len(recordIDs) > baseAttachmentGetMaxRecords { + return nil, output.ErrValidation("get attachments record selection exceeds maximum limit of %d (got %d)", baseAttachmentGetMaxRecords, len(recordIDs)) } - if imageWidth, ok := attachment["image_width"]; ok && !util.IsNil(imageWidth) { - normalized["image_width"] = imageWidth + data, err := baseV3Call(runtime, "POST", baseV3Path("bases", baseToken, "tables", tableIDValue, "get_attachments"), nil, map[string]interface{}{ + "record_id_list": recordIDs, + }) + if err != nil { + return nil, err } - if imageHeight, ok := attachment["image_height"]; ok && !util.IsNil(imageHeight) { - normalized["image_height"] = imageHeight + attachments, _ := data["attachments"].(map[string]interface{}) + if attachments == nil { + return map[string]interface{}{}, nil } - normalized["deprecated_set_attachment"] = true - return normalized + return attachments, nil } -func uploadAttachmentToBase(runtime *common.RuntimeContext, filePath, fileName, baseToken string, fileSize int64) (map[string]interface{}, error) { +func uploadAttachmentToBase(runtime *common.RuntimeContext, filePath, fileName string, fileSize int64, target baseAttachmentUploadTarget) (map[string]interface{}, error) { mimeType, err := detectAttachmentMIMEType(runtime.FileIO(), filePath, fileName) if err != nil { return nil, err } - parentNode := baseToken var ( fileToken string ) if fileSize <= common.MaxDriveMediaUploadSinglePartSize { + parentNode := target.ParentNode fileToken, err = common.UploadDriveMediaAll(runtime, common.DriveMediaUploadAllConfig{ FilePath: filePath, FileName: fileName, FileSize: fileSize, - ParentType: baseAttachmentParentType, + ParentType: target.ParentType, ParentNode: &parentNode, + Extra: target.Extra, }) } else { fileToken, err = common.UploadDriveMediaMultipart(runtime, common.DriveMediaMultipartUploadConfig{ FilePath: filePath, FileName: fileName, FileSize: fileSize, - ParentType: baseAttachmentParentType, - ParentNode: parentNode, + ParentType: target.ParentType, + ParentNode: target.ParentNode, + Extra: target.Extra, }) } if err != nil { @@ -280,15 +503,51 @@ func uploadAttachmentToBase(runtime *common.RuntimeContext, filePath, fileName, } attachment := map[string]interface{}{ - "file_token": fileToken, - "name": fileName, - "mime_type": mimeType, - "size": fileSize, - "deprecated_set_attachment": true, + "file_token": fileToken, + "name": fileName, + "mime_type": mimeType, + "size": fileSize, + } + if width, height, ok := detectAttachmentImageDimensions(runtime.FileIO(), filePath, mimeType); ok { + attachment["image_width"] = width + attachment["image_height"] = height + } else if attachmentImageDimensionsWarningEnabled(mimeType) { + fmt.Fprintf(runtime.IO().ErrOut, "Warning: image dimensions unavailable for %s; attachment may display as square\n", fileName) } return attachment, nil } +func attachmentAppendItem(attachment map[string]interface{}) map[string]interface{} { + item := map[string]interface{}{ + "file_token": attachment["file_token"], + } + if width, ok := attachment["image_width"]; ok && !util.IsNil(width) { + item["image_width"] = width + } + if height, ok := attachment["image_height"]; ok && !util.IsNil(height) { + item["image_height"] = height + } + return item +} + +func fileTokenPatchItems(tokens []string) []interface{} { + items := make([]interface{}, 0, len(tokens)) + for _, token := range tokens { + items = append(items, map[string]interface{}{"file_token": token}) + } + return items +} + +func buildSingleCellAttachmentsBody(recordID, fieldID string, items []interface{}) map[string]interface{} { + return map[string]interface{}{ + "attachments": map[string]interface{}{ + recordID: map[string]interface{}{ + fieldID: items, + }, + }, + } +} + func detectAttachmentMIMEType(fio fileio.FileIO, filePath, fileName string) (string, error) { if byExt := strings.TrimSpace(mime.TypeByExtension(strings.ToLower(filepath.Ext(fileName)))); byExt != "" { return stripMIMEParams(byExt), nil @@ -311,6 +570,309 @@ func detectAttachmentMIMEType(fio fileio.FileIO, filePath, fileName string) (str return detectAttachmentMIMEFromContent(buf[:n]), nil } +func detectAttachmentImageDimensions(fio fileio.FileIO, filePath string, mimeType string) (int, int, bool) { + if fio == nil || !strings.HasPrefix(mimeType, "image/") { + return 0, 0, false + } + f, err := fio.Open(filePath) + if err != nil { + return 0, 0, false + } + defer f.Close() + cfg, _, err := image.DecodeConfig(f) + if err != nil || cfg.Width <= 0 || cfg.Height <= 0 { + return 0, 0, false + } + return cfg.Width, cfg.Height, true +} + +func attachmentImageDimensionsWarningEnabled(mimeType string) bool { + switch mimeType { + case "image/gif", "image/jpeg", "image/png": + return true + default: + return false + } +} + +type baseAttachmentDownloadItem struct { + RecordID string + FieldID string + FileToken string + Name string + Size interface{} + ExtraInfo string + MimeType string + RawPayload map[string]interface{} +} + +type baseAttachmentDownloadTarget struct { + Item baseAttachmentDownloadItem + TargetPath string + ResolvedPath string +} + +func selectAttachmentDownloadItems(attachments map[string]interface{}, recordID string, tokens []string) ([]baseAttachmentDownloadItem, error) { + recordRaw, ok := attachments[recordID] + if !ok { + return nil, output.ErrValidation("record %q has no attachment metadata; verify the record-id", recordID) + } + fields, ok := recordRaw.(map[string]interface{}) + if !ok { + return nil, output.ErrValidation("record %q attachment metadata has unexpected type %T", recordID, recordRaw) + } + byToken := map[string]baseAttachmentDownloadItem{} + fieldIDs := make([]string, 0, len(fields)) + for currentFieldID := range fields { + fieldIDs = append(fieldIDs, currentFieldID) + } + sort.Strings(fieldIDs) + for _, currentFieldID := range fieldIDs { + rawList := fields[currentFieldID] + items, ok := rawList.([]interface{}) + if !ok { + return nil, output.ErrValidation("record %q field %q attachment metadata has unexpected type %T", recordID, currentFieldID, rawList) + } + for _, rawItem := range items { + item, ok := rawItem.(map[string]interface{}) + if !ok { + return nil, output.ErrValidation("record %q field %q contains unexpected attachment item type %T", recordID, currentFieldID, rawItem) + } + fileToken, _ := item["file_token"].(string) + if fileToken == "" { + continue + } + if _, exists := byToken[fileToken]; exists { + continue + } + name, _ := item["name"].(string) + extraInfo, _ := item["extra_info"].(string) + mimeType, _ := item["mime_type"].(string) + byToken[fileToken] = baseAttachmentDownloadItem{ + RecordID: recordID, + FieldID: currentFieldID, + FileToken: fileToken, + Name: name, + Size: item["size"], + ExtraInfo: extraInfo, + MimeType: mimeType, + RawPayload: item, + } + } + } + result := make([]baseAttachmentDownloadItem, 0, len(tokens)) + if len(tokens) == 0 { + for _, item := range byToken { + result = append(result, item) + } + if len(result) == 0 { + return nil, output.ErrValidation("record %q has no attachments to download", recordID) + } + sort.SliceStable(result, func(i, j int) bool { + leftName := strings.ToLower(baseAttachmentDownloadName(result[i])) + rightName := strings.ToLower(baseAttachmentDownloadName(result[j])) + if leftName != rightName { + return leftName < rightName + } + return result[i].FileToken < result[j].FileToken + }) + return result, nil + } + for _, token := range tokens { + item, ok := byToken[token] + if !ok { + return nil, output.ErrValidation("attachment file_token %q not found in record %q; verify the record-id/file-token pair", token, recordID) + } + result = append(result, item) + } + return result, nil +} + +func planAttachmentDownloadTargets(runtime *common.RuntimeContext, items []baseAttachmentDownloadItem, outputPath string, outputIsDir bool, overwrite bool) ([]baseAttachmentDownloadTarget, error) { + names := downloadTargetNames(items, outputIsDir || outputPathLooksDirectory(runtime, outputPath)) + targets := make([]baseAttachmentDownloadTarget, 0, len(items)) + seen := map[string]baseAttachmentDownloadItem{} + for _, item := range items { + targetName := names[item.FileToken] + targetPath := outputPath + if targetName != "" { + targetPath = filepath.Join(outputPath, targetName) + } + resolved, err := runtime.ResolveSavePath(targetPath) + if err != nil { + return nil, output.ErrValidation("unsafe output path: %s", err) + } + if previous, exists := seen[resolved]; exists { + return nil, output.ErrValidation("multiple attachments resolve to the same output path %q (%s and %s); download them separately or choose a different directory", resolved, previous.FileToken, item.FileToken) + } + seen[resolved] = item + if !overwrite { + if _, statErr := runtime.FileIO().Stat(targetPath); statErr == nil { + return nil, output.ErrValidation("output file already exists: %s (use --overwrite to replace)", targetPath) + } + } + targets = append(targets, baseAttachmentDownloadTarget{ + Item: item, + TargetPath: targetPath, + ResolvedPath: resolved, + }) + } + return targets, nil +} + +func downloadTargetNames(items []baseAttachmentDownloadItem, outputIsDir bool) map[string]string { + if !outputIsDir { + return nil + } + nameCounts := make(map[string]int, len(items)) + for _, item := range items { + nameCounts[baseAttachmentDownloadName(item)]++ + } + names := make(map[string]string, len(items)) + for _, item := range items { + name := baseAttachmentDownloadName(item) + if nameCounts[name] > 1 { + name = attachmentNameWithTokenSuffix(name, item.FileToken) + } + names[item.FileToken] = name + } + return names +} + +func baseAttachmentDownloadName(item baseAttachmentDownloadItem) string { + name := filepath.Base(strings.TrimSpace(item.Name)) + if name == "" || name == "." || name == string(filepath.Separator) { + name = item.FileToken + } + return name +} + +func attachmentNameWithTokenSuffix(name, fileToken string) string { + ext := filepath.Ext(name) + stem := strings.TrimSuffix(name, ext) + if stem == "" { + stem = name + } + return stem + "_" + safeAttachmentFileTokenSuffix(fileToken) + ext +} + +func safeAttachmentFileTokenSuffix(fileToken string) string { + var b strings.Builder + for _, r := range fileToken { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + b.WriteRune(r) + continue + } + b.WriteByte('_') + } + suffix := strings.Trim(b.String(), "_") + if suffix == "" { + return "file" + } + return suffix +} + +func downloadBaseAttachment(ctx context.Context, runtime *common.RuntimeContext, item baseAttachmentDownloadItem, targetPath string, overwrite bool) (map[string]interface{}, error) { + if _, err := runtime.ResolveSavePath(targetPath); err != nil { + return nil, output.ErrValidation("unsafe output path: %s", err) + } + + query := larkcore.QueryParams{} + if item.ExtraInfo != "" { + query.Set("extra", item.ExtraInfo) + } + resp, err := runtime.DoAPIStream(ctx, &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: fmt.Sprintf("/open-apis/drive/v1/medias/%s/download", validate.EncodePathSegment(item.FileToken)), + QueryParams: query, + }) + if err != nil { + return nil, output.ErrNetwork("download failed: %v", err) + } + defer resp.Body.Close() + + if !overwrite { + if _, statErr := runtime.FileIO().Stat(targetPath); statErr == nil { + return nil, output.ErrValidation("output file already exists: %s (use --overwrite to replace)", targetPath) + } + } + result, err := runtime.FileIO().Save(targetPath, fileio.SaveOptions{ + ContentType: resp.Header.Get("Content-Type"), + ContentLength: resp.ContentLength, + }, resp.Body) + if err != nil { + return nil, common.WrapSaveErrorByCategory(err, "io") + } + savedPath, _ := runtime.ResolveSavePath(targetPath) + if savedPath == "" { + savedPath = targetPath + } + return map[string]interface{}{ + "record_id": item.RecordID, + "field_id": item.FieldID, + "file_token": item.FileToken, + "name": item.Name, + "size": item.Size, + "saved_path": savedPath, + "size_bytes": result.Size(), + "content_type": resp.Header.Get("Content-Type"), + }, nil +} + +func attachmentDownloadFailure(target baseAttachmentDownloadTarget, err error) map[string]interface{} { + return map[string]interface{}{ + "record_id": target.Item.RecordID, + "field_id": target.Item.FieldID, + "file_token": target.Item.FileToken, + "name": target.Item.Name, + "target_path": target.TargetPath, + "resolved_path": target.ResolvedPath, + "error": err.Error(), + } +} + +func attachmentDownloadProgressError(err error, downloaded []map[string]interface{}, failed []map[string]interface{}) error { + msg := fmt.Sprintf("download failed after %d attachment(s) succeeded and %d failed: %v", len(downloaded), len(failed), err) + var exitErr *output.ExitError + if errors.As(err, &exitErr) && exitErr.Detail != nil { + return &output.ExitError{ + Code: exitErr.Code, + Detail: &output.ErrDetail{ + Type: exitErr.Detail.Type, + Code: exitErr.Detail.Code, + Message: msg, + Hint: "Some files may already have been saved. Inspect error.detail.downloaded before retrying, or rerun with --overwrite if the failed target now exists.", + Detail: map[string]interface{}{ + "downloaded": downloaded, + "failed": failed, + }, + }, + Err: err, + } + } + return &output.ExitError{ + Code: output.ExitInternal, + Detail: &output.ErrDetail{ + Type: "io", + Message: msg, + Hint: "Some files may already have been saved. Inspect error.detail.downloaded before retrying, or rerun with --overwrite if the failed target now exists.", + Detail: map[string]interface{}{ + "downloaded": downloaded, + "failed": failed, + }, + }, + Err: err, + } +} + +func outputPathLooksDirectory(runtime *common.RuntimeContext, outputPath string) bool { + if strings.HasSuffix(outputPath, "/") || strings.HasSuffix(outputPath, string(filepath.Separator)) { + return true + } + info, err := runtime.FileIO().Stat(outputPath) + return err == nil && info.IsDir() +} + func stripMIMEParams(value string) string { if i := strings.IndexByte(value, ';'); i != -1 { value = value[:i] diff --git a/shortcuts/base/record_upload_attachment_test.go b/shortcuts/base/record_upload_attachment_test.go index 69ff360e2..eaf9b138d 100644 --- a/shortcuts/base/record_upload_attachment_test.go +++ b/shortcuts/base/record_upload_attachment_test.go @@ -5,6 +5,9 @@ package base import ( "bytes" + "image" + "image/color" + "image/png" "io" "io/fs" "os" @@ -82,6 +85,42 @@ func TestDetectAttachmentMIMETypeFallsBackToContent(t *testing.T) { } } +func TestDetectAttachmentImageDimensions(t *testing.T) { + var buf bytes.Buffer + img := image.NewRGBA(image.Rect(0, 0, 4, 3)) + img.Set(0, 0, color.RGBA{G: 255, A: 255}) + if err := png.Encode(&buf, img); err != nil { + t.Fatalf("png.Encode() error = %v", err) + } + fio := attachmentTestFileIO{openFile: newAttachmentTestFile(buf.Bytes())} + + width, height, ok := detectAttachmentImageDimensions(fio, "image.png", "image/png") + if !ok || width != 4 || height != 3 { + t.Fatalf("detectAttachmentImageDimensions() = (%d,%d,%v), want (4,3,true)", width, height, ok) + } +} + +func TestAttachmentImageDimensionsWarningEnabled(t *testing.T) { + tests := []struct { + mimeType string + want bool + }{ + {mimeType: "image/gif", want: true}, + {mimeType: "image/jpeg", want: true}, + {mimeType: "image/png", want: true}, + {mimeType: "image/webp", want: false}, + {mimeType: "application/pdf", want: false}, + } + + for _, tt := range tests { + t.Run(tt.mimeType, func(t *testing.T) { + if got := attachmentImageDimensionsWarningEnabled(tt.mimeType); got != tt.want { + t.Fatalf("attachmentImageDimensionsWarningEnabled(%q) = %v, want %v", tt.mimeType, got, tt.want) + } + }) + } +} + func TestDetectAttachmentMIMETypeWrapsOpenError(t *testing.T) { fio := attachmentTestFileIO{openErr: os.ErrNotExist} diff --git a/shortcuts/base/shortcuts.go b/shortcuts/base/shortcuts.go index 60ebfe000..c98ccff7e 100644 --- a/shortcuts/base/shortcuts.go +++ b/shortcuts/base/shortcuts.go @@ -44,6 +44,8 @@ func Shortcuts() []common.Shortcut { BaseRecordBatchUpdate, BaseRecordShareLinkCreate, BaseRecordUploadAttachment, + BaseRecordDownloadAttachment, + BaseRecordRemoveAttachment, BaseRecordDelete, BaseRecordHistoryList, BaseBaseGet, @@ -68,10 +70,12 @@ func Shortcuts() []common.Shortcut { BaseFormsList, BaseFormUpdate, BaseFormGet, + BaseFormDetail, BaseFormQuestionsCreate, BaseFormQuestionsDelete, BaseFormQuestionsUpdate, BaseFormQuestionsList, + BaseFormSubmit, BaseDashboardList, BaseDashboardGet, BaseDashboardCreate, diff --git a/shortcuts/common/download_path.go b/shortcuts/common/download_path.go new file mode 100644 index 000000000..24059f8d0 --- /dev/null +++ b/shortcuts/common/download_path.go @@ -0,0 +1,125 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package common + +import ( + "mime" + "net/http" + "path" + "path/filepath" + "strings" + + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" +) + +// DownloadExtensionResolution describes how a file extension was inferred. +type DownloadExtensionResolution struct { + Ext string + Source string + Detail string +} + +var downloadMimeToExt = map[string]string{ + "application/msword": ".doc", + "application/pdf": ".pdf", + "application/vnd.ms-excel": ".xls", + "application/vnd.ms-powerpoint": ".ppt", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/xml": ".xml", + "application/zip": ".zip", + "image/bmp": ".bmp", + "image/gif": ".gif", + "image/jpeg": ".jpg", + "image/png": ".png", + "image/svg+xml": ".svg", + "image/webp": ".webp", + "text/csv": ".csv", + "text/html": ".html", + "text/plain": ".txt", + "text/xml": ".xml", + "video/mp4": ".mp4", +} + +// ResolveDownloadFileName returns a sanitized filename from Content-Disposition, +// falling back to the caller-provided name when the header is absent or invalid. +func ResolveDownloadFileName(header http.Header, fallback string) string { + name := strings.TrimSpace(larkcore.FileNameByHeader(header)) + if name == "" { + name = fallback + } + name = strings.ReplaceAll(strings.TrimSpace(name), "\\", "/") + name = path.Base(name) + if name == "" || name == "." || name == ".." { + return fallback + } + return name +} + +// AutoAppendDownloadExtension appends an inferred file extension when the +// target path has no explicit suffix. If no extension can be inferred, the +// original basename is preserved without adding a synthetic fallback suffix. +func AutoAppendDownloadExtension(outputPath string, header http.Header, fallbackExt string) (string, *DownloadExtensionResolution) { + if hasExplicitDownloadExtension(outputPath) { + return outputPath, nil + } + normalizedPath := outputPath + if filepath.Ext(outputPath) == "." { + normalizedPath = strings.TrimSuffix(outputPath, ".") + } + if resolution := downloadExtensionByContentType(header.Get("Content-Type")); resolution != nil { + return normalizedPath + resolution.Ext, resolution + } + if resolution := downloadExtensionByContentDisposition(header); resolution != nil { + return normalizedPath + resolution.Ext, resolution + } + if fallbackExt != "" { + return normalizedPath + fallbackExt, &DownloadExtensionResolution{ + Ext: fallbackExt, + Source: "fallback", + Detail: "default fallback", + } + } + return normalizedPath, nil +} + +func hasExplicitDownloadExtension(path string) bool { + ext := filepath.Ext(path) + return ext != "" && ext != "." +} + +func downloadExtensionByContentType(contentType string) *DownloadExtensionResolution { + if contentType == "" { + return nil + } + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + mediaType = strings.TrimSpace(strings.Split(contentType, ";")[0]) + } + if ext, ok := downloadMimeToExt[strings.ToLower(mediaType)]; ok { + return &DownloadExtensionResolution{ + Ext: ext, + Source: "Content-Type", + Detail: contentType, + } + } + return nil +} + +func downloadExtensionByContentDisposition(header http.Header) *DownloadExtensionResolution { + filename := strings.TrimSpace(larkcore.FileNameByHeader(header)) + if filename == "" { + return nil + } + ext := filepath.Ext(filename) + if ext == "" || ext == "." { + return nil + } + return &DownloadExtensionResolution{ + Ext: ext, + Source: "Content-Disposition", + Detail: filename, + } +} diff --git a/shortcuts/common/download_path_test.go b/shortcuts/common/download_path_test.go new file mode 100644 index 000000000..100eb9ccd --- /dev/null +++ b/shortcuts/common/download_path_test.go @@ -0,0 +1,115 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package common + +import ( + "net/http" + "testing" +) + +func TestResolveDownloadFileName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + header http.Header + fallback string + want string + }{ + { + name: "content disposition filename wins", + header: http.Header{ + "Content-Disposition": []string{`attachment; filename="report-v7.md"`}, + }, + fallback: "boxcn123", + want: "report-v7.md", + }, + { + name: "path traversal in header is stripped", + header: http.Header{ + "Content-Disposition": []string{`attachment; filename="../nested/report-v7.md"`}, + }, + fallback: "boxcn123", + want: "report-v7.md", + }, + { + name: "fallback when header missing", + header: http.Header{}, + fallback: "boxcn123", + want: "boxcn123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := ResolveDownloadFileName(tt.header, tt.fallback); got != tt.want { + t.Fatalf("ResolveDownloadFileName() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestAutoAppendDownloadExtension(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + header http.Header + want string + }{ + { + name: "explicit extension is preserved", + path: "artifact.bin", + header: http.Header{ + "Content-Type": []string{"text/csv; charset=utf-8"}, + }, + want: "artifact.bin", + }, + { + name: "appends extension from content type", + path: "artifact", + header: http.Header{ + "Content-Type": []string{"text/csv; charset=utf-8"}, + }, + want: "artifact.csv", + }, + { + name: "appends extension from content disposition when content type is generic", + path: "artifact", + header: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Disposition": []string{`attachment; filename="report-v7.md"`}, + }, + want: "artifact.md", + }, + { + name: "trailing dot is normalized before append", + path: "artifact.", + header: http.Header{ + "Content-Type": []string{"text/plain; charset=utf-8"}, + }, + want: "artifact.txt", + }, + { + name: "unknown type keeps suffixless path", + path: "artifact.", + header: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + }, + want: "artifact", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, _ := AutoAppendDownloadExtension(tt.path, tt.header, "") + if got != tt.want { + t.Fatalf("AutoAppendDownloadExtension() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/shortcuts/common/drive_meta.go b/shortcuts/common/drive_meta.go new file mode 100644 index 000000000..ea4d91beb --- /dev/null +++ b/shortcuts/common/drive_meta.go @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package common + +// FetchDriveMetaTitle looks up the document title via the drive metas batch_query API. +func FetchDriveMetaTitle(runtime *RuntimeContext, token, docType string) (string, error) { + data, err := runtime.CallAPI( + "POST", + "/open-apis/drive/v1/metas/batch_query", + nil, + map[string]interface{}{ + "request_docs": []map[string]interface{}{ + { + "doc_token": token, + "doc_type": docType, + }, + }, + }, + ) + if err != nil { + return "", err + } + + metas := GetSlice(data, "metas") + if len(metas) == 0 { + return "", nil + } + meta, _ := metas[0].(map[string]interface{}) + return GetString(meta, "title"), nil +} diff --git a/shortcuts/common/drive_meta_test.go b/shortcuts/common/drive_meta_test.go new file mode 100644 index 000000000..aad18696f --- /dev/null +++ b/shortcuts/common/drive_meta_test.go @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package common + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/httpmock" +) + +var driveMetaTestSeq atomic.Int64 + +func TestFetchDriveMetaTitle(t *testing.T) { + t.Run("returns title from batch_query response", func(t *testing.T) { + runtime, reg := newDriveMetaTestRuntime(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/metas/batch_query", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "metas": []map[string]interface{}{ + {"doc_token": "doxcnABC", "doc_type": "docx", "title": "My Document"}, + }, + }, + }, + }) + + title, err := FetchDriveMetaTitle(runtime, "doxcnABC", "docx") + if err != nil { + t.Fatalf("FetchDriveMetaTitle() error: %v", err) + } + if title != "My Document" { + t.Errorf("title = %q, want %q", title, "My Document") + } + }) + + t.Run("returns empty string when metas is empty", func(t *testing.T) { + runtime, reg := newDriveMetaTestRuntime(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/metas/batch_query", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "metas": []map[string]interface{}{}, + }, + }, + }) + + title, err := FetchDriveMetaTitle(runtime, "doxcnABC", "docx") + if err != nil { + t.Fatalf("FetchDriveMetaTitle() error: %v", err) + } + if title != "" { + t.Errorf("title = %q, want empty string", title) + } + }) + + t.Run("returns empty string when meta has no title", func(t *testing.T) { + runtime, reg := newDriveMetaTestRuntime(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/metas/batch_query", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "metas": []map[string]interface{}{ + {"doc_token": "doxcnABC", "doc_type": "docx"}, + }, + }, + }, + }) + + title, err := FetchDriveMetaTitle(runtime, "doxcnABC", "docx") + if err != nil { + t.Fatalf("FetchDriveMetaTitle() error: %v", err) + } + if title != "" { + t.Errorf("title = %q, want empty string", title) + } + }) + + t.Run("propagates API error", func(t *testing.T) { + runtime, reg := newDriveMetaTestRuntime(t) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/metas/batch_query", + Body: map[string]interface{}{ + "code": 99991668, + "msg": "permission denied", + }, + }) + + _, err := FetchDriveMetaTitle(runtime, "doxcnABC", "docx") + if err == nil { + t.Fatal("FetchDriveMetaTitle() expected error, got nil") + } + }) +} + +func newDriveMetaTestRuntime(t *testing.T) (*RuntimeContext, *httpmock.Registry) { + t.Helper() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + cfg := &core.CliConfig{ + AppID: fmt.Sprintf("drive-meta-test-%d", driveMetaTestSeq.Add(1)), AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, _, _, reg := cmdutil.TestFactory(t, cfg) + runtime := &RuntimeContext{ + ctx: context.Background(), + Config: cfg, + Factory: f, + resolvedAs: core.AsBot, + } + return runtime, reg +} diff --git a/shortcuts/common/extract.go b/shortcuts/common/extract.go index 382977eda..cea127c1e 100644 --- a/shortcuts/common/extract.go +++ b/shortcuts/common/extract.go @@ -33,6 +33,26 @@ func GetFloat(m map[string]interface{}, keys ...string) float64 { return f } +// GetInt safely extracts an int, accepting both in-memory ints and JSON-style float64 values. +func GetInt(m map[string]interface{}, keys ...string) int { + if len(keys) == 0 { + return 0 + } + v := navigate(m, keys[:len(keys)-1]) + if v == nil { + return 0 + } + switch n := v[keys[len(keys)-1]].(type) { + case int: + return n + case int64: + return int(n) + case float64: + return int(n) + } + return 0 +} + // GetBool safely extracts a bool. func GetBool(m map[string]interface{}, keys ...string) bool { if len(keys) == 0 { diff --git a/shortcuts/common/extract_test.go b/shortcuts/common/extract_test.go index 373bfc906..5b57cf153 100644 --- a/shortcuts/common/extract_test.go +++ b/shortcuts/common/extract_test.go @@ -64,6 +64,32 @@ func TestGetFloat(t *testing.T) { } } +func TestGetInt(t *testing.T) { + m := map[string]interface{}{ + "count": 42, + "json_count": 7.0, + "data": map[string]interface{}{ + "score": int64(99), + }, + } + + if got := GetInt(m, "count"); got != 42 { + t.Errorf("GetInt(count) = %d, want 42", got) + } + if got := GetInt(m, "json_count"); got != 7 { + t.Errorf("GetInt(json_count) = %d, want 7", got) + } + if got := GetInt(m, "data", "score"); got != 99 { + t.Errorf("GetInt(data.score) = %d, want 99", got) + } + if got := GetInt(m, "missing"); got != 0 { + t.Errorf("GetInt(missing) = %d, want 0", got) + } + if got := GetInt(m); got != 0 { + t.Errorf("GetInt() = %d, want 0", got) + } +} + func TestGetBool(t *testing.T) { m := map[string]interface{}{ "active": true, diff --git a/shortcuts/common/resource_url.go b/shortcuts/common/resource_url.go index 99b81bd1a..69345ea90 100644 --- a/shortcuts/common/resource_url.go +++ b/shortcuts/common/resource_url.go @@ -4,6 +4,7 @@ package common import ( + "net/url" "strings" "github.com/larksuite/cli/internal/core" @@ -55,3 +56,79 @@ func BuildResourceURL(brand core.LarkBrand, kind, token string) string { return "" } } + +// ResourceRef holds the parsed type and token from a Lark resource URL. +type ResourceRef struct { + Type string // e.g. "docx", "bitable", "wiki", "sheet", etc. + Token string // the token extracted from the URL path +} + +// urlPathToType maps URL path prefixes to resource types. +// Longer prefixes must come first to avoid false matches +// (e.g. "/drive/folder/" before a hypothetical "/drive/"). +// Aliases (e.g. "/bitable/" → "bitable") must come after the +// canonical prefix to keep the list deterministic. +var urlPathToType = []struct { + Prefix string + Type string +}{ + {"/drive/folder/", "folder"}, + {"/docx/", "docx"}, + {"/doc/", "doc"}, + {"/sheets/", "sheet"}, + {"/base/", "bitable"}, + {"/bitable/", "bitable"}, + {"/wiki/", "wiki"}, + {"/file/", "file"}, + {"/mindnote/", "mindnote"}, + {"/slides/", "slides"}, +} + +// ParseResourceURL parses a Lark/Feishu URL and extracts the resource type +// and token from the URL path. It is the inverse of BuildResourceURL. +// +// Supported path patterns: +// +// /docx/TOKEN -> {Type: "docx", Token: TOKEN} +// /doc/TOKEN -> {Type: "doc", Token: TOKEN} +// /sheets/TOKEN -> {Type: "sheet", Token: TOKEN} +// /base/TOKEN -> {Type: "bitable", Token: TOKEN} +// /wiki/TOKEN -> {Type: "wiki", Token: TOKEN} +// /file/TOKEN -> {Type: "file", Token: TOKEN} +// /drive/folder/TOKEN -> {Type: "folder", Token: TOKEN} +// /mindnote/TOKEN -> {Type: "mindnote", Token: TOKEN} +// /slides/TOKEN -> {Type: "slides", Token: TOKEN} +// +// Returns (ResourceRef{}, false) when the URL does not match any known pattern. +func ParseResourceURL(rawURL string) (ResourceRef, bool) { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return ResourceRef{}, false + } + + u, err := url.Parse(rawURL) + if err != nil { + return ResourceRef{}, false + } + + path := u.Path + + for _, mapping := range urlPathToType { + if !strings.HasPrefix(path, mapping.Prefix) { + continue + } + token := path[len(mapping.Prefix):] + // Trim trailing slashes and stop at the next path segment boundary. + token = strings.TrimRight(token, "/") + if idx := strings.IndexByte(token, '/'); idx >= 0 { + token = token[:idx] + } + token = strings.TrimSpace(token) + if token == "" { + return ResourceRef{}, false + } + return ResourceRef{Type: mapping.Type, Token: token}, true + } + + return ResourceRef{}, false +} diff --git a/shortcuts/common/resource_url_test.go b/shortcuts/common/resource_url_test.go index 9ef0d9db1..baa1165f1 100644 --- a/shortcuts/common/resource_url_test.go +++ b/shortcuts/common/resource_url_test.go @@ -9,6 +9,102 @@ import ( "github.com/larksuite/cli/internal/core" ) +func TestParseResourceURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + rawURL string + wantType string + wantToken string + wantOK bool + }{ + // All 9 supported types + {"docx", "https://xxx.feishu.cn/docx/doxcnABC", "docx", "doxcnABC", true}, + {"doc", "https://xxx.feishu.cn/doc/doccnABC", "doc", "doccnABC", true}, + {"sheet", "https://xxx.feishu.cn/sheets/shtcnABC", "sheet", "shtcnABC", true}, + {"bitable via /base/", "https://xxx.feishu.cn/base/bascnABC", "bitable", "bascnABC", true}, + {"bitable via /bitable/", "https://xxx.feishu.cn/bitable/bascnABC", "bitable", "bascnABC", true}, + {"wiki", "https://xxx.feishu.cn/wiki/wikcnABC", "wiki", "wikcnABC", true}, + {"file", "https://xxx.feishu.cn/file/boxcnABC", "file", "boxcnABC", true}, + {"folder", "https://xxx.feishu.cn/drive/folder/fldcnABC", "folder", "fldcnABC", true}, + {"mindnote", "https://xxx.feishu.cn/mindnote/mncnABC", "mindnote", "mncnABC", true}, + {"slides", "https://xxx.feishu.cn/slides/slkcnABC", "slides", "slkcnABC", true}, + + // Lark domain + {"lark docx", "https://xxx.larksuite.com/docx/doxcnABC", "docx", "doxcnABC", true}, + {"lark wiki", "https://xxx.larksuite.com/wiki/wikcnABC", "wiki", "wikcnABC", true}, + + // With query parameters + {"with query", "https://xxx.feishu.cn/docx/doxcnABC?from=wiki", "docx", "doxcnABC", true}, + {"with fragment", "https://xxx.feishu.cn/docx/doxcnABC#section", "docx", "doxcnABC", true}, + + // With trailing slash + {"trailing slash", "https://xxx.feishu.cn/docx/doxcnABC/", "docx", "doxcnABC", true}, + + // With extra path segments after token + {"extra path", "https://xxx.feishu.cn/docx/doxcnABC/edit", "docx", "doxcnABC", true}, + + // Non-Lark host with Lark-like path (host validation is the caller's responsibility) + {"non-lark host with lark path", "https://google.com/docx/doxcnABC", "docx", "doxcnABC", true}, + + // Negative cases + {"unrecognized path", "https://xxx.feishu.cn/calendar/calABC", "", "", false}, + {"non-lark host unrecognized path", "https://example.com/page", "", "", false}, + {"empty input", "", "", "", false}, + {"bare token", "doxcnABC", "", "", false}, + {"invalid url parse", "://not-a-valid-url", "", "", false}, + {"matching prefix but empty token", "https://xxx.feishu.cn/docx/", "", "", false}, + {"matching prefix but whitespace-only token", "https://xxx.feishu.cn/docx/ ", "", "", false}, + {"whitespace-only input", " ", "", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ref, ok := ParseResourceURL(tt.rawURL) + if ok != tt.wantOK { + t.Errorf("ParseResourceURL(%q) ok = %v, want %v", tt.rawURL, ok, tt.wantOK) + } + if ok { + if ref.Type != tt.wantType { + t.Errorf("ParseResourceURL(%q) Type = %q, want %q", tt.rawURL, ref.Type, tt.wantType) + } + if ref.Token != tt.wantToken { + t.Errorf("ParseResourceURL(%q) Token = %q, want %q", tt.rawURL, ref.Token, tt.wantToken) + } + } + }) + } +} + +// TestParseResourceURL_RoundTrip verifies that ParseResourceURL is the inverse +// of BuildResourceURL for all supported types. +func TestParseResourceURL_RoundTrip(t *testing.T) { + t.Parallel() + + types := []string{"docx", "doc", "sheet", "bitable", "wiki", "file", "folder", "mindnote", "slides"} + token := "testTOKEN123" + + for _, kind := range types { + t.Run(kind, func(t *testing.T) { + built := BuildResourceURL(core.BrandFeishu, kind, token) + if built == "" { + t.Fatalf("BuildResourceURL returned empty for kind %q", kind) + } + ref, ok := ParseResourceURL(built) + if !ok { + t.Fatalf("ParseResourceURL(%q) returned ok=false", built) + } + if ref.Type != kind { + t.Errorf("round-trip type mismatch: got %q, want %q", ref.Type, kind) + } + if ref.Token != token { + t.Errorf("round-trip token mismatch: got %q, want %q", ref.Token, token) + } + }) + } +} + func TestBuildResourceURL(t *testing.T) { t.Parallel() diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index 7e4c8ecef..d6f4c1a5d 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -103,13 +103,15 @@ func (ctx *RuntimeContext) fetchBotInfo() (*BotInfo, error) { if resp.StatusCode >= 400 { return nil, fmt.Errorf("fetch bot info: HTTP %d", resp.StatusCode) } + // /open-apis/bot/v3/info returns `{code, msg, bot: {...}}` — the bot + // payload is under "bot", not "data" as the newer Lark API convention. var envelope struct { Code int `json:"code"` Msg string `json:"msg"` Data struct { OpenID string `json:"open_id"` AppName string `json:"app_name"` - } `json:"data"` + } `json:"bot"` } if err := json.Unmarshal(resp.RawBody, &envelope); err != nil { return nil, fmt.Errorf("fetch bot info: unmarshal: %w", err) diff --git a/shortcuts/common/runner_botinfo_test.go b/shortcuts/common/runner_botinfo_test.go index 0ca121f7b..9a4247c0d 100644 --- a/shortcuts/common/runner_botinfo_test.go +++ b/shortcuts/common/runner_botinfo_test.go @@ -57,7 +57,7 @@ func TestFetchBotInfo_Success(t *testing.T) { URL: "/open-apis/bot/v3/info", Body: map[string]interface{}{ "code": 0, "msg": "ok", - "data": map[string]interface{}{ + "bot": map[string]interface{}{ "open_id": "ou_bot_abc123", "app_name": "TestBot", }, @@ -86,7 +86,7 @@ func TestFetchBotInfo_ShortcutHeaders(t *testing.T) { URL: "/open-apis/bot/v3/info", Body: map[string]interface{}{ "code": 0, "msg": "ok", - "data": map[string]interface{}{ + "bot": map[string]interface{}{ "open_id": "ou_bot_header", "app_name": "HeaderBot", }, @@ -119,7 +119,7 @@ func TestFetchBotInfo_OnceSemantics(t *testing.T) { URL: "/open-apis/bot/v3/info", Body: map[string]interface{}{ "code": 0, "msg": "ok", - "data": map[string]interface{}{ + "bot": map[string]interface{}{ "open_id": "ou_bot_once", "app_name": "OnceBot", }, @@ -183,7 +183,7 @@ func TestFetchBotInfo_EmptyOpenID(t *testing.T) { URL: "/open-apis/bot/v3/info", Body: map[string]interface{}{ "code": 0, "msg": "ok", - "data": map[string]interface{}{ + "bot": map[string]interface{}{ "open_id": "", "app_name": "EmptyBot", }, diff --git a/shortcuts/doc/doc_media_insert.go b/shortcuts/doc/doc_media_insert.go index cd34db569..5c31495a5 100644 --- a/shortcuts/doc/doc_media_insert.go +++ b/shortcuts/doc/doc_media_insert.go @@ -7,6 +7,11 @@ import ( "bytes" "context" "fmt" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "io" "path/filepath" "strings" @@ -55,6 +60,8 @@ var DocMediaInsert = common.Shortcut{ {Name: "selection-with-ellipsis", Desc: "plain text (or 'start...end' to disambiguate) matching the target block's content. Media is inserted at the top-level ancestor of the matched block — i.e., when the selection is inside a callout, table cell, or nested list, media lands outside that container, not inside it. Pass 'start...end' (a unique prefix and suffix separated by '...') when the plain text appears in more than one block"}, {Name: "before", Type: "bool", Desc: "insert before the matched block instead of after (requires --selection-with-ellipsis)"}, {Name: "file-view", Desc: "file block rendering: card (default) | preview | inline; only applies when --type=file. preview renders audio/video as an inline player"}, + {Name: "width", Type: "int", Desc: "image display width in pixels (only for --type=image); if --height is omitted it is auto-computed from the source image aspect ratio"}, + {Name: "height", Type: "int", Desc: "image display height in pixels (only for --type=image); if --width is omitted it is auto-computed from the source image aspect ratio"}, }, Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { filePath := runtime.Str("file") @@ -93,6 +100,24 @@ var DocMediaInsert = common.Shortcut{ return output.ErrValidation("--file-view only applies when --type=file") } } + widthChanged := runtime.Changed("width") + heightChanged := runtime.Changed("height") + if (widthChanged || heightChanged) && runtime.Str("type") != "image" { + return output.ErrValidation("--width/--height only apply when --type=image") + } + if widthChanged && runtime.Int("width") <= 0 { + return output.ErrValidation("--width must be a positive integer") + } + if heightChanged && runtime.Int("height") <= 0 { + return output.ErrValidation("--height must be a positive integer") + } + const maxDimension = 10000 + if widthChanged && runtime.Int("width") > maxDimension { + return output.ErrValidation("--width must not exceed %d pixels", maxDimension) + } + if heightChanged && runtime.Int("height") > maxDimension { + return output.ErrValidation("--height must not exceed %d pixels", maxDimension) + } return nil }, DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { @@ -120,7 +145,25 @@ var DocMediaInsert = common.Shortcut{ } else { createBlockData["index"] = "" } - batchUpdateData := buildBatchUpdateData("", mediaType, "", runtime.Str("align"), caption) + // Best-effort dimension computation for dry-run. + dryWidth := runtime.Int("width") + dryHeight := runtime.Int("height") + widthChanged := runtime.Changed("width") + heightChanged := runtime.Changed("height") + + if (widthChanged || heightChanged) && !(widthChanged && heightChanged) { + if filePath == "" { + fmt.Fprintf(runtime.IO().ErrOut, "Note: cannot detect clipboard image dimensions in dry-run; provide both --width and --height for accurate preview\n") + } else if nativeW, nativeH, err := detectImageDimensionsFromPath(runtime.FileIO(), filePath); err == nil { + dims := computeMissingDimension(dryWidth, dryHeight, nativeW, nativeH) + dryWidth = dims.width + dryHeight = dims.height + } else { + fmt.Fprintf(runtime.IO().ErrOut, "Note: unable to detect image dimensions from %s; provide both --width and --height to avoid failure at execution time\n", filePath) + } + } + + batchUpdateData := buildBatchUpdateData("", mediaType, "", runtime.Str("align"), caption, dryWidth, dryHeight) d := common.NewDryRunAPI() totalSteps := 4 @@ -188,6 +231,9 @@ var DocMediaInsert = common.Shortcut{ if runtime.Bool("from-clipboard") { d.Set("upload_size_note", "clipboard size unknown; single-part vs multipart decision deferred to runtime") } + if runtime.Bool("from-clipboard") && (widthChanged || heightChanged) && !(widthChanged && heightChanged) { + d.Set("dimension_note", "clipboard dimensions unknown; aspect-ratio calculation deferred to runtime") + } return d }, Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { @@ -314,6 +360,42 @@ var DocMediaInsert = common.Shortcut{ // interface stays a true nil for the --file path. Passing a typed-nil // *bytes.Reader here would make the downstream `if cfg.Content != nil` // check incorrectly take the clipboard branch and crash on Read. + // Resolve display dimensions before upload to fail fast on unreadable images. + var finalWidth, finalHeight int + if mediaType == "image" { + userWidth := runtime.Int("width") + userHeight := runtime.Int("height") + widthChanged := runtime.Changed("width") + heightChanged := runtime.Changed("height") + + if widthChanged && heightChanged { + finalWidth = userWidth + finalHeight = userHeight + } else if widthChanged || heightChanged { + var nativeW, nativeH int + var dimErr error + if clipboardContent != nil { + nativeW, nativeH, dimErr = detectImageDimensions(bytes.NewReader(clipboardContent)) + } else { + f, openErr := runtime.FileIO().Open(filePath) + if openErr != nil { + return withRollbackWarning(output.ErrValidation( + "unable to detect image dimensions from %s for aspect-ratio calculation; provide both --width and --height", fileName)) + } + nativeW, nativeH, dimErr = detectImageDimensions(f) + f.Close() + } + if dimErr != nil { + return withRollbackWarning(output.ErrValidation( + "unable to detect image dimensions from %s for aspect-ratio calculation; provide both --width and --height", fileName)) + } + dims := computeMissingDimension(userWidth, userHeight, nativeW, nativeH) + finalWidth = dims.width + finalHeight = dims.height + fmt.Fprintf(runtime.IO().ErrOut, "Image dimensions: %dx%d (native: %dx%d)\n", finalWidth, finalHeight, nativeW, nativeH) + } + } + uploadCfg := UploadDocMediaFileConfig{ FilePath: filePath, FileName: fileName, @@ -337,16 +419,23 @@ var DocMediaInsert = common.Shortcut{ if _, err := runtime.CallAPI("PATCH", fmt.Sprintf("/open-apis/docx/v1/documents/%s/blocks/batch_update", validate.EncodePathSegment(documentID)), - nil, buildBatchUpdateData(replaceBlockID, mediaType, fileToken, alignStr, caption)); err != nil { + nil, buildBatchUpdateData(replaceBlockID, mediaType, fileToken, alignStr, caption, finalWidth, finalHeight)); err != nil { return withRollbackWarning(err) } - runtime.Out(map[string]interface{}{ + outData := map[string]interface{}{ "document_id": documentID, "block_id": blockId, "file_token": fileToken, "type": mediaType, - }, nil) + } + if finalWidth > 0 { + outData["width"] = finalWidth + } + if finalHeight > 0 { + outData["height"] = finalHeight + } + runtime.Out(outData, nil) return nil }, } @@ -453,7 +542,51 @@ func resolveDocxDocumentID(runtime *common.RuntimeContext, input string) (string } } -func buildBatchUpdateData(blockID, mediaType, fileToken, alignStr, caption string) map[string]interface{} { +type imageDimensions struct { + width int + height int +} + +func computeMissingDimension(userWidth, userHeight, nativeWidth, nativeHeight int) imageDimensions { + if nativeWidth <= 0 || nativeHeight <= 0 { + return imageDimensions{width: userWidth, height: userHeight} + } + if userWidth > 0 && userHeight == 0 { + return imageDimensions{ + width: userWidth, + height: (userWidth*nativeHeight + nativeWidth/2) / nativeWidth, + } + } + if userHeight > 0 && userWidth == 0 { + return imageDimensions{ + width: (userHeight*nativeWidth + nativeHeight/2) / nativeHeight, + height: userHeight, + } + } + return imageDimensions{width: userWidth, height: userHeight} +} + +func detectImageDimensions(r io.Reader) (width, height int, err error) { + cfg, _, err := image.DecodeConfig(r) + if err != nil { + return 0, 0, err + } + return cfg.Width, cfg.Height, nil +} + +func detectImageDimensionsFromPath(fio fileio.FileIO, filePath string) (int, int, error) { + if _, err := validate.SafeInputPath(filePath); err != nil { + return 0, 0, err + } + f, err := fio.Open(filePath) + if err != nil { + return 0, 0, err + } + defer f.Close() + return detectImageDimensions(f) +} + +func buildBatchUpdateData(blockID, mediaType, fileToken, alignStr, caption string, width, height int) map[string]interface{} { request := map[string]interface{}{ "block_id": blockID, } @@ -465,6 +598,12 @@ func buildBatchUpdateData(blockID, mediaType, fileToken, alignStr, caption strin replaceImage := map[string]interface{}{ "token": fileToken, } + if width > 0 { + replaceImage["width"] = width + } + if height > 0 { + replaceImage["height"] = height + } if alignVal, ok := alignMap[alignStr]; ok { replaceImage["align"] = alignVal } diff --git a/shortcuts/doc/doc_media_insert_test.go b/shortcuts/doc/doc_media_insert_test.go index 71d211f75..19574423b 100644 --- a/shortcuts/doc/doc_media_insert_test.go +++ b/shortcuts/doc/doc_media_insert_test.go @@ -6,6 +6,7 @@ package doc import ( "context" "encoding/json" + "fmt" "reflect" "strings" "testing" @@ -176,7 +177,7 @@ func TestBuildDeleteBlockDataUsesHalfOpenInterval(t *testing.T) { func TestBuildBatchUpdateDataForImage(t *testing.T) { t.Parallel() - got := buildBatchUpdateData("blk_1", "image", "file_tok", "center", "caption text") + got := buildBatchUpdateData("blk_1", "image", "file_tok", "center", "caption text", 0, 0) want := map[string]interface{}{ "requests": []interface{}{ map[string]interface{}{ @@ -199,7 +200,7 @@ func TestBuildBatchUpdateDataForImage(t *testing.T) { func TestBuildBatchUpdateDataForFile(t *testing.T) { t.Parallel() - got := buildBatchUpdateData("blk_2", "file", "file_tok", "", "") + got := buildBatchUpdateData("blk_2", "file", "file_tok", "", "", 0, 0) want := map[string]interface{}{ "requests": []interface{}{ map[string]interface{}{ @@ -215,6 +216,48 @@ func TestBuildBatchUpdateDataForFile(t *testing.T) { } } +func TestBuildBatchUpdateDataForImageWithWidthHeight(t *testing.T) { + t.Parallel() + + got := buildBatchUpdateData("blk_1", "image", "file_tok", "center", "caption text", 800, 447) + want := map[string]interface{}{ + "requests": []interface{}{ + map[string]interface{}{ + "block_id": "blk_1", + "replace_image": map[string]interface{}{ + "token": "file_tok", + "width": 800, + "height": 447, + "align": 2, + "caption": map[string]interface{}{"content": "caption text"}, + }, + }, + }, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("buildBatchUpdateData(image, 800, 447) = %#v, want %#v", got, want) + } +} + +func TestBuildBatchUpdateDataForFileIgnoresWidthHeight(t *testing.T) { + t.Parallel() + + got := buildBatchUpdateData("blk_2", "file", "file_tok", "", "", 800, 600) + want := map[string]interface{}{ + "requests": []interface{}{ + map[string]interface{}{ + "block_id": "blk_2", + "replace_file": map[string]interface{}{ + "token": "file_tok", + }, + }, + }, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("buildBatchUpdateData(file, 800, 600) = %#v, want %#v", got, want) + } +} + func TestExtractAppendTargetUsesRootChildrenCount(t *testing.T) { t.Parallel() @@ -669,10 +712,202 @@ func newMediaInsertValidateRuntime(t *testing.T, doc, mediaType, fileView string return common.TestNewRuntimeContext(cmd, nil) } -// Validate is the real user-facing contract for --file-view: unknown -// values must be rejected, and passing the flag alongside --type!=file -// must also be rejected. buildCreateBlockData tests alone cannot catch -// regressions here, so lock the guard logic down explicitly. +func newMediaInsertValidateRuntimeWithSize(t *testing.T, doc, mediaType string, width, height int, setWidth, setHeight bool) *common.RuntimeContext { + t.Helper() + + cmd := &cobra.Command{Use: "docs +media-insert"} + cmd.Flags().String("file", "", "") + cmd.Flags().Bool("from-clipboard", false, "") + cmd.Flags().String("doc", "", "") + cmd.Flags().String("type", "", "") + cmd.Flags().String("file-view", "", "") + cmd.Flags().Int("width", 0, "") + cmd.Flags().Int("height", 0, "") + cmd.Flags().String("selection-with-ellipsis", "", "") + cmd.Flags().Bool("before", false, "") + if err := cmd.Flags().Set("file", "dummy.bin"); err != nil { + t.Fatalf("set --file: %v", err) + } + if err := cmd.Flags().Set("doc", doc); err != nil { + t.Fatalf("set --doc: %v", err) + } + if err := cmd.Flags().Set("type", mediaType); err != nil { + t.Fatalf("set --type: %v", err) + } + if setWidth { + if err := cmd.Flags().Set("width", fmt.Sprintf("%d", width)); err != nil { + t.Fatalf("set --width: %v", err) + } + } + if setHeight { + if err := cmd.Flags().Set("height", fmt.Sprintf("%d", height)); err != nil { + t.Fatalf("set --height: %v", err) + } + } + return common.TestNewRuntimeContext(cmd, nil) +} + +func TestDocMediaInsertValidateWidthHeightOnlyForImage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mediaType string + width int + height int + setWidth bool + setHeight bool + wantErr string + }{ + { + name: "width with file type is rejected", + mediaType: "file", + width: 800, + setWidth: true, + wantErr: "--width/--height only apply when --type=image", + }, + { + name: "height with file type is rejected", + mediaType: "file", + height: 600, + setHeight: true, + wantErr: "--width/--height only apply when --type=image", + }, + { + name: "explicit zero width is rejected", + mediaType: "image", + width: 0, + setWidth: true, + wantErr: "--width must be a positive integer", + }, + { + name: "negative width is rejected", + mediaType: "image", + width: -1, + setWidth: true, + wantErr: "--width must be a positive integer", + }, + { + name: "negative height is rejected", + mediaType: "image", + height: -5, + setHeight: true, + wantErr: "--height must be a positive integer", + }, + { + name: "valid width with image type is accepted", + mediaType: "image", + width: 800, + setWidth: true, + }, + { + name: "valid width and height with image type is accepted", + mediaType: "image", + width: 800, + height: 600, + setWidth: true, + setHeight: true, + }, + } + + for _, ttTemp := range tests { + tt := ttTemp + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + rt := newMediaInsertValidateRuntimeWithSize(t, "doxcnValidateSize", tt.mediaType, tt.width, tt.height, tt.setWidth, tt.setHeight) + err := DocMediaInsert.Validate(context.Background(), rt) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("Validate() unexpected error: %v", err) + } + return + } + if err == nil { + t.Fatalf("Validate() error = nil, want error containing %q", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Validate() error = %q, want substring %q", err.Error(), tt.wantErr) + } + }) + } +} + +func TestDocMediaInsertValidateNoWidthHeightIsValid(t *testing.T) { + t.Parallel() + + rt := newMediaInsertValidateRuntimeWithSize(t, "doxcnNoSize", "image", 0, 0, false, false) + err := DocMediaInsert.Validate(context.Background(), rt) + if err != nil { + t.Fatalf("Validate() unexpected error when neither --width nor --height passed: %v", err) + } +} + +func TestAutoAspectRatioFromWidth(t *testing.T) { + t.Parallel() + + // Native image: 1200x800 (3:2 ratio) + // User provides width=600 → expected height = 600 * 800 / 1200 = 400 + got := computeMissingDimension(600, 0, 1200, 800) + wantWidth, wantHeight := 600, 400 + if got.width != wantWidth || got.height != wantHeight { + t.Fatalf("computeMissingDimension(600, 0, 1200, 800) = (%d, %d), want (%d, %d)", got.width, got.height, wantWidth, wantHeight) + } +} + +func TestAutoAspectRatioFromHeight(t *testing.T) { + t.Parallel() + + // Native image: 1200x800 (3:2 ratio) + // User provides height=400 → expected width = 400 * 1200 / 800 = 600 + got := computeMissingDimension(0, 400, 1200, 800) + wantWidth, wantHeight := 600, 400 + if got.width != wantWidth || got.height != wantHeight { + t.Fatalf("computeMissingDimension(0, 400, 1200, 800) = (%d, %d), want (%d, %d)", got.width, got.height, wantWidth, wantHeight) + } +} + +func TestComputeMissingDimensionBothProvided(t *testing.T) { + t.Parallel() + got := computeMissingDimension(800, 600, 1200, 900) + if got.width != 800 || got.height != 600 { + t.Fatalf("computeMissingDimension(800, 600, 1200, 900) = (%d, %d), want (800, 600)", got.width, got.height) + } +} + +func TestComputeMissingDimensionNeitherProvided(t *testing.T) { + t.Parallel() + got := computeMissingDimension(0, 0, 1200, 900) + if got.width != 0 || got.height != 0 { + t.Fatalf("computeMissingDimension(0, 0, 1200, 900) = (%d, %d), want (0, 0)", got.width, got.height) + } +} + +func TestComputeMissingDimensionZeroNativeWidth(t *testing.T) { + t.Parallel() + got := computeMissingDimension(600, 0, 0, 800) + if got.width != 600 || got.height != 0 { + t.Fatalf("computeMissingDimension(600, 0, 0, 800) = (%d, %d), want (600, 0)", got.width, got.height) + } +} + +func TestComputeMissingDimensionZeroNativeHeight(t *testing.T) { + t.Parallel() + got := computeMissingDimension(0, 400, 1200, 0) + if got.width != 0 || got.height != 400 { + t.Fatalf("computeMissingDimension(0, 400, 1200, 0) = (%d, %d), want (0, 400)", got.width, got.height) + } +} + +func TestComputeMissingDimensionRounding(t *testing.T) { + t.Parallel() + got := computeMissingDimension(999, 0, 1000, 333) + want := (999*333 + 500) / 1000 + if got.height != want { + t.Fatalf("computeMissingDimension(999, 0, 1000, 333).height = %d, want %d (rounded)", got.height, want) + } +} + func TestDocMediaInsertValidateFileView(t *testing.T) { t.Parallel() diff --git a/shortcuts/doc/docs_update.go b/shortcuts/doc/docs_update.go index a4ed62afc..f8753cad6 100644 --- a/shortcuts/doc/docs_update.go +++ b/shortcuts/doc/docs_update.go @@ -6,6 +6,7 @@ package doc import ( "context" "fmt" + "regexp" "strings" "github.com/spf13/cobra" @@ -118,7 +119,7 @@ func validateUpdateV1(_ context.Context, runtime *common.RuntimeContext) error { } if needsSelectionV1[mode] && selEllipsis == "" && selTitle == "" { - return common.FlagErrorf("--%s mode requires --selection-with-ellipsis or --selection-by-title", mode) + return common.FlagErrorf(selectionRequiredMessageV1(mode)) } if err := validateSelectionByTitleV1(selTitle); err != nil { return err @@ -127,6 +128,14 @@ func validateUpdateV1(_ context.Context, runtime *common.RuntimeContext) error { return nil } +func selectionRequiredMessageV1(mode string) string { + msg := fmt.Sprintf("--%s mode requires --selection-with-ellipsis or --selection-by-title", mode) + if mode == "replace_all" { + msg += ". If you intended to replace the entire document body, use --mode overwrite instead." + } + return msg +} + func validateSelectionByTitleV1(title string) error { if title == "" { return nil @@ -160,6 +169,16 @@ func executeUpdateV1(_ context.Context, runtime *common.RuntimeContext) error { fmt.Fprintf(runtime.IO().ErrOut, "warning: %s\n", w) } + // Overwrite replaces the entire document, silently discarding any + // whiteboard or file-attachment blocks that cannot be re-created from + // Markdown. Pre-fetch the current content and warn when such blocks + // are present so the caller can take a backup before proceeding. + if runtime.Str("mode") == "overwrite" { + if w := warnOverwriteResourceBlocks(runtime); w != "" { + fmt.Fprintf(runtime.IO().ErrOut, "warning: %s\n", w) + } + } + // Surface callout type= hint so users know to switch to background-color/ // border-color when they want a colored callout. Non-blocking, advisory. if md := runtime.Str("markdown"); md != "" { @@ -197,3 +216,74 @@ func buildUpdateArgsV1(runtime *common.RuntimeContext) map[string]interface{} { } return args } + +// resourceBlockRe matches the opening of a or tag +// (followed by whitespace, > or /) to avoid false positives on tag names like +// or prose that merely mentions the word "whiteboard". +var resourceBlockRe = regexp.MustCompile(`<(whiteboard|file)[\s/>]`) + +// warnOverwriteResourceBlocks pre-fetches the current document and returns a +// non-empty warning string when the document contains whiteboard or file +// attachment blocks that would be permanently deleted by an overwrite. Returns +// an empty string (no warning) when the document is clean or the fetch fails +// (we never block the overwrite on a best-effort check). +// +// This function is not unit-tested because it depends on an external MCP call +// (fetch-doc). The pure detection logic lives in checkOverwriteResourceBlocks, +// which has full table-driven coverage. +// +// Performance: this adds one extra fetch-doc round-trip to every --mode overwrite +// call, even when the document has no resource blocks. The cost is intentional: +// the guard is best-effort and silent on failure, so the latency is bounded and +// the trade-off is acceptable to avoid silent data loss. +func warnOverwriteResourceBlocks(runtime *common.RuntimeContext) string { + args := map[string]interface{}{ + "doc_id": runtime.Str("doc"), + // skip_task_detail reduces response payload by omitting per-block task + // metadata, making the pre-fetch faster and cheaper. + "skip_task_detail": true, + } + result, err := common.CallMCPTool(runtime, "fetch-doc", args) + if err != nil { + // Fetch failed — silently skip the guard rather than blocking overwrite. + return "" + } + md, _ := result["markdown"].(string) + return checkOverwriteResourceBlocks(md) +} + +// checkOverwriteResourceBlocks scans Markdown for resource block tags that +// cannot survive an overwrite: and . Returns a +// warning string listing the counts if any are found, empty string otherwise. +func checkOverwriteResourceBlocks(markdown string) string { + matches := resourceBlockRe.FindAllStringSubmatch(markdown, -1) + whiteboards, files := 0, 0 + for _, m := range matches { + switch m[1] { + case "whiteboard": + whiteboards++ + case "file": + files++ + } + } + var found []string + if whiteboards == 1 { + found = append(found, "1 whiteboard block") + } else if whiteboards > 1 { + found = append(found, fmt.Sprintf("%d whiteboard blocks", whiteboards)) + } + if files == 1 { + found = append(found, "1 file attachment block") + } else if files > 1 { + found = append(found, fmt.Sprintf("%d file attachment blocks", files)) + } + if len(found) == 0 { + return "" + } + return fmt.Sprintf( + "the document contains %s that cannot be reconstructed from Markdown; "+ + "overwrite will permanently delete them. "+ + "Consider fetching a backup with `docs +fetch` before overwriting.", + strings.Join(found, " and "), + ) +} diff --git a/shortcuts/doc/docs_update_test.go b/shortcuts/doc/docs_update_test.go index 1da30b344..6ae06d277 100644 --- a/shortcuts/doc/docs_update_test.go +++ b/shortcuts/doc/docs_update_test.go @@ -4,6 +4,7 @@ package doc import ( "reflect" + "strings" "testing" ) @@ -32,6 +33,33 @@ func TestValidCommandsV2(t *testing.T) { // ── V1 tests ── +func TestSelectionRequiredMessageV1ReplaceAllSuggestsOverwrite(t *testing.T) { + t.Parallel() + + msg := selectionRequiredMessageV1("replace_all") + for _, needle := range []string{ + "--replace_all mode requires --selection-with-ellipsis or --selection-by-title", + "replace the entire document body", + "--mode overwrite", + } { + if !strings.Contains(msg, needle) { + t.Fatalf("message missing %q: %s", needle, msg) + } + } +} + +func TestSelectionRequiredMessageV1OtherModesDoNotSuggestOverwrite(t *testing.T) { + t.Parallel() + + msg := selectionRequiredMessageV1("replace_range") + if strings.Contains(msg, "--mode overwrite") { + t.Fatalf("replace_range message should not suggest overwrite: %s", msg) + } + if !strings.Contains(msg, "--replace_range mode requires --selection-with-ellipsis or --selection-by-title") { + t.Fatalf("unexpected message: %s", msg) + } +} + func TestIsWhiteboardCreateMarkdown(t *testing.T) { t.Run("blank whiteboard tags", func(t *testing.T) { markdown := "\n" @@ -55,6 +83,72 @@ func TestIsWhiteboardCreateMarkdown(t *testing.T) { }) } +func TestCheckOverwriteResourceBlocks(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + markdown string + wantWarn bool + wantSubs []string + }{ + { + name: "empty markdown is clean", + markdown: "", + wantWarn: false, + }, + { + name: "plain prose is clean", + markdown: "## Heading\n\nsome text", + wantWarn: false, + }, + { + name: "single whiteboard triggers warning", + markdown: ``, + wantWarn: true, + wantSubs: []string{"1 whiteboard block", "overwrite"}, + }, + { + name: "multiple whiteboards counted", + markdown: "\n", + wantWarn: true, + wantSubs: []string{"2 whiteboard blocks"}, + }, + { + name: "single file attachment triggers warning", + markdown: ``, + wantWarn: true, + wantSubs: []string{"1 file attachment block"}, + }, + { + name: "multiple file attachments counted", + markdown: "\n\n", + wantWarn: true, + wantSubs: []string{"3 file attachment blocks"}, + }, + { + name: "whiteboard and file together both counted", + markdown: "\n", + wantWarn: true, + wantSubs: []string{"1 whiteboard block", "1 file attachment block"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := checkOverwriteResourceBlocks(tt.markdown) + if (got != "") != tt.wantWarn { + t.Fatalf("checkOverwriteResourceBlocks(%q) = %q, wantWarn=%v", tt.markdown, got, tt.wantWarn) + } + for _, sub := range tt.wantSubs { + if !strings.Contains(got, sub) { + t.Errorf("expected warning to contain %q, got: %s", sub, got) + } + } + }) + } +} + func TestNormalizeWhiteboardResult(t *testing.T) { t.Run("adds empty board_tokens when whiteboard creation response omits it", func(t *testing.T) { result := map[string]interface{}{ @@ -101,3 +195,35 @@ func TestNormalizeWhiteboardResult(t *testing.T) { } }) } + +func TestValidateSelectionByTitleV1(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + title string + wantErr bool + errSub string + }{ + {name: "empty title is valid", title: "", wantErr: false}, + {name: "single heading is valid", title: "## Section", wantErr: false}, + {name: "h1 heading is valid", title: "# Top", wantErr: false}, + {name: "deep heading is valid", title: "### Sub-section", wantErr: false}, + {name: "missing hash prefix is invalid", title: "No hash", wantErr: true, errSub: "'#'"}, + {name: "multiline title is invalid", title: "## First\n## Second", wantErr: true, errSub: "single"}, + {name: "title with embedded carriage return is invalid", title: "## Title\r## Next", wantErr: true, errSub: "single"}, + {name: "leading-space heading is valid after trim", title: " ## Section", wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateSelectionByTitleV1(tt.title) + if (err != nil) != tt.wantErr { + t.Fatalf("validateSelectionByTitleV1(%q) error = %v, wantErr = %v", tt.title, err, tt.wantErr) + } + if tt.wantErr && tt.errSub != "" && !strings.Contains(err.Error(), tt.errSub) { + t.Errorf("expected error to contain %q, got: %v", tt.errSub, err) + } + }) + } +} diff --git a/shortcuts/drive/drive_export.go b/shortcuts/drive/drive_export.go index 533accce9..a588d01d5 100644 --- a/shortcuts/drive/drive_export.go +++ b/shortcuts/drive/drive_export.go @@ -12,6 +12,7 @@ import ( "time" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/validate" "github.com/larksuite/cli/shortcuts/common" ) @@ -25,6 +26,7 @@ var DriveExport = common.Shortcut{ Scopes: []string{ "docs:document.content:read", "docs:document:export", + "docx:document:readonly", "drive:drive.metadata:readonly", }, AuthTypes: []string{"user", "bot"}, @@ -52,16 +54,15 @@ var DriveExport = common.Shortcut{ FileExtension: runtime.Str("file-extension"), SubID: runtime.Str("sub-id"), } - // Markdown export is a special case: docx markdown comes from docs content - // directly instead of the Drive export task API. + // Markdown export is a special case: docx markdown comes from the V2 + // docs_ai fetch API directly instead of the Drive export task API. if spec.FileExtension == "markdown" { + apiPath := fmt.Sprintf("/open-apis/docs_ai/v1/documents/%s/fetch", validate.EncodePathSegment(spec.Token)) dr := common.NewDryRunAPI(). Desc("2-step orchestration: fetch docx markdown -> write local file"). - GET("/open-apis/docs/v1/content"). - Params(map[string]interface{}{ - "doc_token": spec.Token, - "doc_type": "docx", - "content_type": "markdown", + POST(apiPath). + Body(map[string]interface{}{ + "format": "markdown", }). Set("output_dir", runtime.Str("output-dir")) if name := strings.TrimSpace(runtime.Str("file-name")); name != "" { @@ -101,28 +102,38 @@ var DriveExport = common.Shortcut{ overwrite := runtime.Bool("overwrite") // Markdown export bypasses the async export task and writes the fetched - // markdown content directly to disk. + // markdown content directly to disk. Uses the V2 docs_ai fetch API for + // higher-quality Lark-flavored Markdown output. if spec.FileExtension == "markdown" { fmt.Fprintf(runtime.IO().ErrOut, "Exporting docx as markdown: %s\n", common.MaskToken(spec.Token)) - data, err := runtime.CallAPI( - "GET", - "/open-apis/docs/v1/content", + apiPath := fmt.Sprintf("/open-apis/docs_ai/v1/documents/%s/fetch", validate.EncodePathSegment(spec.Token)) + data, err := runtime.DoAPIJSONWithLogID( + "POST", + apiPath, + nil, map[string]interface{}{ - "doc_token": spec.Token, - "doc_type": "docx", - "content_type": "markdown", + "format": "markdown", }, - nil, ) if err != nil { return err } + // Extract content from the V2 response: data.document.content + doc, ok := data["document"].(map[string]interface{}) + if !ok { + return output.Errorf(output.ExitAPI, "api_error", "invalid markdown fetch response: missing document object") + } + content, ok := doc["content"].(string) + if !ok { + return output.Errorf(output.ExitAPI, "api_error", "invalid markdown fetch response: missing document.content") + } + fileName := preferredFileName if fileName == "" { // Prefer the remote title for the exported file name, but still fall // back to the token if metadata is empty. - title, err := fetchDriveMetaTitle(runtime, spec.Token, spec.DocType) + title, err := common.FetchDriveMetaTitle(runtime, spec.Token, spec.DocType) if err != nil { fmt.Fprintf(runtime.IO().ErrOut, "Title lookup failed, using token as filename: %v\n", err) title = spec.Token @@ -130,7 +141,7 @@ var DriveExport = common.Shortcut{ fileName = title } fileName = ensureExportFileExtension(sanitizeExportFileName(fileName, spec.Token), spec.FileExtension) - savedPath, err := saveContentToOutputDir(runtime.FileIO(), outputDir, fileName, []byte(common.GetString(data, "content")), overwrite) + savedPath, err := saveContentToOutputDir(runtime.FileIO(), outputDir, fileName, []byte(content), overwrite) if err != nil { return err } @@ -141,7 +152,7 @@ var DriveExport = common.Shortcut{ "file_extension": spec.FileExtension, "file_name": filepath.Base(savedPath), "saved_path": savedPath, - "size_bytes": len([]byte(common.GetString(data, "content"))), + "size_bytes": len(content), }, nil) return nil } diff --git a/shortcuts/drive/drive_export_common.go b/shortcuts/drive/drive_export_common.go index a9382c66f..e0c9331db 100644 --- a/shortcuts/drive/drive_export_common.go +++ b/shortcuts/drive/drive_export_common.go @@ -228,34 +228,6 @@ func parseDriveExportStatus(ticket string, data map[string]interface{}) driveExp return status } -// fetchDriveMetaTitle looks up the document title so exported files can use a -// human-readable default name when possible. -func fetchDriveMetaTitle(runtime *common.RuntimeContext, token, docType string) (string, error) { - data, err := runtime.CallAPI( - "POST", - "/open-apis/drive/v1/metas/batch_query", - nil, - map[string]interface{}{ - "request_docs": []map[string]interface{}{ - { - "doc_token": token, - "doc_type": docType, - }, - }, - }, - ) - if err != nil { - return "", err - } - - metas := common.GetSlice(data, "metas") - if len(metas) == 0 { - return "", nil - } - meta, _ := metas[0].(map[string]interface{}) - return common.GetString(meta, "title"), nil -} - // saveContentToOutputDir validates the target path, enforces overwrite policy, // and writes the payload atomically via FileIO.Save. func saveContentToOutputDir(fio fileio.FileIO, outputDir, fileName string, payload []byte, overwrite bool) (string, error) { diff --git a/shortcuts/drive/drive_export_test.go b/shortcuts/drive/drive_export_test.go index 780118204..8f277a9fa 100644 --- a/shortcuts/drive/drive_export_test.go +++ b/shortcuts/drive/drive_export_test.go @@ -81,16 +81,19 @@ func TestValidateDriveExportSpec(t *testing.T) { func TestDriveExportMarkdownWritesFile(t *testing.T) { f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) - reg.Register(&httpmock.Stub{ - Method: "GET", - URL: "/open-apis/docs/v1/content", + fetchStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/docs_ai/v1/documents/docx123/fetch", Body: map[string]interface{}{ "code": 0, "data": map[string]interface{}{ - "content": "# hello\n", + "document": map[string]interface{}{ + "content": "# hello\n", + }, }, }, - }) + } + reg.Register(fetchStub) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/drive/v1/metas/batch_query", @@ -118,6 +121,14 @@ func TestDriveExportMarkdownWritesFile(t *testing.T) { t.Fatalf("unexpected error: %v", err) } + var reqBody map[string]interface{} + if err := json.Unmarshal(fetchStub.CapturedBody, &reqBody); err != nil { + t.Fatalf("unmarshal docs_ai fetch body: %v", err) + } + if reqBody["format"] != "markdown" { + t.Fatalf("docs_ai fetch body format = %v, want %q", reqBody["format"], "markdown") + } + data, err := os.ReadFile(filepath.Join(tmpDir, "Weekly Notes.md")) if err != nil { t.Fatalf("ReadFile() error: %v", err) @@ -132,16 +143,19 @@ func TestDriveExportMarkdownWritesFile(t *testing.T) { func TestDriveExportMarkdownUsesProvidedFileName(t *testing.T) { f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) - reg.Register(&httpmock.Stub{ - Method: "GET", - URL: "/open-apis/docs/v1/content", + fetchStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/docs_ai/v1/documents/docx123/fetch", Body: map[string]interface{}{ "code": 0, "data": map[string]interface{}{ - "content": "# custom\n", + "document": map[string]interface{}{ + "content": "# custom\n", + }, }, }, - }) + } + reg.Register(fetchStub) tmpDir := t.TempDir() withDriveWorkingDir(t, tmpDir) @@ -158,6 +172,14 @@ func TestDriveExportMarkdownUsesProvidedFileName(t *testing.T) { t.Fatalf("unexpected error: %v", err) } + var reqBody map[string]interface{} + if err := json.Unmarshal(fetchStub.CapturedBody, &reqBody); err != nil { + t.Fatalf("unmarshal docs_ai fetch body: %v", err) + } + if reqBody["format"] != "markdown" { + t.Fatalf("docs_ai fetch body format = %v, want %q", reqBody["format"], "markdown") + } + data, err := os.ReadFile(filepath.Join(tmpDir, "custom-notes.md")) if err != nil { t.Fatalf("ReadFile() error: %v", err) @@ -179,7 +201,7 @@ func TestDriveExportDryRunIncludesLocalFileNameMetadata(t *testing.T) { }{ { name: "markdown", - wantURL: "/open-apis/docs/v1/content", + wantURL: "/open-apis/docs_ai/v1/documents/docx123/fetch", wantFileName: `"file_name": "notes.md"`, args: []string{ "+export", @@ -233,16 +255,19 @@ func TestDriveExportDryRunIncludesLocalFileNameMetadata(t *testing.T) { func TestDriveExportMarkdownFallsBackToTokenWhenTitleLookupFails(t *testing.T) { f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) - reg.Register(&httpmock.Stub{ - Method: "GET", - URL: "/open-apis/docs/v1/content", + fetchStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/docs_ai/v1/documents/docx123/fetch", Body: map[string]interface{}{ "code": 0, "data": map[string]interface{}{ - "content": "# fallback\n", + "document": map[string]interface{}{ + "content": "# fallback\n", + }, }, }, - }) + } + reg.Register(fetchStub) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/drive/v1/metas/batch_query", @@ -267,6 +292,14 @@ func TestDriveExportMarkdownFallsBackToTokenWhenTitleLookupFails(t *testing.T) { t.Fatalf("unexpected error: %v", err) } + var reqBody map[string]interface{} + if err := json.Unmarshal(fetchStub.CapturedBody, &reqBody); err != nil { + t.Fatalf("unmarshal docs_ai fetch body: %v", err) + } + if reqBody["format"] != "markdown" { + t.Fatalf("docs_ai fetch body format = %v, want %q", reqBody["format"], "markdown") + } + data, err := os.ReadFile(filepath.Join(tmpDir, "docx123.md")) if err != nil { t.Fatalf("ReadFile() error: %v", err) @@ -279,6 +312,76 @@ func TestDriveExportMarkdownFallsBackToTokenWhenTitleLookupFails(t *testing.T) { } } +func TestDriveExportMarkdownRejectsMissingDocumentObject(t *testing.T) { + f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/docs_ai/v1/documents/docx123/fetch", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{}, + }, + }) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + err := mountAndRunDrive(t, DriveExport, []string{ + "+export", + "--token", "docx123", + "--doc-type", "docx", + "--file-extension", "markdown", + "--as", "bot", + }, f, nil) + if err == nil { + t.Fatal("expected error for missing document object, got nil") + } + + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured exit error, got %v", err) + } + if !strings.Contains(exitErr.Detail.Message, "missing document object") { + t.Fatalf("error message = %q, want mention of missing document object", exitErr.Detail.Message) + } +} + +func TestDriveExportMarkdownRejectsMissingDocumentContent(t *testing.T) { + f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/docs_ai/v1/documents/docx123/fetch", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "document": map[string]interface{}{}, + }, + }, + }) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + err := mountAndRunDrive(t, DriveExport, []string{ + "+export", + "--token", "docx123", + "--doc-type", "docx", + "--file-extension", "markdown", + "--as", "bot", + }, f, nil) + if err == nil { + t.Fatal("expected error for missing document.content, got nil") + } + + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured exit error, got %v", err) + } + if !strings.Contains(exitErr.Detail.Message, "missing document.content") { + t.Fatalf("error message = %q, want mention of missing document.content", exitErr.Detail.Message) + } +} + func TestDriveExportAsyncSuccess(t *testing.T) { f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) reg.Register(&httpmock.Stub{ diff --git a/shortcuts/drive/drive_import.go b/shortcuts/drive/drive_import.go index 1b04fe7c2..330c119ea 100644 --- a/shortcuts/drive/drive_import.go +++ b/shortcuts/drive/drive_import.go @@ -31,6 +31,7 @@ var DriveImport = common.Shortcut{ {Name: "type", Desc: "target document type (docx, sheet, bitable)", Required: true}, {Name: "folder-token", Desc: "target folder token (omit for root folder; API accepts empty mount_key as root)"}, {Name: "name", Desc: "imported file name (default: local file name without extension)"}, + {Name: "target-token", Desc: "existing token to import data into (only for type=bitable); when set, data is mounted into this bitable instead of creating a new one"}, }, Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { return validateDriveImportSpec(driveImportSpec{ @@ -38,6 +39,7 @@ var DriveImport = common.Shortcut{ DocType: strings.ToLower(runtime.Str("type")), FolderToken: runtime.Str("folder-token"), Name: runtime.Str("name"), + TargetToken: runtime.Str("target-token"), }) }, DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { @@ -46,11 +48,15 @@ var DriveImport = common.Shortcut{ DocType: strings.ToLower(runtime.Str("type")), FolderToken: runtime.Str("folder-token"), Name: runtime.Str("name"), + TargetToken: runtime.Str("target-token"), } fileSize, err := preflightDriveImportFile(runtime.FileIO(), &spec) if err != nil { return common.NewDryRunAPI().Set("error", err.Error()) } + if valErr := validateDriveImportSpec(spec); valErr != nil { + return common.NewDryRunAPI().Set("error", valErr.Error()) + } dry := common.NewDryRunAPI() dry.Desc("Upload file (single-part or multipart) -> create import task -> poll status") @@ -76,6 +82,7 @@ var DriveImport = common.Shortcut{ DocType: strings.ToLower(runtime.Str("type")), FolderToken: runtime.Str("folder-token"), Name: runtime.Str("name"), + TargetToken: runtime.Str("target-token"), } if _, err := preflightDriveImportFile(runtime.FileIO(), &spec); err != nil { return err diff --git a/shortcuts/drive/drive_import_common.go b/shortcuts/drive/drive_import_common.go index 210bfb394..ed5fa318e 100644 --- a/shortcuts/drive/drive_import_common.go +++ b/shortcuts/drive/drive_import_common.go @@ -51,6 +51,7 @@ type driveImportSpec struct { DocType string FolderToken string Name string + TargetToken string // existing bitable token to import data into (only for type=bitable) } func (s driveImportSpec) FileExtension() string { @@ -67,7 +68,7 @@ func (s driveImportSpec) TargetFileName() string { // CreateTaskBody builds the request body expected by /drive/v1/import_tasks. func (s driveImportSpec) CreateTaskBody(fileToken string) map[string]interface{} { - return map[string]interface{}{ + body := map[string]interface{}{ "file_extension": s.FileExtension(), "file_token": fileToken, "type": s.DocType, @@ -79,6 +80,12 @@ func (s driveImportSpec) CreateTaskBody(fileToken string) map[string]interface{} "mount_key": s.FolderToken, }, } + + if s.DocType == "bitable" && s.TargetToken != "" { + body["token"] = s.TargetToken + } + + return body } // uploadMediaForImport uploads the source file to the temporary import media @@ -232,6 +239,15 @@ func validateDriveImportSpec(spec driveImportSpec) error { } } + if strings.TrimSpace(spec.TargetToken) != "" { + if spec.DocType != "bitable" { + return output.ErrValidation("--target-token is only supported when --type is bitable") + } + if err := validate.ResourceName(spec.TargetToken, "--target-token"); err != nil { + return output.ErrValidation("%s", err) + } + } + return nil } diff --git a/shortcuts/drive/drive_import_common_test.go b/shortcuts/drive/drive_import_common_test.go index 674b22832..2c786eed8 100644 --- a/shortcuts/drive/drive_import_common_test.go +++ b/shortcuts/drive/drive_import_common_test.go @@ -45,6 +45,19 @@ func TestValidateDriveImportSpec(t *testing.T) { spec: driveImportSpec{FilePath: "./data.rtf", DocType: "docx"}, wantErr: "unsupported file extension", }, + { + name: "target-token rejected for non-bitable type", + spec: driveImportSpec{FilePath: "./data.xlsx", DocType: "sheet", TargetToken: "bascnxxx"}, + wantErr: "--target-token is only supported when --type is bitable", + }, + { + name: "target-token accepted for bitable", + spec: driveImportSpec{FilePath: "./data.xlsx", DocType: "bitable", TargetToken: "bascnxxx"}, + }, + { + name: "target-token empty for bitable still ok", + spec: driveImportSpec{FilePath: "./data.xlsx", DocType: "bitable"}, + }, } for _, tt := range tests { diff --git a/shortcuts/drive/drive_import_test.go b/shortcuts/drive/drive_import_test.go index 1f8d1f704..c5b6aa10f 100644 --- a/shortcuts/drive/drive_import_test.go +++ b/shortcuts/drive/drive_import_test.go @@ -84,6 +84,7 @@ func TestDriveImportDryRunUsesExtensionlessDefaultName(t *testing.T) { cmd.Flags().String("type", "", "") cmd.Flags().String("folder-token", "", "") cmd.Flags().String("name", "", "") + cmd.Flags().String("target-token", "", "") if err := cmd.Flags().Set("file", "./base-import.xlsx"); err != nil { t.Fatalf("set --file: %v", err) } @@ -148,6 +149,7 @@ func TestDriveImportDryRunShowsMultipartUploadForLargeFile(t *testing.T) { cmd.Flags().String("type", "", "") cmd.Flags().String("folder-token", "", "") cmd.Flags().String("name", "", "") + cmd.Flags().String("target-token", "", "") if err := cmd.Flags().Set("file", "./large.xlsx"); err != nil { t.Fatalf("set --file: %v", err) } @@ -197,6 +199,7 @@ func TestDriveImportDryRunReturnsErrorForUnsafePath(t *testing.T) { cmd.Flags().String("type", "", "") cmd.Flags().String("folder-token", "", "") cmd.Flags().String("name", "", "") + cmd.Flags().String("target-token", "", "") if err := cmd.Flags().Set("file", "../outside.md"); err != nil { t.Fatalf("set --file: %v", err) } @@ -250,6 +253,7 @@ func TestDriveImportDryRunReturnsErrorForOversizedMarkdown(t *testing.T) { cmd.Flags().String("type", "", "") cmd.Flags().String("folder-token", "", "") cmd.Flags().String("name", "", "") + cmd.Flags().String("target-token", "", "") if err := cmd.Flags().Set("file", "./large.md"); err != nil { t.Fatalf("set --file: %v", err) } @@ -296,6 +300,7 @@ func TestDriveImportDryRunReturnsErrorForDirectoryInput(t *testing.T) { cmd.Flags().String("type", "", "") cmd.Flags().String("folder-token", "", "") cmd.Flags().String("name", "", "") + cmd.Flags().String("target-token", "", "") if err := cmd.Flags().Set("file", "./folder-input"); err != nil { t.Fatalf("set --file: %v", err) } @@ -366,6 +371,165 @@ func TestDriveImportCreateTaskBodyKeepsEmptyMountKeyForRoot(t *testing.T) { } } +func TestDriveImportCreateTaskBodyWithTargetToken(t *testing.T) { + t.Parallel() + + spec := driveImportSpec{ + FilePath: "/tmp/data.xlsx", + DocType: "bitable", + TargetToken: "bascnxxxxx", + } + + body := spec.CreateTaskBody("file_token_test") + + // point stays the same as default (mount_type=1) + point, ok := body["point"].(map[string]interface{}) + if !ok { + t.Fatalf("point = %#v, want map", body["point"]) + } + if mt := point["mount_type"]; mt != float64(1) && mt != 1 { + t.Fatalf("mount_type = %v (%T), want 1", mt, mt) + } + + // token is injected at body top-level + if tt, _ := body["token"].(string); tt != "bascnxxxxx" { + t.Fatalf("token = %q, want %q", tt, "bascnxxxxx") + } +} + +func TestDriveImportCreateTaskBodyTargetTokenIgnoredForNonBitable(t *testing.T) { + t.Parallel() + + spec := driveImportSpec{ + FilePath: "/tmp/data.xlsx", + DocType: "sheet", + TargetToken: "bascnxxxxx", + FolderToken: "fld_test", + } + + body := spec.CreateTaskBody("file_token_test") + point, ok := body["point"].(map[string]interface{}) + if !ok { + t.Fatalf("point = %#v, want map", body["point"]) + } + + // Non-bitable should use default folder mount (type=1), ignoring TargetToken + if mt := point["mount_type"]; mt != float64(1) && mt != 1 { + t.Fatalf("mount_type = %v (%T), want 1 (folder mount)", mt, mt) + } + if _, exists := point["target_token"]; exists { + t.Fatal("target_token should not be present for non-bitable type") + } +} + +func TestDriveImportDryRunWithTargetToken(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + if err := os.WriteFile("data.xlsx", []byte("fake-xlsx"), 0644); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cmd := &cobra.Command{Use: "drive +import"} + cmd.Flags().String("file", "", "") + cmd.Flags().String("type", "", "") + cmd.Flags().String("folder-token", "", "") + cmd.Flags().String("name", "", "") + cmd.Flags().String("target-token", "", "") + if err := cmd.Flags().Set("file", "./data.xlsx"); err != nil { + t.Fatalf("set --file: %v", err) + } + if err := cmd.Flags().Set("type", "bitable"); err != nil { + t.Fatalf("set --type: %v", err) + } + if err := cmd.Flags().Set("target-token", "bascntarget123"); err != nil { + t.Fatalf("set --target-token: %v", err) + } + + runtime := common.TestNewRuntimeContextWithCtx(context.Background(), cmd, nil) + dry := DriveImport.DryRun(context.Background(), runtime) + if dry == nil { + t.Fatal("DryRun returned nil") + } + + data, err := json.Marshal(dry) + if err != nil { + t.Fatalf("marshal dry run: %v", err) + } + + var got struct { + API []struct { + URL string `json:"url"` + Body map[string]interface{} `json:"body"` + } `json:"api"` + } + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal dry run json: %v", err) + } + if len(got.API) != 3 { + t.Fatalf("expected 3 API calls, got %d", len(got.API)) + } + + // The import task body (API[1]) should contain target_token in point + importTaskBody := got.API[1].Body + point, ok := importTaskBody["point"].(map[string]interface{}) + if !ok { + t.Fatalf("point = %#v, want map", importTaskBody["point"]) + } + if mt := point["mount_type"]; mt != float64(1) && mt != 1 { + t.Fatalf("dry-run mount_type = %v (%T), want 1 (unchanged)", mt, mt) + } + if tt, _ := importTaskBody["token"].(string); tt != "bascntarget123" { + t.Fatalf("dry-run token = %q, want %q", tt, "bascntarget123") + } +} + +func TestDriveImportDryRunTargetTokenRejectedForSheet(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + if err := os.WriteFile("data.xlsx", []byte("fake-xlsx"), 0644); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cmd := &cobra.Command{Use: "drive +import"} + cmd.Flags().String("file", "", "") + cmd.Flags().String("type", "", "") + cmd.Flags().String("folder-token", "", "") + cmd.Flags().String("name", "", "") + cmd.Flags().String("target-token", "", "") + if err := cmd.Flags().Set("file", "./data.xlsx"); err != nil { + t.Fatalf("set --file: %v", err) + } + if err := cmd.Flags().Set("type", "sheet"); err != nil { + t.Fatalf("set --type: %v", err) + } + if err := cmd.Flags().Set("target-token", "bascnxxx"); err != nil { + t.Fatalf("set --target-token: %v", err) + } + + runtime := common.TestNewRuntimeContextWithCtx(context.Background(), cmd, nil) + dry := DriveImport.DryRun(context.Background(), runtime) + if dry == nil { + t.Fatal("DryRun returned nil") + } + + data, err := json.Marshal(dry) + if err != nil { + t.Fatalf("marshal dry run: %v", err) + } + + var got struct { + Error string `json:"error"` + } + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got.Error == "" || !strings.Contains(got.Error, "--target-token is only supported when --type is bitable") { + t.Fatalf("dry-run error = %q, want target-token validation error", got.Error) + } +} + // driveImportMockEnv mounts the three stubs needed for a full +import run: // media upload_all -> import_tasks (create) -> import_tasks/ (poll). // Returns nothing; caller asserts on stdout via decodeDriveEnvelope. diff --git a/shortcuts/drive/drive_inspect.go b/shortcuts/drive/drive_inspect.go new file mode 100644 index 000000000..7941d6b48 --- /dev/null +++ b/shortcuts/drive/drive_inspect.go @@ -0,0 +1,183 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package drive + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/shortcuts/common" +) + +var DriveInspect = common.Shortcut{ + Service: "drive", + Command: "+inspect", + Description: "Inspect a Lark document URL to get its type, title, and canonical token (with wiki unwrapping)", + Risk: "read", + Scopes: []string{"drive:drive.metadata:readonly"}, + ConditionalScopes: []string{"wiki:node:retrieve"}, + AuthTypes: []string{"user", "bot"}, + HasFormat: true, + Flags: []common.Flag{ + { + Name: "url", + Desc: "Lark/Feishu document URL (docx, doc, sheet, bitable, wiki, file, folder, mindnote, slides)", + Required: true, + }, + { + Name: "type", + Desc: "document type (required when --url is a bare token; auto-detected for URLs)", + Enum: []string{"doc", "docx", "sheet", "bitable", "wiki", "file", "folder", "mindnote", "slides"}, + }, + }, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + raw := strings.TrimSpace(runtime.Str("url")) + if raw == "" { + return output.ErrValidation("--url cannot be empty") + } + + _, ok := common.ParseResourceURL(raw) + if !ok { + // Not a recognized URL pattern. + if strings.Contains(raw, "://") { + return output.ErrValidation("unsupported --url %q: use a recognized Lark document URL or a bare token with --type", raw) + } + // Bare token: --type is required. + if strings.TrimSpace(runtime.Str("type")) == "" { + return output.ErrValidation("--type is required when --url is a bare token (allowed: doc, docx, sheet, bitable, wiki, file, folder, mindnote, slides)") + } + } + return nil + }, + DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + raw := strings.TrimSpace(runtime.Str("url")) + ref, ok := common.ParseResourceURL(raw) + if !ok { + ref = common.ResourceRef{ + Type: strings.TrimSpace(runtime.Str("type")), + Token: raw, + } + } + + dry := common.NewDryRunAPI() + + if ref.Type == "wiki" { + dry.Desc("2-step: inspect wiki node, then batch query metadata") + dry.GET("/open-apis/wiki/v2/spaces/get_node"). + Desc("[1] Inspect wiki node to get underlying document"). + Params(map[string]interface{}{"token": ref.Token}) + dry.POST("/open-apis/drive/v1/metas/batch_query"). + Desc("[2] Batch query document metadata (title)"). + Body(map[string]interface{}{ + "request_docs": []map[string]interface{}{ + {"doc_token": "", "doc_type": ""}, + }, + }) + return dry + } + + dry.Desc("1-step: batch query document metadata") + dry.POST("/open-apis/drive/v1/metas/batch_query"). + Body(map[string]interface{}{ + "request_docs": []map[string]interface{}{ + {"doc_token": ref.Token, "doc_type": ref.Type}, + }, + }) + return dry + }, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + raw := strings.TrimSpace(runtime.Str("url")) + + // Step 1: Parse URL to extract {type, token}. + ref, ok := common.ParseResourceURL(raw) + if !ok { + // Bare token: use --type. + ref = common.ResourceRef{ + Type: strings.TrimSpace(runtime.Str("type")), + Token: raw, + } + } + + inputURL := raw + docType := ref.Type + docToken := ref.Token + + var wikiNode map[string]interface{} + + // Step 2: If type is "wiki", unwrap via get_node API. + if docType == "wiki" { + fmt.Fprintf(runtime.IO().ErrOut, "Inspecting wiki node: %s\n", common.MaskToken(docToken)) + data, err := runtime.CallAPI( + "GET", + "/open-apis/wiki/v2/spaces/get_node", + map[string]interface{}{"token": docToken}, + nil, + ) + if err != nil { + return err + } + + node := common.GetMap(data, "node") + objType := common.GetString(node, "obj_type") + objToken := common.GetString(node, "obj_token") + spaceID := common.GetString(node, "space_id") + nodeToken := common.GetString(node, "node_token") + + if objType == "" || objToken == "" { + return output.Errorf(output.ExitAPI, "api_error", "wiki get_node returned incomplete node data (obj_type=%q, obj_token=%q)", objType, objToken) + } + + wikiNode = map[string]interface{}{ + "space_id": spaceID, + "node_token": nodeToken, + "obj_token": objToken, + "obj_type": objType, + } + + docType = objType + docToken = objToken + + fmt.Fprintf(runtime.IO().ErrOut, "Wiki unwrapped to %s: %s\n", docType, common.MaskToken(docToken)) + } + + // Step 3: Call batch_query to verify and get title. + title, err := common.FetchDriveMetaTitle(runtime, docToken, docType) + if err != nil { + return err + } + + // Step 4: Build the resolved URL. + resolvedURL := common.BuildResourceURL(runtime.Config.Brand, docType, docToken) + + // Step 5: Build output. + result := map[string]interface{}{ + "input_url": inputURL, + "type": docType, + "title": title, + "token": docToken, + "url": resolvedURL, + } + if wikiNode != nil { + result["wiki_node"] = wikiNode + } + + runtime.OutFormat(result, nil, func(w io.Writer) { + fmt.Fprintf(w, "Type: %s\n", docType) + if title != "" { + fmt.Fprintf(w, "Title: %s\n", title) + } + fmt.Fprintf(w, "Token: %s\n", docToken) + if resolvedURL != "" { + fmt.Fprintf(w, "URL: %s\n", resolvedURL) + } + if wikiNode != nil { + fmt.Fprintf(w, "Wiki: space_id=%s, node_token=%s\n", wikiNode["space_id"], wikiNode["node_token"]) + } + }) + return nil + }, +} diff --git a/shortcuts/drive/drive_inspect_test.go b/shortcuts/drive/drive_inspect_test.go new file mode 100644 index 000000000..a40cc399d --- /dev/null +++ b/shortcuts/drive/drive_inspect_test.go @@ -0,0 +1,466 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package drive + +import ( + "context" + "encoding/json" + "testing" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/httpmock" + "github.com/larksuite/cli/shortcuts/common" +) + +// --- Validate tests --- + +func TestDriveInspectValidate_EmptyURL(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + err := DriveInspect.Validate(context.Background(), runtime) + if err == nil { + t.Fatal("expected error for empty --url, got nil") + } +} + +func TestDriveInspectValidate_UnsupportedURL(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + _ = cmd.Flags().Set("url", "https://google.com/some/page") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + err := DriveInspect.Validate(context.Background(), runtime) + if err == nil { + t.Fatal("expected error for unsupported URL, got nil") + } +} + +func TestDriveInspectValidate_NonLarkHostWithLarkPath(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + _ = cmd.Flags().Set("url", "https://google.com/docx/doxcnLooksValid") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + err := DriveInspect.Validate(context.Background(), runtime) + if err != nil { + t.Fatalf("expected no error for non-Lark host with Lark-like path (host validation removed), got %v", err) + } +} + +func TestDriveInspectValidate_BareTokenWithoutType(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + _ = cmd.Flags().Set("url", "doxcnBareToken") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + err := DriveInspect.Validate(context.Background(), runtime) + if err == nil { + t.Fatal("expected error for bare token without --type, got nil") + } +} + +func TestDriveInspectValidate_BareTokenWithType(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + _ = cmd.Flags().Set("url", "doxcnBareToken") + _ = cmd.Flags().Set("type", "docx") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + err := DriveInspect.Validate(context.Background(), runtime) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestDriveInspectValidate_ValidDocxURL(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + _ = cmd.Flags().Set("url", "https://xxx.feishu.cn/docx/doxcnABC") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + err := DriveInspect.Validate(context.Background(), runtime) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestDriveInspectValidate_ValidWikiURL(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + _ = cmd.Flags().Set("url", "https://xxx.feishu.cn/wiki/wikcnABC") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + err := DriveInspect.Validate(context.Background(), runtime) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +// --- DryRun tests --- + +func TestDriveInspectDryRun_DocxURL(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + _ = cmd.Flags().Set("url", "https://xxx.feishu.cn/docx/doxcnABC") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + dry := DriveInspect.DryRun(context.Background(), runtime) + if dry == nil { + t.Fatal("DryRun returned nil") + } + + data, err := json.Marshal(dry) + if err != nil { + t.Fatalf("marshal dry run: %v", err) + } + + var got struct { + API []struct { + URL string `json:"url"` + Method string `json:"method"` + Body map[string]interface{} `json:"body"` + } `json:"api"` + } + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal dry run: %v", err) + } + if len(got.API) != 1 { + t.Fatalf("expected 1 API step, got %d", len(got.API)) + } + if got.API[0].URL != "/open-apis/drive/v1/metas/batch_query" { + t.Errorf("API URL = %q, want /open-apis/drive/v1/metas/batch_query", got.API[0].URL) + } + // Verify body contains request_docs with the correct token and type. + reqDocs, ok := got.API[0].Body["request_docs"].([]interface{}) + if !ok || len(reqDocs) != 1 { + t.Fatalf("expected request_docs with 1 entry, got %v", got.API[0].Body["request_docs"]) + } + doc, _ := reqDocs[0].(map[string]interface{}) + if doc["doc_token"] != "doxcnABC" { + t.Errorf("doc_token = %v, want doxcnABC", doc["doc_token"]) + } + if doc["doc_type"] != "docx" { + t.Errorf("doc_type = %v, want docx", doc["doc_type"]) + } +} + +func TestDriveInspectDryRun_WikiURL(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + _ = cmd.Flags().Set("url", "https://xxx.feishu.cn/wiki/wikcnABC") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + dry := DriveInspect.DryRun(context.Background(), runtime) + if dry == nil { + t.Fatal("DryRun returned nil") + } + + data, err := json.Marshal(dry) + if err != nil { + t.Fatalf("marshal dry run: %v", err) + } + + var got struct { + API []struct { + URL string `json:"url"` + Params map[string]interface{} `json:"params"` + Body map[string]interface{} `json:"body"` + } `json:"api"` + } + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal dry run: %v", err) + } + if len(got.API) != 2 { + t.Fatalf("expected 2 API steps, got %d", len(got.API)) + } + if got.API[0].URL != "/open-apis/wiki/v2/spaces/get_node" { + t.Errorf("step 1 URL = %q, want /open-apis/wiki/v2/spaces/get_node", got.API[0].URL) + } + // Verify step 1 params contain the wiki token. + if got.API[0].Params["token"] != "wikcnABC" { + t.Errorf("step 1 params.token = %v, want wikcnABC", got.API[0].Params["token"]) + } + if got.API[1].URL != "/open-apis/drive/v1/metas/batch_query" { + t.Errorf("step 2 URL = %q, want /open-apis/drive/v1/metas/batch_query", got.API[1].URL) + } + // Verify step 2 body contains request_docs placeholder. + if got.API[1].Body["request_docs"] == nil { + t.Error("step 2 body should contain request_docs") + } +} + +func TestDriveInspectDryRun_BareTokenWithType(t *testing.T) { + cmd := &cobra.Command{Use: "drive +inspect"} + cmd.Flags().String("url", "", "") + cmd.Flags().String("type", "", "") + _ = cmd.Flags().Set("url", "doxcnBareToken") + _ = cmd.Flags().Set("type", "docx") + + runtime := common.TestNewRuntimeContext(cmd, &core.CliConfig{}) + dry := DriveInspect.DryRun(context.Background(), runtime) + if dry == nil { + t.Fatal("DryRun returned nil") + } + + data, err := json.Marshal(dry) + if err != nil { + t.Fatalf("marshal dry run: %v", err) + } + + var got struct { + API []struct { + URL string `json:"url"` + } `json:"api"` + } + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal dry run: %v", err) + } + if len(got.API) != 1 { + t.Fatalf("expected 1 API step, got %d", len(got.API)) + } +} + +// --- Execute tests --- + +func TestDriveInspectExecute_DocxURL(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + cfg := driveTestConfig() + f, stdout, _, reg := cmdutil.TestFactory(t, cfg) + + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/metas/batch_query", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "metas": []map[string]interface{}{ + {"doc_token": "doxcnABC", "doc_type": "docx", "title": "Test Doc"}, + }, + }, + }, + }) + + err := mountAndRunDrive(t, DriveInspect, []string{ + "+inspect", + "--url", "https://xxx.feishu.cn/docx/doxcnABC", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data := decodeDriveEnvelope(t, stdout) + if data["type"] != "docx" { + t.Errorf("type = %v, want docx", data["type"]) + } + if data["token"] != "doxcnABC" { + t.Errorf("token = %v, want doxcnABC", data["token"]) + } + if data["title"] != "Test Doc" { + t.Errorf("title = %v, want Test Doc", data["title"]) + } + if _, ok := data["wiki_node"]; ok { + t.Error("wiki_node should not be present for non-wiki URL") + } +} + +func TestDriveInspectExecute_WikiURL(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + cfg := driveTestConfig() + f, stdout, _, reg := cmdutil.TestFactory(t, cfg) + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/wiki/v2/spaces/get_node", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "node": map[string]interface{}{ + "obj_type": "docx", + "obj_token": "doxcnUnwrapped", + "space_id": "space123", + "node_token": "wikcnNodeToken", + "title": "Wiki Doc", + "node_type": "origin", + }, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/metas/batch_query", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "metas": []map[string]interface{}{ + {"doc_token": "doxcnUnwrapped", "doc_type": "docx", "title": "Wiki Doc"}, + }, + }, + }, + }) + + err := mountAndRunDrive(t, DriveInspect, []string{ + "+inspect", + "--url", "https://xxx.feishu.cn/wiki/wikcnABC", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data := decodeDriveEnvelope(t, stdout) + if data["type"] != "docx" { + t.Errorf("type = %v, want docx (unwrapped from wiki)", data["type"]) + } + if data["token"] != "doxcnUnwrapped" { + t.Errorf("token = %v, want doxcnUnwrapped", data["token"]) + } + if data["title"] != "Wiki Doc" { + t.Errorf("title = %v, want Wiki Doc", data["title"]) + } + wikiNode, ok := data["wiki_node"].(map[string]interface{}) + if !ok { + t.Fatal("wiki_node should be present for wiki URL") + } + if wikiNode["space_id"] != "space123" { + t.Errorf("wiki_node.space_id = %v, want space123", wikiNode["space_id"]) + } +} + +func TestDriveInspectExecute_WikiGetNodeIncompleteData(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + cfg := driveTestConfig() + f, stdout, _, reg := cmdutil.TestFactory(t, cfg) + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/wiki/v2/spaces/get_node", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "node": map[string]interface{}{ + "obj_type": "", + "obj_token": "", + }, + }, + }, + }) + + err := mountAndRunDrive(t, DriveInspect, []string{ + "+inspect", + "--url", "https://xxx.feishu.cn/wiki/wikcnABC", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatal("expected error for incomplete wiki node data, got nil") + } +} + +func TestDriveInspectExecute_BareTokenWithType(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + cfg := driveTestConfig() + f, stdout, _, reg := cmdutil.TestFactory(t, cfg) + + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/metas/batch_query", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "metas": []map[string]interface{}{ + {"doc_token": "doxcnBare", "doc_type": "docx", "title": "Bare Doc"}, + }, + }, + }, + }) + + err := mountAndRunDrive(t, DriveInspect, []string{ + "+inspect", + "--url", "doxcnBare", + "--type", "docx", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data := decodeDriveEnvelope(t, stdout) + if data["type"] != "docx" { + t.Errorf("type = %v, want docx", data["type"]) + } + if data["token"] != "doxcnBare" { + t.Errorf("token = %v, want doxcnBare", data["token"]) + } +} + +func TestDriveInspectExecute_BatchQueryError(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + cfg := driveTestConfig() + f, stdout, _, reg := cmdutil.TestFactory(t, cfg) + + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/metas/batch_query", + Body: map[string]interface{}{ + "code": 99991668, + "msg": "permission denied", + }, + }) + + err := mountAndRunDrive(t, DriveInspect, []string{ + "+inspect", + "--url", "https://xxx.feishu.cn/docx/doxcnABC", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatal("expected error for batch_query failure, got nil") + } +} + +func TestDriveInspectExecute_PrettyFormat(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + cfg := driveTestConfig() + f, stdout, _, reg := cmdutil.TestFactory(t, cfg) + + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/metas/batch_query", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "metas": []map[string]interface{}{ + {"doc_token": "doxcnABC", "doc_type": "docx", "title": "Test Doc"}, + }, + }, + }, + }) + + err := mountAndRunDrive(t, DriveInspect, []string{ + "+inspect", + "--url", "https://xxx.feishu.cn/docx/doxcnABC", + "--format", "pretty", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Pretty format outputs to stdout as text, not JSON envelope. + // Just verify it didn't error. + _ = stdout +} diff --git a/shortcuts/drive/drive_push.go b/shortcuts/drive/drive_push.go index bc790653c..78c34e67c 100644 --- a/shortcuts/drive/drive_push.go +++ b/shortcuts/drive/drive_push.go @@ -275,7 +275,14 @@ var DrivePush = common.Shortcut{ skipped++ continue } - token, version, upErr := drivePushUploadFile(ctx, runtime, localFile, entry.FileToken, folderToken) + parentToken, parentErr := drivePushEnsureParentToken(ctx, runtime, folderToken, rel, folderCache) + if parentErr != nil { + items = append(items, drivePushItem{RelPath: rel, FileToken: entry.FileToken, Action: "failed", SizeBytes: localFile.Size, Error: parentErr.Error()}) + failed++ + uploadFailed = true + continue + } + token, version, upErr := drivePushUploadFile(ctx, runtime, localFile, entry.FileToken, parentToken) if upErr != nil { // Token contract on overwrite failure: an in-place // overwrite preserves the file's token, so the @@ -580,6 +587,10 @@ func drivePushEnsureFolder(ctx context.Context, runtime *common.RuntimeContext, return token, nil } +func drivePushEnsureParentToken(ctx context.Context, runtime *common.RuntimeContext, rootFolderToken, relPath string, folderCache map[string]string) (string, error) { + return drivePushEnsureFolder(ctx, runtime, rootFolderToken, drivePushParentRel(relPath), folderCache) +} + // drivePushUploadFile uploads (or overwrites) a single local file. When // existingToken is non-empty, the request adds the file_token form field to // trigger overwrite-with-version semantics on the backend; the response is diff --git a/shortcuts/drive/drive_push_test.go b/shortcuts/drive/drive_push_test.go index ec71e4bfa..3d5654ca2 100644 --- a/shortcuts/drive/drive_push_test.go +++ b/shortcuts/drive/drive_push_test.go @@ -1296,6 +1296,130 @@ func TestDrivePushReusesExistingRemoteFolder(t *testing.T) { } } +// TestDrivePushOverwriteNestedFileUsesParentFolderToken verifies that +// overwriting an existing nested remote file keeps parent_node aligned with +// the file's actual parent folder instead of the root folder token. +func TestDrivePushOverwriteNestedFileUsesParentFolderToken(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll(filepath.Join("local", "sub"), 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(filepath.Join("local", "sub", "keep.txt"), []byte("local"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "fld_existing_sub", "name": "sub", "type": "folder"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=fld_existing_sub", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_keep_nested", "name": "keep.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + uploadStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "file_token": "tok_keep_nested", + "version": "v2", + }, + }, + } + reg.Register(uploadStub) + + err := mountAndRunDrive(t, DrivePush, []string{ + "+push", + "--local-dir", "local", + "--folder-token", "folder_root", + "--if-exists", "overwrite", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + body := decodeDriveMultipartBody(t, uploadStub) + if got := body.Fields["file_token"]; got != "tok_keep_nested" { + t.Fatalf("upload_all file_token = %q, want tok_keep_nested", got) + } + if got := body.Fields["parent_node"]; got != "fld_existing_sub" { + t.Fatalf("upload_all parent_node = %q, want fld_existing_sub", got) + } +} + +func TestDrivePushOverwriteNestedFileReportsParentEnsureFailure(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll(filepath.Join("local", "sub"), 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(filepath.Join("local", "sub", "keep.txt"), []byte("local"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_keep_nested", "name": "sub/keep.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/create_folder", + Body: map[string]interface{}{ + "code": 9999, + "msg": "create parent failed", + }, + }) + + err := mountAndRunDrive(t, DrivePush, []string{ + "+push", + "--local-dir", "local", + "--folder-token", "folder_root", + "--if-exists", "overwrite", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected parent ensure failure\nstdout: %s", stdout.String()) + } + if !strings.Contains(stdout.String(), `"action": "failed"`) || !strings.Contains(stdout.String(), "create parent failed") { + t.Fatalf("expected failed item with create_folder error, got: %s", stdout.String()) + } +} + // TestDrivePushMirrorsEmptyDirectories confirms the gap codex review // flagged: a local directory with no files inside must still surface on // Drive as a created sub-folder, not be silently dropped because the diff --git a/shortcuts/drive/drive_search.go b/shortcuts/drive/drive_search.go index 4f34f49f3..f71be3478 100644 --- a/shortcuts/drive/drive_search.go +++ b/shortcuts/drive/drive_search.go @@ -77,8 +77,8 @@ var DriveSearch = common.Shortcut{ Flags: []common.Flag{ {Name: "query", Desc: "search keyword (may be empty to browse by filter only)"}, - {Name: "mine", Type: "bool", Desc: "restrict to docs I created (uses current user's open_id)"}, - {Name: "creator-ids", Desc: "comma-separated creator open_ids; mutually exclusive with --mine"}, + {Name: "mine", Type: "bool", Desc: "restrict to docs I own (server-side owner semantic, NOT original creator; uses current user's open_id)"}, + {Name: "creator-ids", Desc: "comma-separated owner open_ids (API field is creator_ids but matched by owner); mutually exclusive with --mine"}, {Name: "edited-since", Desc: "start of [my edited] time window (e.g. 7d, 1m, 1y, 2026-04-01, RFC3339, unix seconds)"}, {Name: "edited-until", Desc: "end of [my edited] time window"}, @@ -108,7 +108,7 @@ var DriveSearch = common.Shortcut{ Tips: []string{ "Time flags accept relative (e.g. 7d, 1m, 1y), absolute (2026-04-01, RFC3339), or unix seconds.", "my_edit_time and my_comment_time are hour-aggregated server-side; sub-hour inputs are snapped and a notice is printed to stderr.", - "Use --mine for a quick \"docs I created\" filter. For other people, use --creator-ids ou_xxx,ou_yyy.", + "Use --mine for a quick \"docs I own\" filter (owner semantic, not original creator). For other people, use --creator-ids ou_xxx,ou_yyy.", "--folder-tokens limits to doc-only search; --space-ids limits to wiki-only. They cannot be combined.", }, DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { diff --git a/shortcuts/drive/drive_status_test.go b/shortcuts/drive/drive_status_test.go index 89c1e42fe..303aeac11 100644 --- a/shortcuts/drive/drive_status_test.go +++ b/shortcuts/drive/drive_status_test.go @@ -17,6 +17,8 @@ import ( "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/httpmock" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/shortcuts/common" + "github.com/spf13/cobra" ) // driveStatusScopedTokenResolver returns a token with caller-controlled scopes @@ -804,3 +806,59 @@ func TestDriveStatusRejectsMalformedFolderToken(t *testing.T) { t.Fatalf("error must reference --folder-token, got: %v", err) } } + +func TestWalkLocalForStatusMissingRootReturnsInternalError(t *testing.T) { + missingRoot := filepath.Join(t.TempDir(), "does-not-exist") + + _, err := walkLocalForStatus(missingRoot, t.TempDir()) + if err == nil { + t.Fatal("expected walkLocalForStatus() to fail for missing root") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected structured ExitError, got %T", err) + } + if exitErr.Detail == nil || exitErr.Detail.Type != "io" { + t.Fatalf("expected io error detail, got %#v", exitErr.Detail) + } + if !strings.Contains(err.Error(), "walk") { + t.Fatalf("expected walk-related error, got: %v", err) + } +} + +func TestHashLocalForStatusWrapsOpenError(t *testing.T) { + config := driveTestConfig() + f, _, _, _ := cmdutil.TestFactory(t, config) + runtime := common.TestNewRuntimeContext(&cobra.Command{Use: "drive"}, config) + runtime.Factory = f + + _, err := hashLocalForStatus(runtime, "missing.txt") + if err == nil { + t.Fatal("expected hashLocalForStatus() to fail for missing file") + } + if !strings.Contains(err.Error(), "missing.txt") { + t.Fatalf("expected error to mention the missing file, got: %v", err) + } +} + +func TestHashRemoteForStatusReturnsNetworkErrorWhenDownloadFails(t *testing.T) { + config := driveTestConfig() + f, _, _, _ := cmdutil.TestFactory(t, config) + runtime := common.TestNewRuntimeContextWithCtx(context.Background(), &cobra.Command{Use: "drive"}, config) + runtime.Factory = f + + _, err := hashRemoteForStatus(context.Background(), runtime, "tok_missing") + if err == nil { + t.Fatal("expected hashRemoteForStatus() to fail when the download request has no stub") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected structured ExitError, got %T", err) + } + if exitErr.Detail == nil || exitErr.Detail.Type != "network" { + t.Fatalf("expected network detail, got %#v", exitErr.Detail) + } + if !strings.Contains(err.Error(), "download") { + t.Fatalf("expected download-related error, got: %v", err) + } +} diff --git a/shortcuts/drive/drive_sync.go b/shortcuts/drive/drive_sync.go new file mode 100644 index 000000000..3c512cecf --- /dev/null +++ b/shortcuts/drive/drive_sync.go @@ -0,0 +1,650 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package drive + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/shortcuts/common" +) + +const ( + driveSyncOnConflictLocalWins = "local-wins" + driveSyncOnConflictRemoteWins = "remote-wins" + driveSyncOnConflictKeepBoth = "keep-both" + driveSyncOnConflictAsk = "ask" +) + +type driveSyncItem struct { + RelPath string `json:"rel_path"` + FileToken string `json:"file_token,omitempty"` + Action string `json:"action"` + Direction string `json:"direction,omitempty"` // "pull" or "push" + Error string `json:"error,omitempty"` +} + +// DriveSync performs a two-way sync between a local directory and a Drive +// folder. It computes a diff (like +status), then: +// - new_remote → pull (download to local) +// - new_local → push (upload to Drive) +// - modified → resolve by --on-conflict strategy: +// local-wins: push local over remote; +// remote-wins: pull remote over local; +// keep-both: rename the local file with a hash suffix and pull the remote; +// ask: prompt the user per conflict. +var DriveSync = common.Shortcut{ + Service: "drive", + Command: "+sync", + Description: "Two-way sync between a local directory and a Drive folder", + Risk: "write", + Scopes: []string{"drive:drive.metadata:readonly"}, + ConditionalScopes: []string{ + "drive:file:download", + "drive:file:upload", + "space:folder:create", + }, + AuthTypes: []string{"user", "bot"}, + Flags: []common.Flag{ + {Name: "local-dir", Desc: "local root directory (relative to cwd)", Required: true}, + {Name: "folder-token", Desc: "Drive folder token", Required: true}, + {Name: "on-conflict", Desc: "conflict resolution when both sides modified a file", Default: driveSyncOnConflictRemoteWins, Enum: []string{driveSyncOnConflictLocalWins, driveSyncOnConflictRemoteWins, driveSyncOnConflictKeepBoth, driveSyncOnConflictAsk}}, + {Name: "on-duplicate-remote", Desc: "policy when multiple remote Drive entries map to the same rel_path", Default: driveDuplicateRemoteFail, Enum: []string{driveDuplicateRemoteFail, driveDuplicateRemoteNewest, driveDuplicateRemoteOldest}}, + {Name: "quick", Type: "bool", Desc: "use best-effort modified_time comparison instead of SHA-256 hash; mismatched timestamps can still trigger real sync writes"}, + }, + Tips: []string{ + "Two-way sync: new remote files are pulled, new local files are pushed, and conflicts (both sides modified) are resolved by --on-conflict.", + "Default --on-conflict=remote-wins pulls the remote version when both sides changed a file. Use local-wins to push instead, keep-both to rename and keep both copies, or ask for interactive resolution.", + "Pass --quick for faster best-effort diff detection using modified_time instead of SHA-256 hash (no remote file downloads needed during diffing).", + "Because +sync acts on the diff, --quick can still pull, overwrite, or rename files when timestamps differ even if file contents are actually unchanged.", + "Only entries with type=file are synced; online docs (docx, sheet, bitable, mindnote, slides) and shortcuts are skipped.", + }, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + localDir := strings.TrimSpace(runtime.Str("local-dir")) + folderToken := strings.TrimSpace(runtime.Str("folder-token")) + if localDir == "" { + return common.FlagErrorf("--local-dir is required") + } + if folderToken == "" { + return common.FlagErrorf("--folder-token is required") + } + if err := validate.ResourceName(folderToken, "--folder-token"); err != nil { + return output.ErrValidation("%s", err) + } + if _, err := validate.SafeLocalFlagPath("--local-dir", localDir); err != nil { + return output.ErrValidation("%s", err) + } + info, err := runtime.FileIO().Stat(localDir) + if err != nil { + return common.WrapInputStatError(err) + } + if !info.IsDir() { + return output.ErrValidation("--local-dir is not a directory: %s", localDir) + } + return nil + }, + DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + return common.NewDryRunAPI(). + Desc("Compute diff between --local-dir and --folder-token, then pull new/modified-remote files, push new/modified-local files, and resolve conflicts by --on-conflict strategy."). + GET("/open-apis/drive/v1/files"). + Set("folder_token", runtime.Str("folder-token")) + }, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + localDir := strings.TrimSpace(runtime.Str("local-dir")) + folderToken := strings.TrimSpace(runtime.Str("folder-token")) + onConflict := strings.TrimSpace(runtime.Str("on-conflict")) + if onConflict == "" { + onConflict = driveSyncOnConflictRemoteWins + } + duplicateRemote := strings.TrimSpace(runtime.Str("on-duplicate-remote")) + if duplicateRemote == "" { + duplicateRemote = driveDuplicateRemoteFail + } + quick := runtime.Bool("quick") + if !quick { + if err := runtime.EnsureScopes([]string{"drive:file:download"}); err != nil { + return err + } + } + + safeRoot, err := validate.SafeInputPath(localDir) + if err != nil { + return output.ErrValidation("--local-dir: %s", err) + } + cwdCanonical, err := validate.SafeInputPath(".") + if err != nil { + return output.ErrValidation("could not resolve cwd: %s", err) + } + rootRelToCwd, err := filepath.Rel(cwdCanonical, safeRoot) + if err != nil { + return output.ErrValidation("--local-dir resolves outside cwd: %s", err) + } + + // --- Phase 1: Compute diff (same logic as +status) --- + fmt.Fprintf(runtime.IO().ErrOut, "Walking local: %s\n", localDir) + localFiles, err := walkLocalForStatus(safeRoot, cwdCanonical) + if err != nil { + return err + } + + fmt.Fprintf(runtime.IO().ErrOut, "Listing Drive folder: %s\n", common.MaskToken(folderToken)) + entries, err := listRemoteFolderEntries(ctx, runtime, folderToken, "") + if err != nil { + return err + } + if duplicates := blockingRemotePathConflicts(entries, duplicateRemote); len(duplicates) > 0 { + return duplicateRemotePathError(duplicates) + } + + // A local regular file at the same rel_path as a remote + // folder/docx/shortcut is a type conflict: +sync would + // classify it as new_local and attempt to upload, which either + // fails at the API or leaves the remote in a broken state + // (same rel_path with mixed types). Detect early and hard-fail. + // Symmetrically, a local directory at the same rel_path as a + // remote file/docx/shortcut would attempt create_folder and + // produce the same broken mixed-type state. + var typeConflicts []string + for _, entry := range entries { + if entry.Type == driveTypeFile { + continue + } + if _, hasLocal := localFiles[entry.RelPath]; hasLocal { + typeConflicts = append(typeConflicts, fmt.Sprintf("%q: local file vs remote %s", entry.RelPath, entry.Type)) + } + } + // Check local directories vs remote non-folder entries. + // localDirs is not available yet (walked later), so check + // the filesystem directly for the subset of remote paths + // that are non-folder. + for _, entry := range entries { + if entry.Type == driveTypeFolder { + continue + } + dirPath := filepath.Join(safeRoot, filepath.FromSlash(entry.RelPath)) + if info, err := os.Stat(dirPath); err == nil && info.IsDir() { //nolint:forbidigo // shortcuts cannot import internal/vfs (depguard rule shortcuts-no-vfs); safeRoot is validated. + typeConflicts = append(typeConflicts, fmt.Sprintf("%q: local directory vs remote %s", entry.RelPath, entry.Type)) + } + } + if len(typeConflicts) > 0 { + return output.ErrValidation("+sync cannot proceed: path type conflict — %s; remove the local entry or the remote entry and retry", strings.Join(typeConflicts, "; ")) + } + + // Build the exact remote-file views that later execution will use so the + // diff phase classifies files against the same duplicate-resolution choice. + pullRemoteFiles, _, err := drivePullRemoteViews(entries, duplicateRemote) + if err != nil { + return output.Errorf(output.ExitInternal, "internal", "%s", err) + } + remoteEntriesForPush, remoteFolders, _, err := drivePushRemoteViews(entries, duplicateRemote) + if err != nil { + return output.Errorf(output.ExitInternal, "internal", "%s", err) + } + + remoteFiles := driveSyncStatusRemoteFiles(pullRemoteFiles) + + paths := mergeStatusPaths(localFiles, remoteFiles) + + var newLocal, newRemote, modified []driveStatusEntry + var unchanged []driveStatusEntry + for _, relPath := range paths { + localFile, hasLocal := localFiles[relPath] + remoteFile, hasRemote := remoteFiles[relPath] + switch { + case hasLocal && !hasRemote: + newLocal = append(newLocal, driveStatusEntry{RelPath: relPath}) + case !hasLocal && hasRemote: + newRemote = append(newRemote, driveStatusEntry{RelPath: relPath, FileToken: remoteFile.FileToken}) + default: + entry := driveStatusEntry{RelPath: relPath, FileToken: remoteFile.FileToken} + if quick { + if driveStatusShouldTreatAsUnchangedQuick(remoteFile.ModifiedTime, localFile.ModTime) { + unchanged = append(unchanged, entry) + } else { + modified = append(modified, entry) + } + continue + } + localHash, err := hashLocalForStatus(runtime, localFile.PathToCwd) + if err != nil { + return err + } + remoteHash, err := hashRemoteForStatus(ctx, runtime, remoteFile.FileToken) + if err != nil { + return err + } + if localHash == remoteHash { + unchanged = append(unchanged, entry) + } else { + modified = append(modified, entry) + } + } + } + + detection := driveStatusDetectionExact + if quick { + detection = driveStatusDetectionQuick + } + + fmt.Fprintf(runtime.IO().ErrOut, "Diff: %d new_local, %d new_remote, %d modified, %d unchanged (detection=%s)\n", + len(newLocal), len(newRemote), len(modified), len(unchanged), detection) + + conflictResolutions := make(map[string]string, len(modified)) + if onConflict == driveSyncOnConflictAsk && len(modified) > 0 && runtime.IO().In == nil { + return output.ErrValidation("--on-conflict=ask requires interactive stdin when modified files exist") + } + for _, entry := range modified { + resolved := onConflict + if resolved == driveSyncOnConflictAsk { + resolved, err = driveSyncAskConflict(entry.RelPath, runtime) + if err != nil { + payload := map[string]interface{}{ + "detection": detection, + "diff": map[string]interface{}{ + "new_local": emptyIfNil(newLocal), + "new_remote": emptyIfNil(newRemote), + "modified": emptyIfNil(modified), + "unchanged": emptyIfNil(unchanged), + }, + "summary": map[string]interface{}{ + "pulled": 0, + "pushed": 0, + "skipped": 0, + "failed": 1, + }, + "items": []driveSyncItem{{ + RelPath: entry.RelPath, + FileToken: entry.FileToken, + Action: "failed", + Direction: "conflict", + Error: err.Error(), + }}, + } + return &output.ExitError{ + Code: output.ExitAPI, + Detail: &output.ErrDetail{ + Type: "partial_failure", + Message: fmt.Sprintf("cannot collect conflict decisions for +sync: %v", err), + Detail: payload, + }, + } + } + } + conflictResolutions[entry.RelPath] = resolved + } + + // --- Phase 2: Execute sync operations --- + var pulled, pushed, skipped, failed int + items := make([]driveSyncItem, 0) + + if quick && driveSyncNeedsDownloadScope(newRemote, modified, conflictResolutions) { + if err := runtime.EnsureScopes([]string{"drive:file:download"}); err != nil { + return err + } + } + plannedUploads := driveSyncPlannedUploadPaths(newLocal, modified, conflictResolutions) + if len(plannedUploads) > 0 { + if err := runtime.EnsureScopes([]string{"drive:file:upload"}); err != nil { + return err + } + } + + // Build push infrastructure: local walk for push + remote views + folder cache. + folderCache := map[string]string{"": folderToken} + for relDir, entry := range remoteFolders { + folderCache[relDir] = entry.FileToken + } + + // Walk local filesystem early so we can include empty directories + // in the scope preflight (they also need space:folder:create). + pushLocalFiles, localDirs, err := drivePushWalkLocal(safeRoot, cwdCanonical) + if err != nil { + return err + } + + if driveSyncNeedsCreateScope(plannedUploads, localDirs, folderCache) { + if err := runtime.EnsureScopes([]string{"space:folder:create"}); err != nil { + return err + } + } + + // Mirror local directory structure first (same as +push), so + // empty local directories are not silently dropped. + for _, relDir := range localDirs { + if _, alreadyRemote := folderCache[relDir]; alreadyRemote { + continue + } + if _, ensureErr := drivePushEnsureFolder(ctx, runtime, folderToken, relDir, folderCache); ensureErr != nil { + items = append(items, driveSyncItem{RelPath: relDir, Action: "failed", Direction: "push", Error: ensureErr.Error()}) + failed++ + continue + } + items = append(items, driveSyncItem{RelPath: relDir, FileToken: folderCache[relDir], Action: "folder_created", Direction: "push"}) + pushed++ + } + + // 2a. Pull new_remote files. + for _, entry := range newRemote { + targetFile, ok := pullRemoteFiles[entry.RelPath] + if !ok { + // Non-file type (doc, shortcut, etc.) — skip. + continue + } + target := filepath.Join(rootRelToCwd, entry.RelPath) + if err := drivePullDownload(ctx, runtime, targetFile.DownloadToken, target, targetFile.ModifiedTime); err != nil { + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: entry.FileToken, Action: "failed", Direction: "pull", Error: err.Error()}) + failed++ + continue + } + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: entry.FileToken, Action: "downloaded", Direction: "pull"}) + pulled++ + } + + // 2b. Push new_local files. + for _, entry := range newLocal { + localFile, ok := pushLocalFiles[entry.RelPath] + if !ok { + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "skipped", Direction: "push", Error: "local file disappeared during sync"}) + skipped++ + continue + } + parentRel := drivePushParentRel(entry.RelPath) + parentToken, ensureErr := drivePushEnsureFolder(ctx, runtime, folderToken, parentRel, folderCache) + if ensureErr != nil { + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "failed", Direction: "push", Error: ensureErr.Error()}) + failed++ + continue + } + token, _, upErr := drivePushUploadFile(ctx, runtime, localFile, "", parentToken) + if upErr != nil { + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "failed", Direction: "push", Error: upErr.Error()}) + failed++ + continue + } + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: token, Action: "uploaded", Direction: "push"}) + pushed++ + } + + // 2c. Resolve modified files by --on-conflict strategy. + for _, entry := range modified { + remoteFile := remoteFiles[entry.RelPath] + localFile, hasLocal := pushLocalFiles[entry.RelPath] + if !hasLocal { + // Should not happen — modified means both sides exist. + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "skipped", Direction: "conflict", Error: "local file disappeared during sync"}) + skipped++ + continue + } + + resolved := conflictResolutions[entry.RelPath] + if resolved == "" { + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "skipped", Direction: "conflict", Error: "user skipped"}) + skipped++ + continue + } + + switch resolved { + case driveSyncOnConflictRemoteWins: + // Pull remote over local. + targetFile, ok := pullRemoteFiles[entry.RelPath] + if !ok { + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "failed", Direction: "pull", Error: "remote file not found in pull views"}) + failed++ + continue + } + target := filepath.Join(rootRelToCwd, entry.RelPath) + if err := drivePullDownload(ctx, runtime, targetFile.DownloadToken, target, targetFile.ModifiedTime); err != nil { + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: entry.FileToken, Action: "failed", Direction: "pull", Error: err.Error()}) + failed++ + continue + } + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: entry.FileToken, Action: "downloaded", Direction: "pull"}) + pulled++ + + case driveSyncOnConflictLocalWins: + // Push local over remote. + existingToken := remoteFile.FileToken + if existingToken == "" { + if chosen, ok := remoteEntriesForPush[entry.RelPath]; ok { + existingToken = chosen.FileToken + } + } + parentToken, parentErr := drivePushEnsureFolder(ctx, runtime, folderToken, drivePushParentRel(entry.RelPath), folderCache) + if parentErr != nil { + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: existingToken, Action: "failed", Direction: "push", Error: parentErr.Error()}) + failed++ + continue + } + token, _, upErr := drivePushUploadFile(ctx, runtime, localFile, existingToken, parentToken) + if upErr != nil { + // Token contract on overwrite failure (same as +push): + // a partial-success response can return a non-empty + // file_token alongside an error. Prefer the freshly + // returned token when one was produced, fall back to + // existingToken otherwise. + failedToken := token + if failedToken == "" { + failedToken = existingToken + } + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: failedToken, Action: "failed", Direction: "push", Error: upErr.Error()}) + failed++ + continue + } + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: token, Action: "overwritten", Direction: "push"}) + pushed++ + + case driveSyncOnConflictKeepBoth: + // Rename the local file with a hash suffix, then pull the remote. + // Use the remote file token to generate a stable suffix (same + // pattern as +pull --on-duplicate-remote=rename). + occupied := occupiedRemotePaths(entries) + // Add current local paths to occupied set so the renamed + // local file doesn't collide with an existing file or directory. + for p := range pushLocalFiles { + occupied[p] = struct{}{} + } + for _, relDir := range localDirs { + occupied[relDir] = struct{}{} + } + suffixedRel, err := relPathWithUniqueFileTokenSuffix(entry.RelPath, remoteFile.FileToken, occupied) + if err != nil { + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "failed", Direction: "conflict", Error: err.Error()}) + failed++ + continue + } + // Rename the local file. + oldAbsPath := filepath.Join(safeRoot, filepath.FromSlash(entry.RelPath)) + newAbsPath := filepath.Join(safeRoot, filepath.FromSlash(suffixedRel)) + if err := os.Rename(oldAbsPath, newAbsPath); err != nil { //nolint:forbidigo // shortcuts cannot import internal/vfs (depguard rule shortcuts-no-vfs); safeRoot is validated. + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "failed", Direction: "conflict", Error: fmt.Sprintf("rename local: %s", err)}) + failed++ + continue + } + occupied[suffixedRel] = struct{}{} + // Now pull the remote version to the original path. + targetFile, ok := pullRemoteFiles[entry.RelPath] + if !ok { + rollbackErr := driveSyncRollbackRenamedLocal(oldAbsPath, newAbsPath) + errMsg := "remote file not found in pull views after rename" + if rollbackErr != nil { + errMsg += "; rollback failed: " + rollbackErr.Error() + } + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "failed", Direction: "pull", Error: errMsg}) + failed++ + continue + } + target := filepath.Join(rootRelToCwd, entry.RelPath) + if err := drivePullDownload(ctx, runtime, targetFile.DownloadToken, target, targetFile.ModifiedTime); err != nil { + rollbackErr := driveSyncRollbackRenamedLocal(oldAbsPath, newAbsPath) + errMsg := err.Error() + if rollbackErr != nil { + errMsg += "; rollback failed: " + rollbackErr.Error() + } + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: entry.FileToken, Action: "failed", Direction: "pull", Error: errMsg}) + failed++ + continue + } + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "renamed_local", Direction: "conflict"}) + items = append(items, driveSyncItem{RelPath: entry.RelPath, FileToken: entry.FileToken, Action: "downloaded", Direction: "pull"}) + pulled++ + + default: + items = append(items, driveSyncItem{RelPath: entry.RelPath, Action: "skipped", Direction: "conflict", Error: fmt.Sprintf("unknown conflict strategy: %s", resolved)}) + skipped++ + } + } + + payload := map[string]interface{}{ + "detection": detection, + "diff": map[string]interface{}{ + "new_local": emptyIfNil(newLocal), + "new_remote": emptyIfNil(newRemote), + "modified": emptyIfNil(modified), + "unchanged": emptyIfNil(unchanged), + }, + "summary": map[string]interface{}{ + "pulled": pulled, + "pushed": pushed, + "skipped": skipped, + "failed": failed, + }, + "items": items, + } + + if failed > 0 { + msg := fmt.Sprintf("%d item(s) failed during +sync", failed) + return &output.ExitError{ + Code: output.ExitAPI, + Detail: &output.ErrDetail{ + Type: "partial_failure", + Message: msg, + Detail: payload, + }, + } + } + + runtime.Out(payload, nil) + return nil + }, +} + +func driveSyncStatusRemoteFiles(pullRemoteFiles map[string]drivePullTarget) map[string]driveStatusRemoteFile { + remoteFiles := make(map[string]driveStatusRemoteFile, len(pullRemoteFiles)) + for relPath, target := range pullRemoteFiles { + fileToken := target.ItemFileToken + if fileToken == "" { + fileToken = target.DownloadToken + } + remoteFiles[relPath] = driveStatusRemoteFile{FileToken: fileToken, ModifiedTime: target.ModifiedTime} + } + return remoteFiles +} + +// driveSyncAskConflict prompts the user for a conflict resolution strategy +// for a single file. Returns the strategy string, or empty string if the +// user chose to skip. +func driveSyncAskConflict(relPath string, runtime *common.RuntimeContext) (string, error) { + fmt.Fprintf(runtime.IO().ErrOut, "CONFLICT: both sides modified %q. Choose: [R]emote-wins / [L]ocal-wins / [K]eep-both / [S]kip (default: R): ", relPath) + if runtime.IO().In == nil { + return "", output.ErrValidation("cannot resolve conflict for %q with --on-conflict=ask: stdin is not available", relPath) + } + reader, ok := runtime.IO().In.(*bufio.Reader) + if !ok { + reader = bufio.NewReader(runtime.IO().In) + runtime.IO().In = reader + } + line, err := reader.ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + return "", output.ErrValidation("cannot read conflict choice for %q: %s", relPath, err) + } + answer := strings.TrimSpace(strings.ToLower(line)) + if answer == "" { + if errors.Is(err, io.EOF) { + return "", output.ErrValidation("cannot resolve conflict for %q with --on-conflict=ask: stdin reached EOF before any choice was provided", relPath) + } + return driveSyncOnConflictRemoteWins, nil + } + switch answer { + case "l", "local", "local-wins": + return driveSyncOnConflictLocalWins, nil + case "k", "keep", "keep-both": + return driveSyncOnConflictKeepBoth, nil + case "s", "skip": + return "", nil + case "r", "remote", "remote-wins": + return driveSyncOnConflictRemoteWins, nil + default: + return "", output.ErrValidation("invalid conflict choice for %q: %q (expected one of remote/local/keep/skip)", relPath, strings.TrimSpace(line)) + } +} + +func driveSyncNeedsDownloadScope(newRemote, modified []driveStatusEntry, conflictResolutions map[string]string) bool { + if len(newRemote) > 0 { + return true + } + for _, entry := range modified { + switch conflictResolutions[entry.RelPath] { + case driveSyncOnConflictRemoteWins, driveSyncOnConflictKeepBoth: + return true + } + } + return false +} + +func driveSyncPlannedUploadPaths(newLocal, modified []driveStatusEntry, conflictResolutions map[string]string) []string { + planned := make([]string, 0, len(newLocal)+len(modified)) + for _, entry := range newLocal { + planned = append(planned, entry.RelPath) + } + for _, entry := range modified { + if conflictResolutions[entry.RelPath] == driveSyncOnConflictLocalWins { + planned = append(planned, entry.RelPath) + } + } + return planned +} + +func driveSyncNeedsCreateScope(uploadPaths []string, localDirs []string, folderCache map[string]string) bool { + for _, relPath := range uploadPaths { + parentRel := drivePushParentRel(relPath) + if parentRel == "" { + continue + } + if _, ok := folderCache[parentRel]; !ok { + return true + } + } + // Empty local directories also need create_folder if not already on Drive. + for _, relDir := range localDirs { + if _, ok := folderCache[relDir]; !ok { + return true + } + } + return false +} + +func driveSyncRollbackRenamedLocal(oldAbsPath, newAbsPath string) error { + if info, err := os.Stat(oldAbsPath); err == nil { //nolint:forbidigo // shortcuts cannot import internal/vfs (depguard rule shortcuts-no-vfs); safeRoot is validated. + if info.IsDir() { + return output.Errorf(output.ExitInternal, "rollback", "original path became a directory during rollback: %s", oldAbsPath) + } + if err := os.Remove(oldAbsPath); err != nil { //nolint:forbidigo // shortcuts cannot import internal/vfs (depguard rule shortcuts-no-vfs); safeRoot is validated. + return output.Errorf(output.ExitInternal, "rollback", "remove partial restored path %q: %s", oldAbsPath, err) + } + } else if !os.IsNotExist(err) { + return output.Errorf(output.ExitInternal, "rollback", "stat original path %q during rollback: %s", oldAbsPath, err) + } + if err := os.Rename(newAbsPath, oldAbsPath); err != nil { //nolint:forbidigo // shortcuts cannot import internal/vfs (depguard rule shortcuts-no-vfs); safeRoot is validated. + return output.Errorf(output.ExitInternal, "rollback", "restore renamed local file %q: %s", oldAbsPath, err) + } + return nil +} diff --git a/shortcuts/drive/drive_sync_test.go b/shortcuts/drive/drive_sync_test.go new file mode 100644 index 000000000..7364397eb --- /dev/null +++ b/shortcuts/drive/drive_sync_test.go @@ -0,0 +1,3097 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package drive + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/larksuite/cli/extension/fileio" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" + "github.com/larksuite/cli/internal/httpmock" + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/shortcuts/common" + "github.com/spf13/cobra" +) + +func newDriveSyncRuntime(t *testing.T, localDir, folderToken string) (*common.RuntimeContext, *cmdutil.Factory) { + t.Helper() + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + runtime := newDriveSyncRuntimeWithFactory(t, f, localDir, folderToken) + return runtime, f +} + +func newDriveSyncRuntimeWithFactory(t *testing.T, f *cmdutil.Factory, localDir, folderToken string) *common.RuntimeContext { + t.Helper() + cmd := &cobra.Command{Use: "drive +sync"} + cmd.Flags().String("local-dir", "", "") + cmd.Flags().String("folder-token", "", "") + cmd.Flags().String("on-conflict", "", "") + cmd.Flags().String("on-duplicate-remote", "", "") + cmd.Flags().Bool("quick", false, "") + if localDir != "" { + if err := cmd.Flags().Set("local-dir", localDir); err != nil { + t.Fatalf("set --local-dir: %v", err) + } + } + if folderToken != "" { + if err := cmd.Flags().Set("folder-token", folderToken); err != nil { + t.Fatalf("set --folder-token: %v", err) + } + } + runtime := common.TestNewRuntimeContextWithCtx(context.Background(), cmd, driveTestConfig()) + runtime.Factory = f + return runtime +} + +type failSaveProvider struct { + inner fileio.Provider + failSuffix string + err error +} + +func (p *failSaveProvider) Name() string { return "fail-save" } + +func (p *failSaveProvider) ResolveFileIO(ctx context.Context) fileio.FileIO { + return &failSaveFileIO{inner: p.inner.ResolveFileIO(ctx), failSuffix: p.failSuffix, err: p.err} +} + +type failSaveFileIO struct { + inner fileio.FileIO + failSuffix string + err error +} + +func (f *failSaveFileIO) Open(name string) (fileio.File, error) { return f.inner.Open(name) } +func (f *failSaveFileIO) Stat(name string) (fileio.FileInfo, error) { return f.inner.Stat(name) } +func (f *failSaveFileIO) ResolvePath(path string) (string, error) { return f.inner.ResolvePath(path) } + +func (f *failSaveFileIO) Save(path string, opts fileio.SaveOptions, body io.Reader) (fileio.SaveResult, error) { + if strings.HasSuffix(path, f.failSuffix) { + return nil, f.err + } + return f.inner.Save(path, opts, body) +} + +type deleteOnCloseProvider struct { + inner fileio.Provider + targetPath string + deletePath string +} + +func (p *deleteOnCloseProvider) Name() string { return "delete-on-close" } + +func (p *deleteOnCloseProvider) ResolveFileIO(ctx context.Context) fileio.FileIO { + return &deleteOnCloseFileIO{inner: p.inner.ResolveFileIO(ctx), targetPath: p.targetPath, deletePath: p.deletePath} +} + +type deleteOnCloseFileIO struct { + inner fileio.FileIO + targetPath string + deletePath string +} + +func (f *deleteOnCloseFileIO) Open(name string) (fileio.File, error) { + file, err := f.inner.Open(name) + if err != nil { + return nil, err + } + if name != f.targetPath { + return file, nil + } + return &deleteOnCloseFile{File: file, deletePath: f.deletePath}, nil +} + +func (f *deleteOnCloseFileIO) Stat(name string) (fileio.FileInfo, error) { return f.inner.Stat(name) } +func (f *deleteOnCloseFileIO) ResolvePath(path string) (string, error) { + return f.inner.ResolvePath(path) +} +func (f *deleteOnCloseFileIO) Save(path string, opts fileio.SaveOptions, body io.Reader) (fileio.SaveResult, error) { + return f.inner.Save(path, opts, body) +} + +type deleteOnCloseFile struct { + fileio.File + deletePath string +} + +func (f *deleteOnCloseFile) Close() error { + err := f.File.Close() + _ = os.Remove(f.deletePath) + return err +} + +type failAfterSaveProvider struct { + inner fileio.Provider + failSuffix string + err error + afterSave func(path string) +} + +func (p *failAfterSaveProvider) Name() string { return "fail-after-save" } + +func (p *failAfterSaveProvider) ResolveFileIO(ctx context.Context) fileio.FileIO { + return &failAfterSaveFileIO{inner: p.inner.ResolveFileIO(ctx), failSuffix: p.failSuffix, err: p.err, afterSave: p.afterSave} +} + +type failAfterSaveFileIO struct { + inner fileio.FileIO + failSuffix string + err error + afterSave func(path string) +} + +func (f *failAfterSaveFileIO) Open(name string) (fileio.File, error) { return f.inner.Open(name) } +func (f *failAfterSaveFileIO) Stat(name string) (fileio.FileInfo, error) { return f.inner.Stat(name) } +func (f *failAfterSaveFileIO) ResolvePath(path string) (string, error) { + return f.inner.ResolvePath(path) +} + +func (f *failAfterSaveFileIO) Save(path string, opts fileio.SaveOptions, body io.Reader) (fileio.SaveResult, error) { + res, err := f.inner.Save(path, opts, body) + if strings.HasSuffix(path, f.failSuffix) { + if f.afterSave != nil { + f.afterSave(path) + } + return res, f.err + } + return res, err +} + +type driveSyncReadThenError struct { + stage int +} + +func (r *driveSyncReadThenError) Read(p []byte) (int, error) { + if r.stage == 0 { + r.stage++ + copy(p, []byte("local ")) + return 6, nil + } + return 0, fmt.Errorf("read failure") +} + +// TestDriveSyncRemoteWinsPullsNewRemoteAndPushesNewLocal verifies the basic +// two-way sync flow: new_remote files are pulled, new_local files are pushed, +// and modified files use --on-conflict=remote-wins (the default) to pull the +// remote version. +func TestDriveSyncRemoteWinsPullsNewRemoteAndPushesNewLocal(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-remote-wins", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + // Local layout: + // local/b.txt — only local → push + // local/a.txt — both sides, different content → conflict (remote-wins → pull) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + if err := os.WriteFile("local/b.txt", []byte("local-b"), 0o644); err != nil { + t.Fatalf("WriteFile b.txt: %v", err) + } + + // Remote listing: a.txt (modified), d.txt (new_remote) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + map[string]interface{}{"token": "tok_d", "name": "d.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + // Download a.txt for hash comparison (exact mode) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + // Download d.txt (new_remote → pull) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_d/download", + Status: 200, + Body: []byte("remote-d"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + // Download a.txt again (conflict: remote-wins → pull remote over local) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + // Upload b.txt (new_local → push) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "file_token": "tok_b_uploaded", + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "remote-wins", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"action": "downloaded"`) { + t.Errorf("output missing downloaded action\noutput: %s", out) + } + if !strings.Contains(out, `"action": "uploaded"`) { + t.Errorf("output missing uploaded action\noutput: %s", out) + } + if !strings.Contains(out, `"direction": "pull"`) { + t.Errorf("output missing pull direction\noutput: %s", out) + } + if !strings.Contains(out, `"direction": "push"`) { + t.Errorf("output missing push direction\noutput: %s", out) + } + + // Verify local file was overwritten with remote content + data, err := os.ReadFile("local/a.txt") + if err != nil { + t.Fatalf("ReadFile a.txt: %v", err) + } + if string(data) != "remote-a" { + t.Errorf("a.txt content = %q, want %q", string(data), "remote-a") + } + + // Verify d.txt was downloaded + data, err = os.ReadFile("local/d.txt") + if err != nil { + t.Fatalf("ReadFile d.txt: %v", err) + } + if string(data) != "remote-d" { + t.Errorf("d.txt content = %q, want %q", string(data), "remote-d") + } +} + +// TestDriveSyncLocalWinsPushesOverRemote verifies that --on-conflict=local-wins +// pushes the local version over the remote file. +func TestDriveSyncLocalWinsPushesOverRemote(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-local-wins", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + // Download a.txt for hash comparison (exact mode) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + // Upload a.txt with overwrite (local-wins → push over remote) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "file_token": "tok_a", + "version": "v2", + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "local-wins", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"action": "overwritten"`) { + t.Errorf("output missing overwritten action\noutput: %s", out) + } + if !strings.Contains(out, `"direction": "push"`) { + t.Errorf("output missing push direction\noutput: %s", out) + } +} + +// TestDriveSyncKeepBothRenamesLocalAndPullsRemote verifies that +// --on-conflict=keep-both renames the local file with a hash suffix +// and then downloads the remote version to the original path. +func TestDriveSyncKeepBothRenamesLocalAndPullsRemote(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-keep-both", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + // Download a.txt for hash comparison + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + // Download a.txt again (keep-both: pull remote to original path after rename) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "keep-both", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"action": "renamed_local"`) { + t.Errorf("output missing renamed_local action\noutput: %s", out) + } + if !strings.Contains(out, `"action": "downloaded"`) { + t.Errorf("output missing downloaded action\noutput: %s", out) + } + + // Original path should now have remote content + data, err := os.ReadFile("local/a.txt") + if err != nil { + t.Fatalf("ReadFile a.txt: %v", err) + } + if string(data) != "remote-a" { + t.Errorf("a.txt content = %q, want %q", string(data), "remote-a") + } + + // There should be a renamed file with __lark_ suffix + entries, err := os.ReadDir("local") + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + found := false + for _, e := range entries { + if strings.Contains(e.Name(), "__lark_") && strings.HasSuffix(e.Name(), ".txt") { + found = true + renamedData, err := os.ReadFile("local/" + e.Name()) + if err != nil { + t.Fatalf("ReadFile renamed: %v", err) + } + if string(renamedData) != "local-a" { + t.Errorf("renamed file content = %q, want %q", string(renamedData), "local-a") + } + } + } + if !found { + t.Errorf("expected a file with __lark_ suffix in local/, got entries: %v", entries) + } +} + +// TestDriveSyncKeepBothRollsBackRenameOnPullFailure verifies that keep-both +// restores the original local path if the remote download fails after the +// local file has been renamed. +func TestDriveSyncKeepBothRollsBackRenameOnPullFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-keep-both-rollback", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + // Download a.txt for the exact diff phase. + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "keep-both", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected +sync keep-both to fail when the post-rename pull has no stub\nstdout: %s", stdout.String()) + } + + data, readErr := os.ReadFile("local/a.txt") + if readErr != nil { + t.Fatalf("ReadFile a.txt after rollback: %v", readErr) + } + if string(data) != "local-a" { + t.Fatalf("a.txt content after rollback = %q, want %q", string(data), "local-a") + } + + entries, readDirErr := os.ReadDir("local") + if readDirErr != nil { + t.Fatalf("ReadDir local: %v", readDirErr) + } + if len(entries) != 1 || entries[0].Name() != "a.txt" { + t.Fatalf("expected rollback to restore only local/a.txt, got entries: %v", entries) + } +} + +// TestDriveSyncAskConflictFailsBeforeWritesWithoutStdin verifies that +// --on-conflict=ask fails before any sync writes start when stdin is not +// available and the diff contains modified entries. +func TestDriveSyncAskConflictFailsBeforeWritesWithoutStdin(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-ask-eof", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + if err := os.WriteFile("local/b.txt", []byte("local-b"), 0o644); err != nil { + t.Fatalf("WriteFile b.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + map[string]interface{}{"token": "tok_d", "name": "d.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + // Download a.txt for the exact diff phase. + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "ask", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected +sync --on-conflict=ask to fail on EOF\nstdout: %s", stdout.String()) + } + if !strings.Contains(err.Error(), "interactive stdin") { + t.Fatalf("expected interactive stdin validation error, got: %v", err) + } + + data, readErr := os.ReadFile("local/a.txt") + if readErr != nil { + t.Fatalf("ReadFile a.txt after ask failure: %v", readErr) + } + if string(data) != "local-a" { + t.Fatalf("a.txt content after ask failure = %q, want %q", string(data), "local-a") + } + if _, statErr := os.Stat("local/d.txt"); !os.IsNotExist(statErr) { + t.Fatalf("new_remote download should not start before ask preflight; stat err=%v", statErr) + } +} + +func TestDriveSyncFailsOnDuplicateRemoteFiles(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + registerDuplicateRemoteFiles(reg) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--as", "bot", + }, f, stdout) + assertDuplicateRemotePathError(t, err, "dup.txt", duplicateRemoteFileIDFirst, duplicateRemoteFileIDSecond) + if stdout.Len() != 0 { + t.Fatalf("stdout should be empty on duplicate_remote_path, got: %s", stdout.String()) + } +} + +// TestDriveSyncUsesResolvedDuplicateTargetForDiff verifies that +sync computes +// the diff against the same duplicate-remote selection used during execution. +func TestDriveSyncUsesResolvedDuplicateTargetForDiff(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-duplicate-resolution", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("same-as-oldest"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_old", "name": "a.txt", "type": "file", "created_time": "100", "modified_time": "100"}, + map[string]interface{}{"token": "tok_new", "name": "a.txt", "type": "file", "created_time": "200", "modified_time": "200"}, + }, + "has_more": false, + }, + }, + }) + + // The chosen --on-duplicate-remote=oldest target is tok_old. The test omits + // any tok_new download stub so a stale last-seen overwrite bug would fail. + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_old/download", + Status: 200, + Body: []byte("same-as-oldest"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-duplicate-remote", "oldest", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"pushed": 0`) || !strings.Contains(out, `"pulled": 0`) { + t.Fatalf("expected unchanged duplicate target to produce no sync actions\noutput: %s", out) + } + if !strings.Contains(out, `"file_token": "tok_old"`) { + t.Fatalf("expected diff to reference the oldest duplicate target token\noutput: %s", out) + } +} + +// TestDriveSyncLocalWinsNestedFileUsesParentFolderToken verifies that local-wins +// overwrites on nested files keep parent_node aligned with the file's parent. +func TestDriveSyncLocalWinsNestedFileUsesParentFolderToken(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-local-wins-nested", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local/sub", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/sub/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "fld_sub", "name": "sub", "type": "folder"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=fld_sub", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + // Diff phase exact hash download. + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + uploadStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "file_token": "tok_a", + "version": "v2", + }, + }, + } + reg.Register(uploadStub) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "local-wins", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + body := decodeDriveMultipartBody(t, uploadStub) + if got := body.Fields["file_token"]; got != "tok_a" { + t.Fatalf("upload_all file_token = %q, want tok_a", got) + } + if got := body.Fields["parent_node"]; got != "fld_sub" { + t.Fatalf("upload_all parent_node = %q, want fld_sub", got) + } +} + +// TestDriveSyncNewLocalDisappearanceIsReported verifies that files discovered +// during diff but removed before the push phase are surfaced as skipped items +// instead of being silently dropped. +func TestDriveSyncNewLocalDisappearanceIsReported(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-new-local-disappeared", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/ephemeral.txt", []byte("temp"), 0o644); err != nil { + t.Fatalf("WriteFile ephemeral.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + OnMatch: func(_ *http.Request) { + if err := os.Remove("local/ephemeral.txt"); err != nil && !os.IsNotExist(err) { + t.Fatalf("Remove ephemeral.txt in OnMatch: %v", err) + } + }, + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{}, + "has_more": false, + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"skipped": 1`) { + t.Fatalf("expected skipped=1 when new_local disappears during execution\noutput: %s", out) + } + if !strings.Contains(out, `"rel_path": "ephemeral.txt"`) || !strings.Contains(out, `"local file disappeared during sync"`) { + t.Fatalf("expected vanished new_local file to be reported in items\noutput: %s", out) + } +} + +// TestDriveSyncQuickModeUsesModifiedTime verifies that --quick mode +// classifies files by modified_time instead of SHA-256 hash. +func TestDriveSyncQuickModeUsesModifiedTime(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-quick", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + if err := os.WriteFile("local/b.txt", []byte("local-b"), 0o644); err != nil { + t.Fatalf("WriteFile b.txt: %v", err) + } + + // Set a.txt mtime to match remote → unchanged in quick mode + matchTime := time.Unix(1715594880, 0) + if err := os.Chtimes("local/a.txt", matchTime, matchTime); err != nil { + t.Fatalf("Chtimes a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file", "modified_time": "1715594880"}, + map[string]interface{}{"token": "tok_d", "name": "d.txt", "type": "file", "modified_time": "1715595000"}, + }, + "has_more": false, + }, + }, + }) + + // Download d.txt (new_remote → pull) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_d/download", + Status: 200, + Body: []byte("remote-d"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + // Upload b.txt (new_local → push) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "file_token": "tok_b_uploaded", + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--quick", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"detection": "quick"`) { + t.Errorf("output missing detection=quick\noutput: %s", out) + } + // a.txt should be unchanged (mtime matches), not downloaded or uploaded + // It should appear in diff.unchanged but NOT in items[] with a pull/push action + itemsSection := out[strings.Index(out, `"items"`):] + if strings.Contains(itemsSection, `"rel_path": "a.txt"`) { + t.Errorf("a.txt should not appear in items[] (mtime matches remote, should be unchanged)\noutput: %s", out) + } +} + +// TestDriveSyncQuickModeMTimeMismatchStillTriggersWrites verifies the best-effort +// nature of --quick: a timestamp mismatch alone is enough to drive a real sync +// action even when the file bytes are already identical. +func TestDriveSyncQuickModeMTimeMismatchStillTriggersWrites(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-quick-mismatch", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("same-content"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + localTime := time.Unix(1715594880, 0) + if err := os.Chtimes("local/a.txt", localTime, localTime); err != nil { + t.Fatalf("Chtimes a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file", "modified_time": "1715594999"}, + }, + "has_more": false, + }, + }, + }) + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("same-content"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--quick", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"detection": "quick"`) { + t.Fatalf("expected detection=quick\noutput: %s", out) + } + if !strings.Contains(out, `"modified":`) || !strings.Contains(out, `"action": "downloaded"`) { + t.Fatalf("expected quick mtime mismatch to trigger a real pull action\noutput: %s", out) + } +} + +// TestDriveSyncNoChangesReportsEmptyItems verifies that when local and remote +// are identical, +sync reports zero pulled/pushed items. +func TestDriveSyncNoChangesReportsEmptyItems(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-no-changes", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("same"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + // Download a.txt for hash comparison → same content → unchanged + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("same"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"pulled": 0`) { + t.Errorf("expected pulled=0\noutput: %s", out) + } + if !strings.Contains(out, `"pushed": 0`) { + t.Errorf("expected pushed=0\noutput: %s", out) + } + if !strings.Contains(out, `"failed": 0`) { + t.Errorf("expected failed=0\noutput: %s", out) + } +} + +func TestDriveSyncValidateRejectsInvalidInputs(t *testing.T) { + t.Run("missing local-dir", func(t *testing.T) { + runtime, _ := newDriveSyncRuntime(t, "", "folder_root") + err := DriveSync.Validate(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "--local-dir is required") { + t.Fatalf("Validate() error = %v, want missing --local-dir", err) + } + }) + + t.Run("missing folder-token", func(t *testing.T) { + runtime, _ := newDriveSyncRuntime(t, "local", "") + err := DriveSync.Validate(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "--folder-token is required") { + t.Fatalf("Validate() error = %v, want missing --folder-token", err) + } + }) + + t.Run("malformed folder-token", func(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + runtime, _ := newDriveSyncRuntime(t, "local", "tok\nwithnewline") + err := DriveSync.Validate(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "--folder-token") { + t.Fatalf("Validate() error = %v, want malformed folder-token error", err) + } + }) + + t.Run("absolute local-dir", func(t *testing.T) { + runtime, _ := newDriveSyncRuntime(t, "/etc", "folder_root") + err := DriveSync.Validate(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "--local-dir") { + t.Fatalf("Validate() error = %v, want invalid local-dir error", err) + } + }) + + t.Run("missing local-dir path", func(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + runtime, _ := newDriveSyncRuntime(t, "missing", "folder_root") + err := DriveSync.Validate(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "missing") { + t.Fatalf("Validate() error = %v, want missing-path error", err) + } + }) + + t.Run("local-dir is file", func(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.WriteFile("not-a-dir.txt", []byte("x"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + runtime, _ := newDriveSyncRuntime(t, "not-a-dir.txt", "folder_root") + err := DriveSync.Validate(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "not a directory") { + t.Fatalf("Validate() error = %v, want not-a-directory error", err) + } + }) +} + +func TestDriveSyncDryRunUsesFolderToken(t *testing.T) { + runtime, _ := newDriveSyncRuntime(t, "local", "folder_root") + dry := DriveSync.DryRun(context.Background(), runtime) + if dry == nil { + t.Fatal("DryRun returned nil") + } + + data, err := json.Marshal(dry) + if err != nil { + t.Fatalf("marshal dry run: %v", err) + } + if !strings.Contains(string(data), `"folder_token":"folder_root"`) { + t.Fatalf("dry run missing folder_token, got: %s", string(data)) + } +} + +func TestDriveSyncExecuteRejectsUnsafeLocalDir(t *testing.T) { + runtime, _ := newDriveSyncRuntime(t, "/etc", "folder_root") + err := DriveSync.Execute(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "--local-dir") { + t.Fatalf("Execute() error = %v, want unsafe local-dir validation error", err) + } +} + +func TestDriveSyncAskConflictParsesChoices(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr string + }{ + {name: "blank line defaults remote wins", input: "\n", want: driveSyncOnConflictRemoteWins}, + {name: "local short form", input: "L\n", want: driveSyncOnConflictLocalWins}, + {name: "keep both long form", input: "keep-both\n", want: driveSyncOnConflictKeepBoth}, + {name: "skip returns empty resolution", input: "skip\n", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + f.IOStreams.In = strings.NewReader(tt.input) + + runtime := common.TestNewRuntimeContext(&cobra.Command{Use: "drive"}, driveTestConfig()) + runtime.Factory = f + + got, err := driveSyncAskConflict("a.txt", runtime) + if tt.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("driveSyncAskConflict() error = %v, want substring %q", err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("driveSyncAskConflict() unexpected error: %v", err) + } + if got != tt.want { + t.Fatalf("driveSyncAskConflict() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestDriveSyncAskConflictRejectsMissingStdin(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + runtime := common.TestNewRuntimeContext(&cobra.Command{Use: "drive"}, driveTestConfig()) + runtime.Factory = f + + _, err := driveSyncAskConflict("a.txt", runtime) + if err == nil || !strings.Contains(err.Error(), "stdin is not available") { + t.Fatalf("driveSyncAskConflict() error = %v, want stdin availability error", err) + } +} + +func TestDriveSyncAskConflictHandlesEOFAndReadErrors(t *testing.T) { + t.Run("blank EOF without answer fails", func(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + f.IOStreams.In = strings.NewReader("") + + runtime := common.TestNewRuntimeContext(&cobra.Command{Use: "drive"}, driveTestConfig()) + runtime.Factory = f + + _, err := driveSyncAskConflict("a.txt", runtime) + if err == nil || !strings.Contains(err.Error(), "stdin reached EOF") { + t.Fatalf("driveSyncAskConflict() error = %v, want EOF failure", err) + } + }) + + t.Run("partial token before EOF is still accepted", func(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + f.IOStreams.In = strings.NewReader("local") + + runtime := common.TestNewRuntimeContext(&cobra.Command{Use: "drive"}, driveTestConfig()) + runtime.Factory = f + + got, err := driveSyncAskConflict("a.txt", runtime) + if err != nil { + t.Fatalf("driveSyncAskConflict() unexpected error: %v", err) + } + if got != driveSyncOnConflictLocalWins { + t.Fatalf("driveSyncAskConflict() = %q, want %q", got, driveSyncOnConflictLocalWins) + } + }) + + t.Run("unknown answer returns validation error", func(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + f.IOStreams.In = strings.NewReader("what\n") + + runtime := common.TestNewRuntimeContext(&cobra.Command{Use: "drive"}, driveTestConfig()) + runtime.Factory = f + + _, err := driveSyncAskConflict("a.txt", runtime) + if err == nil || !strings.Contains(err.Error(), "invalid conflict choice") { + t.Fatalf("driveSyncAskConflict() error = %v, want invalid-choice failure", err) + } + }) + + t.Run("non EOF read failure returns wrapped error", func(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + f.IOStreams.In = bufio.NewReader(&driveSyncReadThenError{}) + + runtime := common.TestNewRuntimeContext(&cobra.Command{Use: "drive"}, driveTestConfig()) + runtime.Factory = f + + _, err := driveSyncAskConflict("a.txt", runtime) + if err == nil || !strings.Contains(err.Error(), "cannot read conflict choice") { + t.Fatalf("driveSyncAskConflict() error = %v, want wrapped read failure", err) + } + }) +} + +func TestDriveSyncRollbackRenamedLocalRestoresRenamedFile(t *testing.T) { + tmpDir := t.TempDir() + oldAbsPath := tmpDir + "/a.txt" + newAbsPath := tmpDir + "/a__lark.txt" + + if err := os.WriteFile(oldAbsPath, []byte("partial remote"), 0o644); err != nil { + t.Fatalf("WriteFile oldAbsPath: %v", err) + } + if err := os.WriteFile(newAbsPath, []byte("original local"), 0o644); err != nil { + t.Fatalf("WriteFile newAbsPath: %v", err) + } + + if err := driveSyncRollbackRenamedLocal(oldAbsPath, newAbsPath); err != nil { + t.Fatalf("driveSyncRollbackRenamedLocal() error = %v", err) + } + + data, err := os.ReadFile(oldAbsPath) + if err != nil { + t.Fatalf("ReadFile restored oldAbsPath: %v", err) + } + if got := string(data); got != "original local" { + t.Fatalf("restored content = %q, want %q", got, "original local") + } + if _, err := os.Stat(newAbsPath); !os.IsNotExist(err) { + t.Fatalf("expected renamed path to be removed after rollback, stat err = %v", err) + } +} + +func TestDriveSyncRollbackRenamedLocalWithoutPartialRestore(t *testing.T) { + tmpDir := t.TempDir() + oldAbsPath := tmpDir + "/a.txt" + newAbsPath := tmpDir + "/a__lark.txt" + + if err := os.WriteFile(newAbsPath, []byte("original local"), 0o644); err != nil { + t.Fatalf("WriteFile newAbsPath: %v", err) + } + + if err := driveSyncRollbackRenamedLocal(oldAbsPath, newAbsPath); err != nil { + t.Fatalf("driveSyncRollbackRenamedLocal() error = %v", err) + } + + data, err := os.ReadFile(oldAbsPath) + if err != nil { + t.Fatalf("ReadFile restored oldAbsPath: %v", err) + } + if got := string(data); got != "original local" { + t.Fatalf("restored content = %q, want %q", got, "original local") + } +} + +func TestDriveSyncRollbackRenamedLocalRejectsDirectoryAtOriginalPath(t *testing.T) { + tmpDir := t.TempDir() + oldAbsPath := tmpDir + "/a.txt" + newAbsPath := tmpDir + "/a__lark.txt" + + if err := os.Mkdir(oldAbsPath, 0o755); err != nil { + t.Fatalf("Mkdir oldAbsPath: %v", err) + } + if err := os.WriteFile(newAbsPath, []byte("original local"), 0o644); err != nil { + t.Fatalf("WriteFile newAbsPath: %v", err) + } + + err := driveSyncRollbackRenamedLocal(oldAbsPath, newAbsPath) + if err == nil || !strings.Contains(err.Error(), "became a directory") { + t.Fatalf("driveSyncRollbackRenamedLocal() error = %v, want directory error", err) + } +} + +func TestDriveSyncRollbackRenamedLocalSurfacesRenameFailure(t *testing.T) { + tmpDir := t.TempDir() + oldAbsPath := tmpDir + "/a.txt" + newAbsPath := tmpDir + "/missing.txt" + + err := driveSyncRollbackRenamedLocal(oldAbsPath, newAbsPath) + if err == nil || !strings.Contains(err.Error(), "restore renamed local file") { + t.Fatalf("driveSyncRollbackRenamedLocal() error = %v, want rename failure", err) + } +} + +func TestDriveSyncRollbackRenamedLocalSurfacesRemoveFailure(t *testing.T) { + tmpDir := t.TempDir() + oldAbsPath := filepath.Join(tmpDir, "a.txt") + newAbsPath := filepath.Join(tmpDir, "a__lark.txt") + + if err := os.WriteFile(oldAbsPath, []byte("partial remote"), 0o644); err != nil { + t.Fatalf("WriteFile oldAbsPath: %v", err) + } + if err := os.WriteFile(newAbsPath, []byte("original local"), 0o644); err != nil { + t.Fatalf("WriteFile newAbsPath: %v", err) + } + if err := os.Chmod(tmpDir, 0o555); err != nil { + t.Fatalf("Chmod read-only dir: %v", err) + } + defer func() { + _ = os.Chmod(tmpDir, 0o755) + }() + + err := driveSyncRollbackRenamedLocal(oldAbsPath, newAbsPath) + if err == nil || !strings.Contains(err.Error(), "remove partial restored path") { + t.Fatalf("driveSyncRollbackRenamedLocal() error = %v, want remove failure", err) + } +} + +func TestDriveSyncRollbackRenamedLocalSurfacesStatFailure(t *testing.T) { + tmpDir := t.TempDir() + blockedDir := filepath.Join(tmpDir, "blocked") + oldAbsPath := filepath.Join(blockedDir, "a.txt") + newAbsPath := filepath.Join(blockedDir, "a__lark.txt") + + if err := os.MkdirAll(blockedDir, 0o755); err != nil { + t.Fatalf("MkdirAll blockedDir: %v", err) + } + if err := os.WriteFile(newAbsPath, []byte("original local"), 0o644); err != nil { + t.Fatalf("WriteFile newAbsPath: %v", err) + } + if err := os.Chmod(blockedDir, 0o000); err != nil { + t.Fatalf("Chmod blockedDir: %v", err) + } + defer func() { + _ = os.Chmod(blockedDir, 0o755) + }() + + err := driveSyncRollbackRenamedLocal(oldAbsPath, newAbsPath) + if err == nil || !strings.Contains(err.Error(), "stat original path") { + t.Fatalf("driveSyncRollbackRenamedLocal() error = %v, want stat failure", err) + } +} + +func TestDriveSyncAskConflictEOFDuringExecuteReportsFailedItem(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-ask-exec-eof", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + f.IOStreams.In = strings.NewReader("") + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "ask", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected EOF failure during ask execution\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 || !strings.Contains(items[0].Error, "stdin reached EOF") { + t.Fatalf("expected failed ask item, got detail: %#v", exitErr.Detail.Detail) + } + data, readErr := os.ReadFile("local/a.txt") + if readErr != nil { + t.Fatalf("ReadFile a.txt: %v", readErr) + } + if string(data) != "local-a" { + t.Fatalf("a.txt content = %q, want local-a", string(data)) + } +} + +func TestDriveSyncAskConflictEOFDuringPlanningPreventsAnyWrites(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-ask-plan-eof", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + f.IOStreams.In = strings.NewReader("") + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + if err := os.WriteFile("local/b.txt", []byte("local-b"), 0o644); err != nil { + t.Fatalf("WriteFile b.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + map[string]interface{}{"token": "tok_d", "name": "d.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "ask", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected EOF failure during ask planning\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + if exitErr.Detail.Type != "partial_failure" || !strings.Contains(exitErr.Error(), "stdin reached EOF") { + t.Fatalf("expected planning failure detail mentioning EOF, got: %#v", exitErr.Detail) + } + if data, readErr := os.ReadFile("local/a.txt"); readErr != nil || string(data) != "local-a" { + t.Fatalf("a.txt should remain untouched, readErr=%v content=%q", readErr, string(data)) + } + if data, readErr := os.ReadFile("local/b.txt"); readErr != nil || string(data) != "local-b" { + t.Fatalf("b.txt should remain untouched, readErr=%v content=%q", readErr, string(data)) + } + if _, statErr := os.Stat("local/d.txt"); !os.IsNotExist(statErr) { + t.Fatalf("new_remote file must not be downloaded before ask decisions, stat err=%v", statErr) + } +} + +func TestDriveSyncDryRunQuickAcceptsMetadataOnlyScope(t *testing.T) { + f, stdout, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + f.Credential = credential.NewCredentialProvider(nil, nil, &driveStatusScopedTokenResolver{scopes: "drive:drive.metadata:readonly"}, nil) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--quick", + "--dry-run", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("expected quick dry-run to succeed without write scopes, got: %v\nstdout: %s", err, stdout.String()) + } + if strings.Contains(strings.ToLower(stdout.String()), "missing_scope") { + t.Fatalf("dry-run should not surface missing_scope, got: %s", stdout.String()) + } +} + +func TestDriveSyncExactRemoteWinsAcceptsDownloadOnlyScope(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-download-scope-only", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + f.Credential = credential.NewCredentialProvider(nil, nil, &driveStatusScopedTokenResolver{scopes: "drive:drive.metadata:readonly drive:file:download"}, nil) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "remote-wins", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("expected exact remote-wins to succeed with download-only scope, got: %v\nstdout: %s", err, stdout.String()) + } + if strings.Contains(strings.ToLower(stdout.String()), "missing_scope") { + t.Fatalf("should not surface missing_scope, got: %s", stdout.String()) + } +} + +func TestDriveSyncAskConflictSkipReportsSkippedItem(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-ask-skip", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + f.IOStreams.In = strings.NewReader("skip\n") + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "ask", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + out := stdout.String() + if !strings.Contains(out, `"action": "skipped"`) || !strings.Contains(out, "user skipped") { + t.Fatalf("expected skipped conflict item, got: %s", out) + } + if !strings.Contains(out, `"skipped": 1`) { + t.Fatalf("expected skipped summary count, got: %s", out) + } +} + +func TestDriveSyncReportsNewRemoteDownloadFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-new-remote-fail", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + f.FileIOProvider = &failSaveProvider{inner: f.FileIOProvider, failSuffix: filepath.Join("local", "d.txt"), err: fmt.Errorf("save failed")} + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_d", "name": "d.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_d/download", + Status: 200, + Body: []byte("remote-d"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "remote-wins", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected download failure\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 || items[0].Direction != "pull" || !strings.Contains(items[0].Error, "save failed") { + t.Fatalf("expected failed pull item, got detail: %#v", exitErr.Detail.Detail) + } +} + +func TestDriveSyncReportsNewLocalEnsureFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-new-local-ensure-fail", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll(filepath.Join("local", "sub"), 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(filepath.Join("local", "sub", "a.txt"), []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{"files": []interface{}{}, "has_more": false}, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/create_folder", + Body: map[string]interface{}{ + "code": 9999, + "msg": "create parent failed", + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected ensure failure\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 || items[0].Direction != "push" || !strings.Contains(items[0].Error, "create parent failed") { + t.Fatalf("expected failed push item, got detail: %#v", exitErr.Detail.Detail) + } +} + +func TestDriveSyncReportsNewLocalUploadFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-new-local-upload-fail", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/b.txt", []byte("local-b"), 0o644); err != nil { + t.Fatalf("WriteFile b.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{"files": []interface{}{}, "has_more": false}, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Body: map[string]interface{}{ + "code": 9999, + "msg": "upload failed", + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected upload failure\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 || items[0].Direction != "push" || !strings.Contains(items[0].Error, "upload failed") { + t.Fatalf("expected failed upload item, got detail: %#v", exitErr.Detail.Detail) + } +} + +func TestDriveSyncLocalWinsReportsUploadFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-local-wins-upload-fail", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Body: map[string]interface{}{ + "code": 9999, + "msg": "overwrite failed", + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "local-wins", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected local-wins upload failure\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 || items[0].Direction != "push" || !strings.Contains(items[0].Error, "overwrite failed") { + t.Fatalf("expected failed overwrite item, got detail: %#v", exitErr.Detail.Detail) + } +} + +func TestDriveSyncKeepBothReportsRenameFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-keep-both-rename-fail", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + // Exhaust all possible suffixed paths so that + // relPathWithUniqueFileTokenSuffix cannot find a free name. + // The function tries 12-char, 24-char, 64-char hash prefixes, + // then _2 through _N sequential suffixes. + // We create local blocker files at each candidate path; they become + // new_local items (uploaded via the reusable stub) and occupy the + // suffixed names in the keep-both occupied map. + tokenHash := stableTokenHash("tok_a") + candidates := []string{ + relPathWithSuffix("a.txt", "__lark_"+tokenHash[:12]), + relPathWithSuffix("a.txt", "__lark_"+tokenHash[:24]), + relPathWithSuffix("a.txt", "__lark_"+tokenHash), + } + for i := 2; i <= driveUniqueSuffixMaxSeq; i++ { + candidates = append(candidates, relPathWithSuffix("a.txt", "__lark_"+tokenHash+"_"+strconv.Itoa(i))) + } + for _, c := range candidates { + full := filepath.Join("local", filepath.FromSlash(c)) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatalf("MkdirAll parent of %s: %v", c, err) + } + if err := os.WriteFile(full, []byte("blocker"), 0o644); err != nil { + t.Fatalf("WriteFile %s: %v", c, err) + } + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + // Reusable upload stub: all blocker files (new_local) upload successfully. + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Reusable: true, + Body: map[string]interface{}{ + "code": 0, + "msg": "ok", + "data": map[string]interface{}{ + "file_token": "tok_blocker", + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "keep-both", + "--quick", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected keep-both suffix exhaustion error\nstdout: %s", stdout.String()) + } + // The error may be a plain ExitError (no Detail.Detail) or a + // partial_failure with items. Either way it must mention the + // suffix exhaustion. + errMsg := err.Error() + // The suffix exhaustion message may be in the top-level error or + // inside a partial_failure detail item. Check both. + foundSuffixError := strings.Contains(errMsg, "could not generate a unique rel_path") + if !foundSuffixError { + var exitErr *output.ExitError + if errors.As(err, &exitErr) && exitErr.Detail != nil { + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + for _, item := range items { + if strings.Contains(item.Error, "could not generate a unique rel_path") { + foundSuffixError = true + break + } + } + if !foundSuffixError { + t.Fatalf("expected suffix exhaustion error, got: %s; detail: %#v", errMsg, exitErr.Detail.Detail) + } + } else { + t.Fatalf("expected suffix exhaustion error, got: %s", errMsg) + } + } +} + +func TestDriveSyncExecuteReturnsRemoteListError(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + runtime, _ := newDriveSyncRuntime(t, "local", "folder_root") + + err := DriveSync.Execute(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "API call failed") { + t.Fatalf("Execute() error = %v, want remote list error", err) + } +} + +func TestDriveSyncExecuteReturnsLocalWalkError(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + runtime, _ := newDriveSyncRuntime(t, "local", "folder_root") + if err := os.RemoveAll("local"); err != nil { + t.Fatalf("RemoveAll local: %v", err) + } + + err := DriveSync.Execute(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "walk") { + t.Fatalf("Execute() error = %v, want local walk error", err) + } +} + +func TestDriveSyncExecuteWrapsInvalidDuplicateStrategyForPullViews(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + runtime := newDriveSyncRuntimeWithFactory(t, f, "local", "folder_root") + if err := runtime.Cmd.Flags().Set("on-duplicate-remote", "invalid-strategy"); err != nil { + t.Fatalf("set --on-duplicate-remote: %v", err) + } + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + map[string]interface{}{"token": "tok_b", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + err := DriveSync.Execute(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "unsupported duplicate remote strategy") { + t.Fatalf("Execute() error = %v, want pull views strategy error", err) + } +} + +func TestDriveSyncExecuteWrapsUnsupportedPushDuplicateStrategy(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + runtime := newDriveSyncRuntimeWithFactory(t, f, "local", "folder_root") + if err := runtime.Cmd.Flags().Set("on-duplicate-remote", driveDuplicateRemoteRename); err != nil { + t.Fatalf("set --on-duplicate-remote: %v", err) + } + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + map[string]interface{}{"token": "tok_b", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + err := DriveSync.Execute(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "unsupported duplicate remote strategy") { + t.Fatalf("Execute() error = %v, want push views strategy error", err) + } +} + +func TestDriveSyncExecuteSurfacesHashLocalError(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o000); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + defer func() { _ = os.Chmod("local/a.txt", 0o644) }() + + f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + runtime := newDriveSyncRuntimeWithFactory(t, f, "local", "folder_root") + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + err := DriveSync.Execute(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "cannot read file") { + t.Fatalf("Execute() error = %v, want hashLocal error", err) + } +} + +func TestDriveSyncExecuteSurfacesHashRemoteError(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + runtime := newDriveSyncRuntimeWithFactory(t, f, "local", "folder_root") + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + err := DriveSync.Execute(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "download") { + t.Fatalf("Execute() error = %v, want hashRemote error", err) + } +} + +func TestDriveSyncExecuteReturnsPushWalkErrorAfterDiff(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + runtime := newDriveSyncRuntimeWithFactory(t, f, "local", "folder_root") + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + OnMatch: func(req *http.Request) { + _ = os.RemoveAll("local") + }, + }) + + err := DriveSync.Execute(context.Background(), runtime) + if err == nil || !strings.Contains(err.Error(), "walk") { + t.Fatalf("Execute() error = %v, want push walk error", err) + } +} + +func TestDriveSyncExecuteUnknownConflictStrategySkipsModifiedFile(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + runtime := newDriveSyncRuntimeWithFactory(t, f, "local", "folder_root") + if err := runtime.Cmd.Flags().Set("on-conflict", "mystery-mode"); err != nil { + t.Fatalf("set --on-conflict: %v", err) + } + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := DriveSync.Execute(context.Background(), runtime) + if err != nil { + t.Fatalf("Execute() unexpected error: %v", err) + } +} + +func TestDriveSyncModifiedFileDisappearingBeforeExecuteIsSkipped(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-modified-disappears", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + f.FileIOProvider = &deleteOnCloseProvider{ + inner: f.FileIOProvider, + targetPath: filepath.Join("local", "a.txt"), + deletePath: filepath.Join("local", "a.txt"), + } + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "remote-wins", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + out := stdout.String() + if !strings.Contains(out, `"direction": "conflict"`) || !strings.Contains(out, "local file disappeared during sync") { + t.Fatalf("expected modified file disappearance to be reported, got: %s", out) + } + if !strings.Contains(out, `"skipped": 1`) { + t.Fatalf("expected skipped summary count, got: %s", out) + } +} + +func TestDriveSyncRemoteWinsReportsModifiedPullFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-remote-wins-pull-fail", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + f.FileIOProvider = &failSaveProvider{inner: f.FileIOProvider, failSuffix: filepath.Join("local", "a.txt"), err: fmt.Errorf("save failed")} + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + Reusable: true, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "remote-wins", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected modified pull failure\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 || items[0].Direction != "pull" || !strings.Contains(items[0].Error, "save failed") { + t.Fatalf("expected failed modified pull item, got detail: %#v", exitErr.Detail.Detail) + } +} + +func TestDriveSyncKeepBothReportsRollbackFailureAfterPullError(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-keep-both-rollback-fail", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + f.FileIOProvider = &failAfterSaveProvider{ + inner: f.FileIOProvider, + failSuffix: filepath.Join("local", "a.txt"), + err: fmt.Errorf("save failed"), + afterSave: func(path string) { + _ = os.Chmod(filepath.Dir(path), 0o555) + }, + } + defer func() { + _ = os.Chmod(filepath.Join(tmpDir, "local"), 0o755) + }() + + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + Reusable: true, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "keep-both", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected keep-both rollback failure\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 || !strings.Contains(items[0].Error, "rollback failed") { + t.Fatalf("expected rollback failure in item error, got detail: %#v", exitErr.Detail.Detail) + } +} + +func TestDriveSyncStatusRemoteFilesUsesStableTokens(t *testing.T) { + remoteFiles := driveSyncStatusRemoteFiles(map[string]drivePullTarget{ + "item-token.txt": { + DownloadToken: "download_token_should_not_win", + ItemFileToken: "item_file_token", + ModifiedTime: "111", + }, + "download-token.txt": { + DownloadToken: "download_only_token", + ModifiedTime: "222", + }, + }) + + if got := remoteFiles["item-token.txt"].FileToken; got != "item_file_token" { + t.Fatalf("item-token.txt file_token = %q, want item_file_token", got) + } + if got := remoteFiles["download-token.txt"].FileToken; got != "download_only_token" { + t.Fatalf("download-token.txt file_token = %q, want download_only_token", got) + } + if got := remoteFiles["download-token.txt"].ModifiedTime; got != "222" { + t.Fatalf("download-token.txt modified_time = %q, want 222", got) + } +} + +func TestDriveSyncLocalWinsNestedFileReportsParentEnsureFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-local-wins-parent-fail", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll(filepath.Join("local", "sub"), 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(filepath.Join("local", "sub", "a.txt"), []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_nested", "name": "sub/a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_nested/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/create_folder", + Body: map[string]interface{}{ + "code": 9999, + "msg": "create parent failed", + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "local-wins", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected parent ensure failure\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 || !strings.Contains(items[0].Error, "create parent failed") { + t.Fatalf("expected failed item with create_folder error, got detail: %#v", exitErr.Detail.Detail) + } +} + +// TestDriveSyncSkipsNonFileRemoteEntries verifies that new_remote entries +// whose rel_path is not in pullRemoteFiles (non-file types like docx, +// shortcuts) are silently skipped rather than causing a panic or error. +func TestDriveSyncSkipsNonFileRemoteEntries(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-skip-nonfile", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + // Remote has a docx and a shortcut — both should be skipped in pull. + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_doc", "name": "notes.docx", "type": "docx"}, + map[string]interface{}{"token": "tok_sc", "name": "link.lnk", "type": "shortcut"}, + }, + "has_more": false, + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"pulled": 0`) { + t.Fatalf("expected pulled=0 (non-file entries skipped), got: %s", out) + } + if !strings.Contains(out, `"pushed": 0`) { + t.Fatalf("expected pushed=0, got: %s", out) + } +} + +// TestDriveSyncAskConflictRemoteShortForms verifies the "r", "remote", +// and "remote-wins" input variants all resolve to remote-wins. +func TestDriveSyncAskConflictRemoteShortForms(t *testing.T) { + tests := []struct { + name string + input string + }{ + {name: "r", input: "r\n"}, + {name: "remote", input: "remote\n"}, + {name: "remote-wins", input: "remote-wins\n"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + f.IOStreams.In = strings.NewReader(tt.input) + + runtime := common.TestNewRuntimeContext(&cobra.Command{Use: "drive"}, driveTestConfig()) + runtime.Factory = f + + got, err := driveSyncAskConflict("a.txt", runtime) + if err != nil { + t.Fatalf("driveSyncAskConflict() unexpected error: %v", err) + } + if got != driveSyncOnConflictRemoteWins { + t.Fatalf("driveSyncAskConflict() = %q, want %q", got, driveSyncOnConflictRemoteWins) + } + }) + } +} + +// TestDriveSyncNeedsDownloadScopeReturnsFalseForLocalWinsOnly verifies +// that driveSyncNeedsDownloadScope returns false when there are no +// new_remote entries and all modified entries resolve to local-wins. +func TestDriveSyncNeedsDownloadScopeReturnsFalseForLocalWinsOnly(t *testing.T) { + modified := []driveStatusEntry{{RelPath: "a.txt"}, {RelPath: "b.txt"}} + resolutions := map[string]string{"a.txt": driveSyncOnConflictLocalWins, "b.txt": driveSyncOnConflictLocalWins} + + if driveSyncNeedsDownloadScope(nil, modified, resolutions) { + t.Fatal("expected false when no new_remote and all conflicts are local-wins") + } +} + +// TestDriveSyncNeedsDownloadScopeReturnsTrueForKeepBoth verifies that +// driveSyncNeedsDownloadScope returns true when a modified entry resolves +// to keep-both (which requires pulling the remote version). +func TestDriveSyncNeedsDownloadScopeReturnsTrueForKeepBoth(t *testing.T) { + modified := []driveStatusEntry{{RelPath: "a.txt"}} + resolutions := map[string]string{"a.txt": driveSyncOnConflictKeepBoth} + + if !driveSyncNeedsDownloadScope(nil, modified, resolutions) { + t.Fatal("expected true when a conflict resolves to keep-both") + } +} + +// TestDriveSyncRemoteWinsReportsMissingPullView verifies that when a +// modified file's rel_path is not in pullRemoteFiles during the +// remote-wins branch, a failed item is reported instead of a panic. +// This can happen when duplicate remote entries are resolved differently +// between pull and status views. +func TestDriveSyncRemoteWinsReportsMissingPullView(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + runtime := newDriveSyncRuntimeWithFactory(t, f, "local", "folder_root") + if err := runtime.Cmd.Flags().Set("on-duplicate-remote", "invalid-strategy"); err != nil { + t.Fatalf("set --on-duplicate-remote: %v", err) + } + // Two remote files with the same name — the invalid duplicate strategy + // will cause drivePullRemoteViews to return an error, which is wrapped + // as an internal error before we even reach the remote-wins branch. + // To test the "remote file not found in pull views" branch directly, + // we use a unit-level approach instead. + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + map[string]interface{}{"token": "tok_b", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + err := DriveSync.Execute(context.Background(), runtime) + if err == nil { + t.Fatalf("expected error for invalid duplicate strategy\nstdout: %s", err) + } + if !strings.Contains(err.Error(), "unsupported duplicate remote strategy") { + t.Fatalf("expected strategy error, got: %v", err) + } +} + +// TestDriveSyncKeepBothReportsSuffixError verifies that keep-both reports +// a failed item when relPathWithUniqueFileTokenSuffix cannot find a +// unique name because all candidates are already occupied. +func TestDriveSyncKeepBothReportsSuffixError(t *testing.T) { + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + // Pre-occupy all possible suffixed names for a.txt with token tok_a. + // This forces relPathWithUniqueFileTokenSuffix to exhaust all attempts. + occupied := map[string]struct{}{"a.txt": {}} + // Generate the same suffixes the function would try. + tokenHash := stableTokenHash("tok_a") + suffixes := []string{ + "__lark_" + tokenHash[:12], + "__lark_" + tokenHash[:24], + "__lark_" + tokenHash, + } + for _, suffix := range suffixes { + occupied[relPathWithSuffix("a.txt", suffix)] = struct{}{} + } + for attempt := 2; attempt <= driveUniqueSuffixMaxSeq; attempt++ { + occupied[relPathWithSuffix("a.txt", "__lark_"+tokenHash+"_"+strconv.Itoa(attempt))] = struct{}{} + } + + // Verify the function actually fails with this occupied set. + _, err := relPathWithUniqueFileTokenSuffix("a.txt", "tok_a", occupied) + if err == nil { + t.Fatal("expected relPathWithUniqueFileTokenSuffix to fail when all names are occupied") + } +} + +// TestDriveSyncKeepBothRollbackSucceedsOnPullFailure verifies the full +// keep-both rollback path: when the pull download fails after the local +// file has been renamed, the rollback restores the original file and +// the error is reported as a partial_failure. +func TestDriveSyncKeepBothRollbackSucceedsOnPullFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-keep-both-rollback-pull-fail", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + f.FileIOProvider = &failSaveProvider{inner: f.FileIOProvider, failSuffix: filepath.Join("local", "a.txt"), err: fmt.Errorf("save failed")} + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + // Diff phase: download for hash comparison. + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + // Pull phase: download for keep-both pull (will fail at Save). + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + Reusable: true, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "keep-both", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected keep-both pull failure with rollback\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 || !strings.Contains(items[0].Error, "save failed") { + t.Fatalf("expected save failure in item, got detail: %#v", exitErr.Detail.Detail) + } + + // Rollback should have restored the original file. + data, readErr := os.ReadFile("local/a.txt") + if readErr != nil { + t.Fatalf("ReadFile a.txt after rollback: %v", readErr) + } + if string(data) != "local-a" { + t.Fatalf("a.txt content after rollback = %q, want local-a", string(data)) + } +} + +// TestDriveSyncLocalWinsFallbackToRemoteEntriesForPush verifies that +// when remoteFile.FileToken is empty in the local-wins branch, the code +// falls back to remoteEntriesForPush to find the existing token. +func TestDriveSyncLocalWinsFallbackToRemoteEntriesForPush(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-local-wins-fallback", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + // Two remote files with the same name (duplicate). Using --on-duplicate-remote=newest + // resolves to tok_new. The diff phase uses driveSyncStatusRemoteFiles which builds + // FileToken from pullRemoteFiles — but the local-wins branch reads remoteFile.FileToken + // from the status remoteFiles map. When the status map's FileToken differs from the + // push view's FileToken, the fallback to remoteEntriesForPush kicks in. + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_old", "name": "a.txt", "type": "file", "created_time": "100", "modified_time": "100"}, + map[string]interface{}{"token": "tok_new", "name": "a.txt", "type": "file", "created_time": "200", "modified_time": "200"}, + }, + "has_more": false, + }, + }, + }) + // Diff phase: download tok_new (the newest duplicate) for hash comparison. + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_new/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + // Upload with overwrite — the file_token in the upload should come from + // the push view's resolved duplicate (tok_new via newest strategy). + uploadStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "file_token": "tok_new", + "version": "v2", + }, + }, + } + reg.Register(uploadStub) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "local-wins", + "--on-duplicate-remote", "newest", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"action": "overwritten"`) { + t.Fatalf("expected overwritten action, got: %s", out) + } +} + +// TestDriveSyncCreatesEmptyLocalDirectoriesOnDrive verifies that empty local +// directories are created on Drive during +sync, mirroring +push behavior. +func TestDriveSyncCreatesEmptyLocalDirectoriesOnDrive(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-empty-dirs", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + // local/empty_sub/ is an empty directory — should be created on Drive. + if err := os.MkdirAll(filepath.Join("local", "empty_sub"), 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{}, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/create_folder", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "token": "fld_empty_sub", + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v\nstdout: %s", err, stdout.String()) + } + + out := stdout.String() + if !strings.Contains(out, `"action": "folder_created"`) { + t.Fatalf("expected folder_created action for empty directory, got: %s", out) + } + if !strings.Contains(out, `"rel_path": "empty_sub"`) { + t.Fatalf("expected empty_sub in items, got: %s", out) + } +} + +// TestDriveSyncLocalWinsUsesReturnedTokenOnUploadFailure verifies that +// when local-wins upload fails with a partial-success response (new +// file_token returned alongside error), the reported item uses the +// freshly returned token rather than the stale existingToken. +func TestDriveSyncLocalWinsUsesReturnedTokenOnUploadFailure(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-local-wins-partial-token", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile("local/a.txt", []byte("local-a"), 0o644); err != nil { + t.Fatalf("WriteFile a.txt: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_a", "name": "a.txt", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/tok_a/download", + Status: 200, + Body: []byte("remote-a"), + Headers: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }) + // Partial-success upload: returns a new file_token alongside an error code. + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/upload_all", + Body: map[string]interface{}{ + "code": 9999, + "msg": "partial write", + "data": map[string]interface{}{ + "file_token": "tok_a_new", + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--on-conflict", "local-wins", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected local-wins upload failure\nstdout: %s", stdout.String()) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + t.Fatalf("expected structured ExitError, got: %v", err) + } + detailMap, _ := exitErr.Detail.Detail.(map[string]interface{}) + items, _ := detailMap["items"].([]driveSyncItem) + if len(items) == 0 { + t.Fatalf("expected failed item, got detail: %#v", exitErr.Detail.Detail) + } + // The reported token should be the new one from the partial-success + // response, not the stale existingToken ("tok_a"). + if items[0].FileToken != "tok_a_new" { + t.Fatalf("expected FileToken=tok_a_new from partial-success, got %q", items[0].FileToken) + } +} + +// TestDriveSyncRejectsPathTypeConflict verifies that +sync hard-fails when a +// local regular file shares a rel_path with a remote non-file entry (folder, +// docx, shortcut, etc.) instead of silently attempting to upload and leaving +// the remote in a broken mixed-type state. +func TestDriveSyncRejectsPathTypeConflict(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-type-conflict", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + // Local has a regular file "report" at the same path as a remote docx. + if err := os.WriteFile("local/report", []byte("local-content"), 0o644); err != nil { + t.Fatalf("WriteFile report: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_doc", "name": "report", "type": "docx"}, + }, + "has_more": false, + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected type conflict error\nstdout: %s", stdout.String()) + } + if !strings.Contains(err.Error(), "path type conflict") { + t.Fatalf("expected path type conflict error, got: %v\nstdout: %s", err, stdout.String()) + } + if !strings.Contains(err.Error(), "docx") { + t.Fatalf("error should mention remote type docx, got: %v", err) + } +} + +// TestDriveSyncRejectsLocalDirVsRemoteFileTypeConflict verifies that +sync +// hard-fails when a local directory shares a rel_path with a remote file, +// which would otherwise attempt create_folder and leave the remote in a +// broken mixed-type state. +func TestDriveSyncRejectsLocalDirVsRemoteFileTypeConflict(t *testing.T) { + syncTestConfig := &core.CliConfig{ + AppID: "drive-sync-dir-vs-file-conflict", AppSecret: "test-secret", Brand: core.BrandFeishu, + } + f, stdout, _, reg := cmdutil.TestFactory(t, syncTestConfig) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("local", 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + // Local has a directory "report" at the same path as a remote file. + if err := os.Mkdir(filepath.Join("local", "report"), 0o755); err != nil { + t.Fatalf("Mkdir report: %v", err) + } + + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "folder_token=folder_root", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "files": []interface{}{ + map[string]interface{}{"token": "tok_file", "name": "report", "type": "file"}, + }, + "has_more": false, + }, + }, + }) + + err := mountAndRunDrive(t, DriveSync, []string{ + "+sync", + "--local-dir", "local", + "--folder-token", "folder_root", + "--as", "bot", + }, f, stdout) + if err == nil { + t.Fatalf("expected type conflict error\nstdout: %s", stdout.String()) + } + if !strings.Contains(err.Error(), "path type conflict") { + t.Fatalf("expected path type conflict error, got: %v\nstdout: %s", err, stdout.String()) + } + if !strings.Contains(err.Error(), "local directory") { + t.Fatalf("error should mention local directory, got: %v", err) + } +} diff --git a/shortcuts/drive/drive_task_result.go b/shortcuts/drive/drive_task_result.go index 5fc971404..d506e1b17 100644 --- a/shortcuts/drive/drive_task_result.go +++ b/shortcuts/drive/drive_task_result.go @@ -20,7 +20,7 @@ import ( var DriveTaskResult = common.Shortcut{ Service: "drive", Command: "+task_result", - Description: "Poll async task result for import, export, drive move/delete, wiki move, or wiki delete-space operations", + Description: "Poll async task result for import, export, drive move/delete, wiki move, wiki delete-space, or wiki delete-node operations", Risk: "read", // This shortcut multiplexes multiple backend APIs with different scope // requirements, so scenario-specific prechecks are handled in Validate. @@ -28,8 +28,8 @@ var DriveTaskResult = common.Shortcut{ AuthTypes: []string{"user", "bot"}, Flags: []common.Flag{ {Name: "ticket", Desc: "async task ticket (for import/export tasks)", Required: false}, - {Name: "task-id", Desc: "async task ID (for drive task_check, wiki_move, or wiki_delete_space tasks)", Required: false}, - {Name: "scenario", Desc: "task scenario: import, export, task_check, wiki_move, or wiki_delete_space", Required: true}, + {Name: "task-id", Desc: "async task ID (for drive task_check, wiki_move, wiki_delete_space, or wiki_delete_node tasks)", Required: false}, + {Name: "scenario", Desc: "task scenario: import, export, task_check, wiki_move, wiki_delete_space, or wiki_delete_node", Required: true}, {Name: "file-token", Desc: "source document token used for export task status lookup", Required: false}, }, Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { @@ -40,9 +40,10 @@ var DriveTaskResult = common.Shortcut{ "task_check": true, "wiki_move": true, "wiki_delete_space": true, + "wiki_delete_node": true, } if !validScenarios[scenario] { - return output.ErrValidation("unsupported scenario: %s. Supported scenarios: import, export, task_check, wiki_move, wiki_delete_space", scenario) + return output.ErrValidation("unsupported scenario: %s. Supported scenarios: import, export, task_check, wiki_move, wiki_delete_space, wiki_delete_node", scenario) } // Validate required params based on scenario @@ -54,7 +55,7 @@ var DriveTaskResult = common.Shortcut{ if err := validate.ResourceName(runtime.Str("ticket"), "--ticket"); err != nil { return output.ErrValidation("%s", err) } - case "task_check", "wiki_move", "wiki_delete_space": + case "task_check", "wiki_move", "wiki_delete_space", "wiki_delete_node": if runtime.Str("task-id") == "" { return output.ErrValidation("--task-id is required for %s scenario", scenario) } @@ -108,6 +109,11 @@ var DriveTaskResult = common.Shortcut{ Desc("[1] Query wiki delete-space task result"). Set("task_id", taskID). Params(map[string]interface{}{"task_type": "delete_space"}) + case "wiki_delete_node": + dry.GET("/open-apis/wiki/v2/tasks/:task_id"). + Desc("[1] Query wiki delete-node task result"). + Set("task_id", taskID). + Params(map[string]interface{}{"task_type": "delete_node"}) } return dry @@ -136,6 +142,8 @@ var DriveTaskResult = common.Shortcut{ result, err = queryWikiMoveTask(runtime, taskID) case "wiki_delete_space": result, err = queryWikiDeleteSpaceTask(runtime, taskID) + case "wiki_delete_node": + result, err = queryWikiDeleteNodeTask(runtime, taskID) } if err != nil { @@ -236,7 +244,7 @@ func validateDriveTaskResultScopes(ctx context.Context, runtime *common.RuntimeC switch scenario { case "import", "export", "task_check": required = []string{"drive:drive.metadata:readonly"} - case "wiki_move", "wiki_delete_space": + case "wiki_move", "wiki_delete_space", "wiki_delete_node": required = []string{"wiki:space:read"} } @@ -540,3 +548,64 @@ func queryWikiDeleteSpaceTask(runtime *common.RuntimeContext, taskID string) (ma "status_msg": label, }, nil } + +// queryWikiDeleteNodeTask returns the normalized status of an async wiki +// delete-node task. For historical reasons the gateway stashes delete-node +// status under the generic `simple_task_result` key (NOT `delete_node_result`), +// and that object only carries `status` — there is no `status_msg`, so the +// label falls back to the status code. Mirrors queryWikiDeleteSpaceTask; +// intentionally duplicated here (rather than importing the wiki package) to +// keep drive from depending on shortcuts/wiki. +func queryWikiDeleteNodeTask(runtime *common.RuntimeContext, taskID string) (map[string]interface{}, error) { + if err := validate.ResourceName(taskID, "--task-id"); err != nil { + return nil, output.ErrValidation("%s", err) + } + + data, err := runtime.CallAPI( + "GET", + fmt.Sprintf("/open-apis/wiki/v2/tasks/%s", validate.EncodePathSegment(taskID)), + map[string]interface{}{"task_type": "delete_node"}, + nil, + ) + if err != nil { + return nil, err + } + + task := common.GetMap(data, "task") + if task == nil { + return nil, output.Errorf(output.ExitAPI, "api_error", "wiki task response missing task") + } + + resolvedTaskID := common.GetString(task, "task_id") + if resolvedTaskID == "" { + resolvedTaskID = taskID + } + + result := common.GetMap(task, "simple_task_result") + var status string + if result != nil { + status = common.GetString(result, "status") + } + + // Keep in sync with wiki.parseWikiAsyncTaskStatus / wikiAsyncTaskStatus + // classification (intentionally duplicated to avoid a drive→wiki import — + // see the doc comment above). If the success/failed/processing rules change + // there, mirror the change here. + lowered := strings.ToLower(strings.TrimSpace(status)) + ready := lowered == "success" + failed := lowered == "failure" || lowered == "failed" + + resolvedStatus := strings.TrimSpace(status) + if resolvedStatus == "" { + resolvedStatus = "processing" + } + + return map[string]interface{}{ + "scenario": "wiki_delete_node", + "task_id": resolvedTaskID, + "ready": ready, + "failed": failed, + "status": resolvedStatus, + "status_msg": resolvedStatus, + }, nil +} diff --git a/shortcuts/drive/drive_task_result_test.go b/shortcuts/drive/drive_task_result_test.go index 69bb91cf5..79e43d76d 100644 --- a/shortcuts/drive/drive_task_result_test.go +++ b/shortcuts/drive/drive_task_result_test.go @@ -417,10 +417,10 @@ func TestDriveTaskResultWikiMoveIncludesFlattenedNodeFields(t *testing.T) { func TestValidateDriveTaskResultScopesWikiScenariosRequireWikiScope(t *testing.T) { t.Parallel() - // wiki_move and wiki_delete_space both read wiki task status, so both must - // require wiki:space:read. A single table keeps this invariant explicit - // without duplicating near-identical test functions per scenario. - for _, scenario := range []string{"wiki_move", "wiki_delete_space"} { + // wiki_move, wiki_delete_space and wiki_delete_node all read wiki task + // status, so all must require wiki:space:read. A single table keeps this + // invariant explicit without duplicating near-identical test functions. + for _, scenario := range []string{"wiki_move", "wiki_delete_space", "wiki_delete_node"} { t.Run(scenario+"/rejects missing scope", func(t *testing.T) { t.Parallel() runtime := newDriveTaskResultRuntimeWithScopes(t, core.AsUser, "drive:drive.metadata:readonly") @@ -518,6 +518,105 @@ func TestDriveTaskResultWikiDeleteSpaceSuccess(t *testing.T) { } } +func TestDriveTaskResultDryRunWikiDeleteNodeIncludesTaskTypeParam(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "drive +task_result"} + cmd.Flags().String("scenario", "", "") + cmd.Flags().String("ticket", "", "") + cmd.Flags().String("task-id", "", "") + cmd.Flags().String("file-token", "", "") + if err := cmd.Flags().Set("scenario", "wiki_delete_node"); err != nil { + t.Fatalf("set --scenario: %v", err) + } + if err := cmd.Flags().Set("task-id", "task_del_node_1"); err != nil { + t.Fatalf("set --task-id: %v", err) + } + + runtime := common.TestNewRuntimeContext(cmd, nil) + dry := DriveTaskResult.DryRun(context.Background(), runtime) + if dry == nil { + t.Fatal("DryRun returned nil") + } + + data, err := json.Marshal(dry) + if err != nil { + t.Fatalf("marshal dry run: %v", err) + } + + var got struct { + API []struct { + Params map[string]interface{} `json:"params"` + } `json:"api"` + } + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal dry run json: %v", err) + } + if len(got.API) != 1 { + t.Fatalf("expected 1 API call, got %d", len(got.API)) + } + if got.API[0].Params["task_type"] != "delete_node" { + t.Fatalf("wiki delete-node params = %#v, want task_type=delete_node", got.API[0].Params) + } +} + +func TestDriveTaskResultWikiDeleteNodeSuccess(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/wiki/v2/tasks/task_del_node_1", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "task": map[string]interface{}{ + // Gateway returns delete-node status under the generic + // simple_task_result key (NOT delete_node_result), and it + // carries only `status` (no status_msg). + "simple_task_result": map[string]interface{}{ + "status": "success", + }, + }, + }, + }, + }) + + err := mountAndRunDrive(t, DriveTaskResult, []string{ + "+task_result", + "--scenario", "wiki_delete_node", + "--task-id", "task_del_node_1", + "--as", "user", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data := decodeDriveEnvelope(t, stdout) + if data["scenario"] != "wiki_delete_node" || data["task_id"] != "task_del_node_1" { + t.Fatalf("unexpected wiki_delete_node envelope: %#v", data) + } + if data["ready"] != true || data["failed"] != false || data["status"] != "success" { + t.Fatalf("unexpected readiness fields: %#v", data) + } + // simple_task_result has no status_msg; label must fall back to status. + if data["status_msg"] != "success" { + t.Fatalf("status_msg = %#v, want fallback to status", data["status_msg"]) + } +} + +func TestDriveTaskResultRejectsUnknownScenarioListsWikiDeleteNode(t *testing.T) { + f, stdout, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + + err := mountAndRunDrive(t, DriveTaskResult, []string{ + "+task_result", + "--scenario", "bogus", + "--task-id", "task_x", + "--as", "user", + }, f, stdout) + if err == nil || !strings.Contains(err.Error(), "wiki_delete_node") { + t.Fatalf("expected unsupported-scenario error listing wiki_delete_node, got %v", err) + } +} + func TestValidateDriveTaskResultScopesDriveScenariosRequireDriveScope(t *testing.T) { t.Parallel() diff --git a/shortcuts/drive/drive_version.go b/shortcuts/drive/drive_version.go new file mode 100644 index 000000000..bcf3fcf0c --- /dev/null +++ b/shortcuts/drive/drive_version.go @@ -0,0 +1,454 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package drive + +import ( + "context" + "fmt" + "io" + "math" + "net/http" + "path/filepath" + "regexp" + "strconv" + "strings" + + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" + + "github.com/larksuite/cli/extension/fileio" + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/util" + "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/shortcuts/common" +) + +var driveVersionNumberRe = regexp.MustCompile(`^\d{1,19}$`) + +type driveVersionHistorySpec struct { + FileToken string + Limit int + Cursor string +} + +func validateDriveNumericValue(value, flagName, valueLabel string) error { + value = strings.TrimSpace(value) + if value == "" { + return output.ErrValidation("%s cannot be empty", flagName) + } + if !driveVersionNumberRe.MatchString(value) { + return output.ErrValidation("%s must be a numeric %s", flagName, valueLabel) + } + return nil +} + +func validateDriveVersionValue(value, flagName string) error { + return validateDriveNumericValue(value, flagName, "version string") +} + +func validateDriveCursorValue(value, flagName string) error { + return validateDriveNumericValue(value, flagName, "pagination cursor") +} + +func validateDriveVersionHistorySpec(spec driveVersionHistorySpec) error { + if err := validate.ResourceName(spec.FileToken, "--file-token"); err != nil { + return output.ErrValidation("%s", err) + } + if spec.Limit < 1 || spec.Limit > 200 { + return output.ErrValidation("invalid --limit %d: must be between 1 and 200", spec.Limit) + } + if spec.Cursor != "" { + if err := validateDriveCursorValue(spec.Cursor, "--cursor"); err != nil { + return err + } + } + return nil +} + +func driveVersionHistoryParams(spec driveVersionHistorySpec) map[string]interface{} { + params := map[string]interface{}{ + "only_tag": true, + "page_size": spec.Limit, + } + if spec.Cursor != "" { + params["last_edit_time"] = spec.Cursor + } + return params +} + +func driveVersionActionTypeLabel(raw int) string { + switch raw { + case 1: + return "upload" + case 2: + return "rename" + case 3: + return "delete_version" + case 4: + return "revert" + default: + return fmt.Sprintf("type_%d", raw) + } +} + +func driveVersionFieldString(m map[string]interface{}, key string) string { + if m == nil { + return "" + } + if s := common.GetString(m, key); s != "" { + return s + } + f, ok := util.ToFloat64(m[key]) + if !ok || math.IsInf(f, 0) || math.IsNaN(f) { + return "" + } + if math.Trunc(f) == f { + return strconv.FormatInt(int64(f), 10) + } + return strconv.FormatFloat(f, 'f', -1, 64) +} + +func transformDriveVersionHistory(items []interface{}) []map[string]interface{} { + versions := make([]map[string]interface{}, 0, len(items)) + for _, item := range items { + m, ok := item.(map[string]interface{}) + if !ok { + continue + } + version := common.GetString(m, "version") + if version == "" { + continue + } + versions = append(versions, map[string]interface{}{ + "version": version, + "name": common.GetString(m, "name"), + "edited_at": driveVersionFieldString(m, "edit_time"), + "edited_by": common.GetString(m, "edit_user_id"), + "size_bytes": int64(common.GetFloat(m, "size")), + "action_type": driveVersionActionTypeLabel(int(common.GetFloat(m, "type"))), + "is_deleted": common.GetBool(m, "is_deleted"), + "tag": int(common.GetFloat(m, "tag")), + }) + } + return versions +} + +func nextDriveVersionCursor(items []interface{}, hasMore bool) string { + if !hasMore || len(items) == 0 { + return "" + } + last, _ := items[len(items)-1].(map[string]interface{}) + return driveVersionFieldString(last, "edit_time") +} + +var DriveVersionHistory = common.Shortcut{ + Service: "drive", + Command: "+version-history", + Description: "List the version history of a Drive file", + Risk: "read", + Scopes: []string{"drive:file:download"}, + AuthTypes: []string{"user", "bot"}, + HasFormat: true, + Flags: []common.Flag{ + {Name: "file-token", Desc: "target file token", Required: true}, + {Name: "limit", Desc: "max versions to return (1-200)", Type: "int", Default: "20"}, + {Name: "cursor", Desc: "pagination cursor from the previous page's next_cursor"}, + }, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + return validateDriveVersionHistorySpec(driveVersionHistorySpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Limit: runtime.Int("limit"), + Cursor: strings.TrimSpace(runtime.Str("cursor")), + }) + }, + DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + spec := driveVersionHistorySpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Limit: runtime.Int("limit"), + Cursor: strings.TrimSpace(runtime.Str("cursor")), + } + return common.NewDryRunAPI(). + Desc("Query version history with only_tag=true and optional pagination cursor"). + GET("/open-apis/drive/v1/files/:file_token/history"). + Set("file_token", spec.FileToken). + Params(driveVersionHistoryParams(spec)) + }, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + spec := driveVersionHistorySpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Limit: runtime.Int("limit"), + Cursor: strings.TrimSpace(runtime.Str("cursor")), + } + + data, err := runtime.CallAPI( + http.MethodGet, + fmt.Sprintf("/open-apis/drive/v1/files/%s/history", validate.EncodePathSegment(spec.FileToken)), + driveVersionHistoryParams(spec), + nil, + ) + if err != nil { + return err + } + + items := common.GetSlice(data, "items") + hasMore := common.GetBool(data, "has_more") + out := map[string]interface{}{ + "versions": transformDriveVersionHistory(items), + "has_more": hasMore, + } + if nextCursor := nextDriveVersionCursor(items, hasMore); nextCursor != "" { + out["next_cursor"] = nextCursor + } + + runtime.OutFormat(out, nil, nil) + return nil + }, +} + +type driveVersionGetSpec struct { + FileToken string + Version string + Output string + Overwrite bool +} + +func validateDriveVersionGetSpec(runtime *common.RuntimeContext, spec driveVersionGetSpec) error { + if err := validate.ResourceName(spec.FileToken, "--file-token"); err != nil { + return output.ErrValidation("%s", err) + } + if err := validateDriveVersionValue(spec.Version, "--version"); err != nil { + return err + } + if spec.Output == "" { + return nil + } + if _, err := validate.SafeOutputPath(spec.Output); err != nil { + return output.ErrValidation("unsafe output path: %s", err) + } + return nil +} + +func driveVersionGetOutputIsDirectory(runtime *common.RuntimeContext, outputPath string) bool { + if strings.HasSuffix(outputPath, "/") || strings.HasSuffix(outputPath, "\\") { + return true + } + info, err := runtime.FileIO().Stat(outputPath) + return err == nil && info.IsDir() +} + +func prettyPrintDriveVersionSavedFile(w io.Writer, data map[string]interface{}) { + fmt.Fprintf(w, "file_token: %s\n", common.GetString(data, "file_token")) + fmt.Fprintf(w, "version: %s\n", common.GetString(data, "version")) + fmt.Fprintf(w, "file_name: %s\n", common.GetString(data, "file_name")) + fmt.Fprintf(w, "saved_path: %s\n", common.GetString(data, "saved_path")) + fmt.Fprintf(w, "size_bytes: %d\n", int64(common.GetFloat(data, "size_bytes"))) +} + +var DriveVersionGet = common.Shortcut{ + Service: "drive", + Command: "+version-get", + Description: "Download a specific version of a Drive file", + Risk: "read", + Scopes: []string{"drive:file:download"}, + AuthTypes: []string{"user", "bot"}, + HasFormat: true, + Flags: []common.Flag{ + {Name: "file-token", Desc: "target file token", Required: true}, + {Name: "version", Desc: "version from drive +version-history (not tag)", Required: true}, + {Name: "output", Desc: "local save path or directory; omit to save into the current directory using the server filename"}, + {Name: "overwrite", Type: "bool", Desc: "overwrite existing output file"}, + }, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + return validateDriveVersionGetSpec(runtime, driveVersionGetSpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Version: strings.TrimSpace(runtime.Str("version")), + Output: strings.TrimSpace(runtime.Str("output")), + Overwrite: runtime.Bool("overwrite"), + }) + }, + DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + spec := driveVersionGetSpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Version: strings.TrimSpace(runtime.Str("version")), + Output: strings.TrimSpace(runtime.Str("output")), + } + outputPath := spec.Output + if outputPath == "" { + outputPath = "." + } + return common.NewDryRunAPI(). + Desc("Download a specific file version; when --output is omitted the CLI saves into the current directory using the server filename"). + GET("/open-apis/drive/v1/files/:file_token/download"). + Set("file_token", spec.FileToken). + Set("output", outputPath). + Params(map[string]interface{}{"version": spec.Version}) + }, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + spec := driveVersionGetSpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Version: strings.TrimSpace(runtime.Str("version")), + Output: strings.TrimSpace(runtime.Str("output")), + Overwrite: runtime.Bool("overwrite"), + } + + resp, err := runtime.DoAPIStream(ctx, &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: fmt.Sprintf("/open-apis/drive/v1/files/%s/download", validate.EncodePathSegment(spec.FileToken)), + QueryParams: larkcore.QueryParams{ + "version": []string{spec.Version}, + }, + }) + if err != nil { + return output.ErrNetwork("download failed: %s", err) + } + defer resp.Body.Close() + + fileName := common.ResolveDownloadFileName(resp.Header, spec.FileToken) + fileName, _ = common.AutoAppendDownloadExtension(fileName, resp.Header, "") + outputPath := spec.Output + if outputPath == "" { + outputPath = "." + } + if driveVersionGetOutputIsDirectory(runtime, outputPath) { + outputPath = filepath.Join(outputPath, fileName) + } else { + outputPath, _ = common.AutoAppendDownloadExtension(outputPath, resp.Header, "") + } + if _, resolveErr := runtime.ResolveSavePath(outputPath); resolveErr != nil { + return output.ErrValidation("unsafe output path: %s", resolveErr) + } + if _, statErr := runtime.FileIO().Stat(outputPath); statErr == nil && !spec.Overwrite { + return output.ErrValidation("output file already exists: %s (use --overwrite to replace)", outputPath) + } + + result, err := runtime.FileIO().Save(outputPath, fileio.SaveOptions{ + ContentType: resp.Header.Get("Content-Type"), + ContentLength: resp.ContentLength, + }, resp.Body) + if err != nil { + return common.WrapSaveErrorByCategory(err, "io") + } + + savedPath, _ := runtime.ResolveSavePath(outputPath) + if savedPath == "" { + savedPath = outputPath + } + out := map[string]interface{}{ + "file_token": spec.FileToken, + "version": spec.Version, + "file_name": filepath.Base(outputPath), + "saved_path": savedPath, + "size_bytes": result.Size(), + } + runtime.OutFormat(out, nil, func(w io.Writer) { + prettyPrintDriveVersionSavedFile(w, out) + }) + return nil + }, +} + +type driveVersionMutationSpec struct { + FileToken string + Version string +} + +func validateDriveVersionMutationSpec(spec driveVersionMutationSpec) error { + if err := validate.ResourceName(spec.FileToken, "--file-token"); err != nil { + return output.ErrValidation("%s", err) + } + return validateDriveVersionValue(spec.Version, "--version") +} + +var DriveVersionRevert = common.Shortcut{ + Service: "drive", + Command: "+version-revert", + Description: "Revert a Drive file to a specific historical version", + Risk: "write", + Scopes: []string{"drive:file:upload"}, + AuthTypes: []string{"user", "bot"}, + Flags: []common.Flag{ + {Name: "file-token", Desc: "target file token", Required: true}, + {Name: "version", Desc: "version from drive +version-history to revert to (not tag)", Required: true}, + }, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + return validateDriveVersionMutationSpec(driveVersionMutationSpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Version: strings.TrimSpace(runtime.Str("version")), + }) + }, + DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + spec := driveVersionMutationSpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Version: strings.TrimSpace(runtime.Str("version")), + } + return common.NewDryRunAPI(). + Desc("Revert the current file to a specified historical version"). + POST("/open-apis/drive/v1/files/:file_token/revert"). + Set("file_token", spec.FileToken). + Body(map[string]interface{}{"version": spec.Version}) + }, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + spec := driveVersionMutationSpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Version: strings.TrimSpace(runtime.Str("version")), + } + if _, err := runtime.CallAPI( + http.MethodPost, + fmt.Sprintf("/open-apis/drive/v1/files/%s/revert", validate.EncodePathSegment(spec.FileToken)), + nil, + map[string]interface{}{"version": spec.Version}, + ); err != nil { + return err + } + + runtime.Out(map[string]interface{}{}, nil) + return nil + }, +} + +var DriveVersionDelete = common.Shortcut{ + Service: "drive", + Command: "+version-delete", + Description: "Delete a specific historical version of a Drive file", + Risk: "high-risk-write", + Scopes: []string{"drive:file:upload"}, + AuthTypes: []string{"user", "bot"}, + Flags: []common.Flag{ + {Name: "file-token", Desc: "target file token", Required: true}, + {Name: "version", Desc: "version from drive +version-history to delete (not tag)", Required: true}, + }, + Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { + return validateDriveVersionMutationSpec(driveVersionMutationSpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Version: strings.TrimSpace(runtime.Str("version")), + }) + }, + DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { + spec := driveVersionMutationSpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Version: strings.TrimSpace(runtime.Str("version")), + } + return common.NewDryRunAPI(). + Desc("Permanently delete a historical file version"). + POST("/open-apis/drive/v1/files/:file_token/version_del"). + Set("file_token", spec.FileToken). + Body(map[string]interface{}{"version": spec.Version}) + }, + Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + spec := driveVersionMutationSpec{ + FileToken: strings.TrimSpace(runtime.Str("file-token")), + Version: strings.TrimSpace(runtime.Str("version")), + } + if _, err := runtime.CallAPI( + http.MethodPost, + fmt.Sprintf("/open-apis/drive/v1/files/%s/version_del", validate.EncodePathSegment(spec.FileToken)), + nil, + map[string]interface{}{"version": spec.Version}, + ); err != nil { + return err + } + + runtime.Out(map[string]interface{}{}, nil) + return nil + }, +} diff --git a/shortcuts/drive/drive_version_test.go b/shortcuts/drive/drive_version_test.go new file mode 100644 index 000000000..309243ad4 --- /dev/null +++ b/shortcuts/drive/drive_version_test.go @@ -0,0 +1,546 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package drive + +import ( + "encoding/json" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/httpmock" + "github.com/larksuite/cli/shortcuts/common" +) + +func TestValidateDriveVersionHistorySpec(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + spec driveVersionHistorySpec + wantErr string + }{ + { + name: "ok", + spec: driveVersionHistorySpec{FileToken: "box123", Limit: 20, Cursor: "1777013761763"}, + }, + { + name: "bad limit", + spec: driveVersionHistorySpec{FileToken: "box123", Limit: 0}, + wantErr: "invalid --limit", + }, + { + name: "bad cursor", + spec: driveVersionHistorySpec{FileToken: "box123", Limit: 20, Cursor: "abc"}, + wantErr: "--cursor must be a numeric pagination cursor", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateDriveVersionHistorySpec(tt.spec) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + return + } + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got %v", tt.wantErr, err) + } + }) + } +} + +func TestDriveVersionHistoryExecuteTransformsResponse(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/box_hist/history", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "items": []map[string]interface{}{ + { + "version": "7633658129540910621", + "name": "report.md", + "edit_time": 1777013761763, + "edit_user_id": "ou_hist_1", + "size": "12345", + "type": 1, + "is_deleted": false, + "tag": 7, + }, + { + "version": "7633658129540910622", + "name": "report.md", + "edit_time": 1777013770000, + "edit_user_id": "ou_hist_2", + "size": "12346", + "type": 4, + "is_deleted": true, + "tag": 8, + }, + }, + "has_more": true, + }, + }, + }) + + err := mountAndRunDrive(t, DriveVersionHistory, []string{ + "+version-history", + "--file-token", "box_hist", + "--limit", "2", + "--cursor", "1777013000000", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var envelope struct { + Data map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(stdout.Bytes(), &envelope); err != nil { + t.Fatalf("unmarshal stdout: %v", err) + } + + if got := common.GetBool(envelope.Data, "has_more"); !got { + t.Fatalf("has_more = %v, want true", got) + } + if got := common.GetString(envelope.Data, "next_cursor"); got != "1777013770000" { + t.Fatalf("next_cursor = %q, want %q", got, "1777013770000") + } + + versions, _ := envelope.Data["versions"].([]interface{}) + if len(versions) != 2 { + t.Fatalf("len(versions) = %d, want 2", len(versions)) + } + first, _ := versions[0].(map[string]interface{}) + if got := common.GetString(first, "version"); got != "7633658129540910621" { + t.Fatalf("first.version = %q", got) + } + if got := common.GetString(first, "edited_at"); got != "1777013761763" { + t.Fatalf("first.edited_at = %q, want %q", got, "1777013761763") + } + if got := common.GetString(first, "action_type"); got != "upload" { + t.Fatalf("first.action_type = %q, want upload", got) + } + if got := common.GetBool(first, "is_deleted"); got { + t.Fatalf("first.is_deleted = %v, want false", got) + } + second, _ := versions[1].(map[string]interface{}) + if got := common.GetString(second, "action_type"); got != "revert" { + t.Fatalf("second.action_type = %q, want revert", got) + } + if got := common.GetBool(second, "is_deleted"); !got { + t.Fatalf("second.is_deleted = %v, want true", got) + } +} + +func TestDriveVersionGetWritesSpecificVersion(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/box_ver/download?version=7633658129540910621", + Status: 200, + RawBody: []byte("versioned-data"), + Headers: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Disposition": []string{`attachment; filename="report-v7.md"`}, + }, + }) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + err := mountAndRunDrive(t, DriveVersionGet, []string{ + "+version-get", + "--file-token", "box_ver", + "--version", "7633658129540910621", + "--output", "version.bin", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(filepath.Join(tmpDir, "version.bin")) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(data) != "versioned-data" { + t.Fatalf("downloaded content = %q", string(data)) + } + if !strings.Contains(stdout.String(), `"version": "7633658129540910621"`) { + t.Fatalf("stdout missing version: %s", stdout.String()) + } + if !strings.Contains(stdout.String(), `"saved_path":`) { + t.Fatalf("stdout missing saved_path: %s", stdout.String()) + } +} + +func TestDriveVersionGetSavesToCurrentDirectoryWhenOutputIsOmitted(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/box_ver/download?version=7633658129540910621", + Status: 200, + RawBody: []byte("# hello\n"), + Headers: http.Header{ + "Content-Type": []string{"text/plain; charset=utf-8"}, + "Content-Disposition": []string{`attachment; filename="report-v7.md"`}, + }, + }) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + err := mountAndRunDrive(t, DriveVersionGet, []string{ + "+version-get", + "--file-token", "box_ver", + "--version", "7633658129540910621", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(filepath.Join(tmpDir, "report-v7.md")) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(data) != "# hello\n" { + t.Fatalf("downloaded content = %q", string(data)) + } + if !strings.Contains(stdout.String(), `"file_name": "report-v7.md"`) { + t.Fatalf("stdout missing file_name: %s", stdout.String()) + } + if strings.Contains(stdout.String(), `"content":`) { + t.Fatalf("stdout unexpectedly contains content payload: %s", stdout.String()) + } +} + +func TestDriveVersionGetRejectsExistingFileWithoutOverwrite(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/box_ver/download?version=7633658129540910621", + Status: 200, + RawBody: []byte("versioned-data"), + Headers: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Disposition": []string{`attachment; filename="report-v7.md"`}, + }, + }) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.WriteFile("version.bin", []byte("existing"), 0o644); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + err := mountAndRunDrive(t, DriveVersionGet, []string{ + "+version-get", + "--file-token", "box_ver", + "--version", "7633658129540910621", + "--output", "version.bin", + "--as", "bot", + }, f, stdout) + if err == nil || !strings.Contains(err.Error(), "output file already exists") { + t.Fatalf("expected output exists error, got %v", err) + } +} + +func TestDriveVersionGetOverwritesExistingFileWhenRequested(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/box_ver/download?version=7633658129540910621", + Status: 200, + RawBody: []byte("versioned-data"), + Headers: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Disposition": []string{`attachment; filename="report-v7.md"`}, + }, + }) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.WriteFile("version.bin", []byte("existing"), 0o644); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + err := mountAndRunDrive(t, DriveVersionGet, []string{ + "+version-get", + "--file-token", "box_ver", + "--version", "7633658129540910621", + "--output", "version.bin", + "--overwrite", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(filepath.Join(tmpDir, "version.bin")) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(data) != "versioned-data" { + t.Fatalf("downloaded content = %q", string(data)) + } +} + +func TestDriveVersionGetSavesUsingRemoteNameWhenOutputIsExistingDirectory(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/box_ver/download?version=7633658129540910621", + Status: 200, + RawBody: []byte("versioned-data"), + Headers: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Disposition": []string{`attachment; filename="report-v7.md"`}, + }, + }) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + if err := os.MkdirAll("downloads", 0o755); err != nil { + t.Fatalf("MkdirAll() error: %v", err) + } + + err := mountAndRunDrive(t, DriveVersionGet, []string{ + "+version-get", + "--file-token", "box_ver", + "--version", "7633658129540910621", + "--output", "downloads", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(filepath.Join("downloads", "report-v7.md")) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(data) != "versioned-data" { + t.Fatalf("downloaded content = %q", string(data)) + } +} + +func TestDriveVersionGetAppendsExtensionFromContentDispositionFilename(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + reg.Register(&httpmock.Stub{ + Method: "GET", + URL: "/open-apis/drive/v1/files/box_ver/download?version=7633658129540910621", + Status: 200, + RawBody: []byte("versioned-data"), + Headers: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Disposition": []string{`attachment; filename="report-v7.md"`}, + }, + }) + + tmpDir := t.TempDir() + withDriveWorkingDir(t, tmpDir) + + err := mountAndRunDrive(t, DriveVersionGet, []string{ + "+version-get", + "--file-token", "box_ver", + "--version", "7633658129540910621", + "--output", "artifact", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(filepath.Join(tmpDir, "artifact.md")) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(data) != "versioned-data" { + t.Fatalf("downloaded content = %q", string(data)) + } + if !strings.Contains(stdout.String(), `"file_name": "artifact.md"`) { + t.Fatalf("stdout missing local file_name: %s", stdout.String()) + } +} + +func TestDriveVersionRevertPostsVersionAndReturnsEmptyData(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + revertStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/box_rev/revert", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{}, + }, + } + reg.Register(revertStub) + + err := mountAndRunDrive(t, DriveVersionRevert, []string{ + "+version-revert", + "--file-token", "box_rev", + "--version", "7633658129540910621", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + body := decodeCapturedJSONBody(t, revertStub) + if got := common.GetString(body, "version"); got != "7633658129540910621" { + t.Fatalf("body.version = %q, want 7633658129540910621", got) + } + if !strings.Contains(stdout.String(), `"data": {}`) { + t.Fatalf("stdout = %s, want empty data object", stdout.String()) + } +} + +func TestDriveVersionDeletePostsVersionAndReturnsEmptyData(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) + deleteStub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/drive/v1/files/box_del/version_del", + Body: map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{}, + }, + } + reg.Register(deleteStub) + + err := mountAndRunDrive(t, DriveVersionDelete, []string{ + "+version-delete", + "--file-token", "box_del", + "--version", "7633658129540910621", + "--yes", + "--as", "bot", + }, f, stdout) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + body := decodeCapturedJSONBody(t, deleteStub) + if got := common.GetString(body, "version"); got != "7633658129540910621" { + t.Fatalf("body.version = %q, want 7633658129540910621", got) + } + if !strings.Contains(stdout.String(), `"data": {}`) { + t.Fatalf("stdout = %s, want empty data object", stdout.String()) + } +} + +func TestDriveVersionRevertDoesNotAcceptYes(t *testing.T) { + t.Parallel() + + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + + err := mountAndRunDrive(t, DriveVersionRevert, []string{ + "+version-revert", + "--file-token", "box_rev", + "--version", "7633658129540910621", + "--yes", + "--as", "bot", + }, f, nil) + if err == nil { + t.Fatal("expected unknown flag error, got nil") + } + if !strings.Contains(err.Error(), "unknown flag: --yes") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestDriveVersionDeleteRequiresYes(t *testing.T) { + t.Parallel() + + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + + err := mountAndRunDrive(t, DriveVersionDelete, []string{ + "+version-delete", + "--file-token", "box_del", + "--version", "7633658129540910621", + "--as", "bot", + }, f, nil) + if err == nil { + t.Fatal("expected confirmation error, got nil") + } + if !strings.Contains(err.Error(), "requires confirmation") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestDriveVersionShortcutsSupportUserDryRun(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + shortcut common.Shortcut + args []string + }{ + { + name: "history", + shortcut: DriveVersionHistory, + args: []string{ + "+version-history", + "--file-token", "box_hist", + "--limit", "2", + "--cursor", "1777013000000", + "--as", "user", + "--dry-run", + }, + }, + { + name: "get", + shortcut: DriveVersionGet, + args: []string{ + "+version-get", + "--file-token", "box_get", + "--version", "7633658129540910621", + "--output", "version.bin", + "--as", "user", + "--dry-run", + }, + }, + { + name: "revert", + shortcut: DriveVersionRevert, + args: []string{ + "+version-revert", + "--file-token", "box_rev", + "--version", "7633658129540910621", + "--as", "user", + "--dry-run", + }, + }, + { + name: "delete", + shortcut: DriveVersionDelete, + args: []string{ + "+version-delete", + "--file-token", "box_del", + "--version", "7633658129540910621", + "--as", "user", + "--dry-run", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, driveTestConfig()) + + err := mountAndRunDrive(t, tt.shortcut, tt.args, f, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} diff --git a/shortcuts/drive/shortcuts.go b/shortcuts/drive/shortcuts.go index fcd3d805e..dcf231e7c 100644 --- a/shortcuts/drive/shortcuts.go +++ b/shortcuts/drive/shortcuts.go @@ -16,13 +16,19 @@ func Shortcuts() []common.Shortcut { DriveExport, DriveExportDownload, DriveImport, + DriveVersionHistory, + DriveVersionGet, + DriveVersionRevert, + DriveVersionDelete, DriveMove, DriveDelete, DriveStatus, DrivePush, DrivePull, + DriveSync, DriveTaskResult, DriveApplyPermission, DriveSearch, + DriveInspect, } } diff --git a/shortcuts/drive/shortcuts_test.go b/shortcuts/drive/shortcuts_test.go index 3116c0c5a..3707fc096 100644 --- a/shortcuts/drive/shortcuts_test.go +++ b/shortcuts/drive/shortcuts_test.go @@ -15,6 +15,10 @@ func TestShortcutsIncludesExpectedCommands(t *testing.T) { "+create-folder", "+create-shortcut", "+download", + "+version-history", + "+version-get", + "+version-revert", + "+version-delete", "+add-comment", "+export", "+export-download", @@ -24,9 +28,11 @@ func TestShortcutsIncludesExpectedCommands(t *testing.T) { "+status", "+push", "+pull", + "+sync", "+task_result", "+apply-permission", "+search", + "+inspect", } if len(got) != len(want) { diff --git a/shortcuts/mail/body_file.go b/shortcuts/mail/body_file.go new file mode 100644 index 000000000..2270e6700 --- /dev/null +++ b/shortcuts/mail/body_file.go @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package mail + +import ( + "io" + "strings" + + "github.com/larksuite/cli/extension/fileio" + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/shortcuts/common" +) + +// bodyFileFlag is the shared `--body-file` flag declaration reused by every +// compose shortcut (+send / +draft-create / +reply / +reply-all / +forward). +// All six shortcuts honour the same mutual-exclusion contract with `--body` +// and the cwd-subtree path safety rule. The flag is intentionally NOT +// shared with `+lint-html` because that command's description differs +// ("HTML to lint" vs "email body") in a way that is more readable when +// authored per-shortcut. `+draft-edit` does not expose `--body-file` either +// — its body ops flow through `--patch-file` JSON whose `value` field is +// the natural file-based entry point for large bodies. +var bodyFileFlag = common.Flag{ + Name: "body-file", + Desc: "Path (relative, within cwd subtree) to a file containing the email body HTML. Mutually exclusive with --body. Size capped at 32 MB.", + Input: []string{common.File}, +} + +// maxBodyFileSize caps the size of a `--body-file` HTML input. The compose +// path's downstream EML limit is 25 MB (helpers.go MAX_EML_BYTES); we allow a +// bit more headroom here (32 MB) so a body close to the limit still loads +// before the downstream check fires with a clearer error message. The cap +// prevents an `io.ReadAll` from blowing memory on a misdirected gigabyte +// file. +const maxBodyFileSize = 32 * 1024 * 1024 // 32 MB + +// validateBodyFileMutex enforces the `--body` / `--body-file` mutual +// exclusion + cwd-subtree path safety. Compose shortcuts call this in +// their Validate phase so AI / users see a clear error before any work +// runs. Pass the shortcut's RuntimeContext-resolved flag values directly: +// `bodyFlag` is the `--body` value (may be empty), `bodyFile` is the +// trimmed `--body-file` value, and `validatePath` is the +// runtime.ValidatePath bound function used to enforce the relative-path +// rule (cwd-subtree only; no absolute / `..` traversal). +// +// Returns an ErrValidation error when either invariant is violated, nil +// otherwise. The "exactly one of {--body, --body-file}" check is +// shortcut-specific (some shortcuts allow neither, e.g. `+forward` with +// no explicit body) and is therefore left to the caller. +func validateBodyFileMutex(bodyFlag, bodyFile string, validatePath func(string) error) error { + bodyEmpty := strings.TrimSpace(bodyFlag) == "" + if !bodyEmpty && bodyFile != "" { + return output.ErrValidation("--body and --body-file are mutually exclusive; pass exactly one") + } + if bodyFile != "" { + if err := validatePath(bodyFile); err != nil { + return output.ErrValidation("--body-file: %v", err) + } + } + return nil +} + +// resolveBodyFromFlags returns the body content from --body or --body-file. +// Validate has already enforced mutual exclusion via validateBodyFileMutex, +// so exactly one is set (or neither when a template / parent message +// supplies the body). Returns ("", nil) when neither flag is set so +// downstream code can decide whether the empty body is allowed. +func resolveBodyFromFlags(runtime *common.RuntimeContext) (string, error) { + if body := runtime.Str("body"); strings.TrimSpace(body) != "" { + return body, nil + } + path := strings.TrimSpace(runtime.Str("body-file")) + if path == "" { + return "", nil + } + return readBodyFile(runtime.FileIO(), path) +} + +// readBodyFile loads --body-file content with a size cap. Returns an +// ErrValidation error if the file exceeds maxBodyFileSize or any IO error +// occurs. The size check uses io.LimitReader(maxBodyFileSize+1) so any +// over-cap byte is observable without reading the whole file. +// +// Callers MUST have run runtime.ValidatePath(path) on `path` first — the +// helper only opens the file via the supplied FileIO and does not repeat +// the cwd-subtree safety check. +func readBodyFile(fio fileio.FileIO, path string) (string, error) { + f, err := fio.Open(path) + if err != nil { + return "", output.ErrValidation("open --body-file %s: %v", path, err) + } + defer f.Close() + buf, err := io.ReadAll(io.LimitReader(f, maxBodyFileSize+1)) + if err != nil { + return "", output.ErrValidation("read --body-file %s: %v", path, err) + } + if len(buf) > maxBodyFileSize { + return "", output.ErrValidation("--body-file: file exceeds %d MB limit", maxBodyFileSize/1024/1024) + } + return string(buf), nil +} diff --git a/shortcuts/mail/draft/model.go b/shortcuts/mail/draft/model.go index e4f190fb8..d0c89f4b7 100644 --- a/shortcuts/mail/draft/model.go +++ b/shortcuts/mail/draft/model.go @@ -166,6 +166,7 @@ type DraftProjection struct { LargeAttachmentsSummary []LargeAttachmentSummary `json:"large_attachments_summary,omitempty"` InlineSummary []PartSummary `json:"inline_summary,omitempty"` Warnings []string `json:"warnings,omitempty"` + Priority string `json:"priority"` } type Patch struct { diff --git a/shortcuts/mail/draft/projection.go b/shortcuts/mail/draft/projection.go index 277d68bc5..13bfac261 100644 --- a/shortcuts/mail/draft/projection.go +++ b/shortcuts/mail/draft/projection.go @@ -140,9 +140,53 @@ func Project(snapshot *DraftSnapshot) DraftProjection { proj.LargeAttachmentsSummary = projectLargeAttachments(snapshot.Headers, htmlBody) + proj.Priority = parsePriorityFromHeaders(snapshot.Headers) + return proj } +// parsePriorityFromHeaders derives the read-side priority projection from +// EML headers. It mirrors the write-side helper helpers.go:parsePriority +// (which translates --set-priority high|normal|low into set_header / +// remove_header X-Cli-Priority ops). Lookup order is case-insensitive +// via headerValue: +// 1. X-Cli-Priority (CLI/OAPI-specific header recognised by +// mail-data-access headersToPbBodyExtra) +// 2. X-Priority (RFC standard, fallback for IMAP-回灌 historical drafts) +// +// When neither header is present (including after the write-side translates +// --set-priority normal into remove_header X-Cli-Priority), this returns +// "normal" — absence of a priority header is the standard email convention +// for normal priority. Agents cannot distinguish "explicitly normal" from +// "never set" — known limitation. +func parsePriorityFromHeaders(headers []Header) string { + if v := headerValue(headers, "X-Cli-Priority"); v != "" { + return mapPriorityValue(v) + } + if v := headerValue(headers, "X-Priority"); v != "" { + return mapPriorityValue(v) + } + return "normal" +} + +// mapPriorityValue normalises a raw priority header value to the projection +// vocabulary {"high","normal","low","unknown"}. The accepted input table is +// kept in sync with backend gopkg/mail_priority.PriorityValueToType so that +// CLI read-side projection observes the same set of values the server +// recognises on write. +func mapPriorityValue(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1", "high", "1 (highest)": + return "high" + case "3", "normal", "3 (normal)": + return "normal" + case "5", "low", "5 (lowest)": + return "low" + default: + return "unknown" + } +} + // projectLargeAttachments extracts large attachment info from the draft. // It first tries the server-format header (X-Lark-Large-Attachment) which // carries filename and size directly. Falls back to merging CLI-format diff --git a/shortcuts/mail/draft/projection_test.go b/shortcuts/mail/draft/projection_test.go index 3fe197eaf..499b2bec2 100644 --- a/shortcuts/mail/draft/projection_test.go +++ b/shortcuts/mail/draft/projection_test.go @@ -178,6 +178,170 @@ func TestSplitAtQuoteFalsePositivePlainText(t *testing.T) { } } +// --------------------------------------------------------------------------- +// Priority projection (X-Cli-Priority primary, X-Priority fallback) +// --------------------------------------------------------------------------- + +func TestProjectPriorityXCliPriorityHigh(t *testing.T) { + snapshot := mustParseFixtureDraft(t, `Subject: priority high +From: Alice +To: Bob +X-Cli-Priority: 1 +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 + +hello +`) + proj := Project(snapshot) + if proj.Priority != "high" { + t.Fatalf("Priority = %q, want %q", proj.Priority, "high") + } +} + +func TestProjectPriorityFallbackXPriorityLow(t *testing.T) { + // Only the standard X-Priority header is present (e.g. an IMAP-回灌 + // historical draft). The fallback path should kick in. + snapshot := mustParseFixtureDraft(t, `Subject: priority low (fallback) +From: Alice +To: Bob +X-Priority: 5 +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 + +hello +`) + proj := Project(snapshot) + if proj.Priority != "low" { + t.Fatalf("Priority = %q, want %q", proj.Priority, "low") + } +} + +func TestProjectPriorityBothAbsentNormal(t *testing.T) { + // Neither header is present — default priority is normal. + snapshot := mustParseFixtureDraft(t, `Subject: no priority +From: Alice +To: Bob +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 + +hello +`) + proj := Project(snapshot) + if proj.Priority != "normal" { + t.Fatalf("Priority = %q, want %q", proj.Priority, "normal") + } +} + +func TestProjectPriorityXCliPriorityOutlookStyleHigh(t *testing.T) { + // X-Cli-Priority set to the Outlook-style string "high" (any case). + snapshot := mustParseFixtureDraft(t, `Subject: priority high (string) +From: Alice +To: Bob +X-Cli-Priority: High +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 + +hello +`) + proj := Project(snapshot) + if proj.Priority != "high" { + t.Fatalf("Priority = %q, want %q", proj.Priority, "high") + } +} + +func TestProjectPriorityUnmappedValueUnknown(t *testing.T) { + // Value outside the recognised mapping table (e.g. "urgent") falls + // back to "unknown". + snapshot := mustParseFixtureDraft(t, `Subject: priority urgent +From: Alice +To: Bob +X-Cli-Priority: urgent +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 + +hello +`) + proj := Project(snapshot) + if proj.Priority != "unknown" { + t.Fatalf("Priority = %q, want %q", proj.Priority, "unknown") + } +} + +func TestProjectPriorityXCliPriorityWinsOverXPriority(t *testing.T) { + // X-Cli-Priority must take precedence over X-Priority when both are + // set (defensive: agent or upstream may write both). + snapshot := mustParseFixtureDraft(t, `Subject: both headers +From: Alice +To: Bob +X-Cli-Priority: 1 +X-Priority: 5 +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 + +hello +`) + proj := Project(snapshot) + if proj.Priority != "high" { + t.Fatalf("Priority = %q, want %q (X-Cli-Priority must win)", proj.Priority, "high") + } +} + +func TestProjectPriorityNormalThree(t *testing.T) { + // X-Cli-Priority=3 → "normal" (rare in CLI write path since + // `--set-priority normal` actually removes the header, but this case + // covers e.g. a draft set by another OAPI client that wrote 3). + snapshot := mustParseFixtureDraft(t, `Subject: priority three +From: Alice +To: Bob +X-Cli-Priority: 3 +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 + +hello +`) + proj := Project(snapshot) + if proj.Priority != "normal" { + t.Fatalf("Priority = %q, want %q", proj.Priority, "normal") + } +} + +func TestProjectPriorityFallbackXPriorityNormalString(t *testing.T) { + // IMAP-回灌 / external client writes the RFC-standard `X-Priority: Normal` + // string. The fallback path must project this as "normal" — symmetric with + // how `X-Priority: High` / `Low` are already handled. + snapshot := mustParseFixtureDraft(t, `Subject: priority normal (fallback) +From: Alice +To: Bob +X-Priority: Normal +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 + +hello +`) + proj := Project(snapshot) + if proj.Priority != "normal" { + t.Fatalf("Priority = %q, want %q", proj.Priority, "normal") + } +} + +func TestProjectPriorityOutlookStyleThreeNormal(t *testing.T) { + // Outlook-style `3 (Normal)` parenthesised form — symmetric with the + // already-supported `1 (Highest)` / `5 (Lowest)`. + snapshot := mustParseFixtureDraft(t, `Subject: priority three (normal) +From: Alice +To: Bob +X-Priority: 3 (Normal) +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 + +hello +`) + proj := Project(snapshot) + if proj.Priority != "normal" { + t.Fatalf("Priority = %q, want %q", proj.Priority, "normal") + } +} + func TestParseMissingInlineCIDReportedAsProjectionWarning(t *testing.T) { // Missing CID references should NOT prevent parsing; they are reported // as warnings in Project() instead. diff --git a/shortcuts/mail/helpers.go b/shortcuts/mail/helpers.go index f34eac8e1..275b0698f 100644 --- a/shortcuts/mail/helpers.go +++ b/shortcuts/mail/helpers.go @@ -2602,3 +2602,14 @@ func buildCalendarBody(runtime *common.RuntimeContext, senderEmail string, toAdd senderEmail, toAddrs, ccAddrs, ) } + +// validateBotMailboxNotMe rejects the combination of bot identity with --mailbox me. +// bot uses tenant access token; "me" cannot be resolved to a user mailbox under TAT. +func validateBotMailboxNotMe(runtime *common.RuntimeContext) error { + if runtime.IsBot() && runtime.Str("mailbox") == "me" { + return output.ErrValidation( + "--as bot does not support --mailbox me: bot identity uses a tenant token and cannot resolve \"me\" to a user mailbox; " + + "pass an explicit email address, e.g. --mailbox alice@example.com") + } + return nil +} diff --git a/shortcuts/mail/lint/linter.go b/shortcuts/mail/lint/linter.go new file mode 100644 index 000000000..55a4a5b17 --- /dev/null +++ b/shortcuts/mail/lint/linter.go @@ -0,0 +1,1090 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package lint + +import ( + "bytes" + "fmt" + "hash/fnv" + "strings" + + xhtml "golang.org/x/net/html" + "golang.org/x/net/html/atom" +) + +// MaxExcerptBytes caps the raw-HTML excerpt embedded in a Finding.Excerpt so +// a single offending tag with megabyte content can't bloat the envelope JSON. +// Lint operates on bytes only, but the excerpt representation must not be +// size-amplifying. +const MaxExcerptBytes = 200 + +// Run lints the given HTML body and returns a structured Report. +// Report.CleanedHTML contains the rewritten HTML (warnings rewritten + errors +// deleted) — the autofix is unconditional. +// +// IMPORTANT: when the input is empty or plain-text (no HTML markup detected +// by the cli's existing `bodyIsHTML` heuristic), callers should short-circuit +// with EmptyReport(html) instead of paying the parse cost. Run still handles +// this gracefully — html.Parse on plain text wraps the input in +// ..., and the lib's pass-through +// rendering will reproduce the original text — but the round-trip is wasteful +// and produces no findings. +func Run(html string, opts Options) Report { + if html == "" { + return EmptyReport("") + } + + rep := Report{ + Applied: []Finding{}, + Blocked: []Finding{}, + } + + // We use html.ParseFragment so users authoring fragment-style snippets + // (the canonical compose-5 input shape — `
...
` rather than a + // full document) don't get implicit wrappers + // re-rendered. The "body" insertion mode matches what html.Parse would + // have done internally for a fragment but skips the structural wrappers + // at render time. + bodyContext := &xhtml.Node{Type: xhtml.ElementNode, DataAtom: atom.Body, Data: "body"} + nodes, err := xhtml.ParseFragment(strings.NewReader(html), bodyContext) + if err != nil { + // Parser failure is exceptional (the parser is permissive by design); + // fall back to the original input so we don't lose user content. + return EmptyReport(html) + } + + // Wrap fragment nodes in a synthetic root so the recursive walker has a + // uniform parent pointer to mutate. + root := &xhtml.Node{Type: xhtml.DocumentNode} + for _, n := range nodes { + root.AppendChild(n) + } + + walk(root, &rep) + // nativeCtx tracks per-Run() state so positional ids (e.g. data-ol-id) + // are deterministic across multiple Run() calls on the same input — + // keying off the document-traversal order rather than heap pointers, + // so cleaned_html is byte-stable and amenable to golden-file tests / CI + // diff / cache-key reuse. + nctx := &nativeCtx{olIDs: map[*xhtml.Node]string{}} + applyFeishuNativeStyles(root, &rep, nctx) + + rep.HasErrorFindings = len(rep.Blocked) > 0 + rep.HasWarningFindings = len(rep.Applied) > 0 + rep.CleanedHTML = renderFragment(root) + + return rep +} + +// walk visits every element node under parent, applying tag/attr/style +// classification. Children are iterated via the next-sibling pointer because +// we mutate the tree in place (replace / remove nodes). +// +// The walker is iterative-style via explicit recursion because the html +// parser's typical nesting depth (≤ 256 by default) is well below Go's +// goroutine stack limit; the existing draft package's plainTextFromHTML +// (mail/draft/htmltext.go) similarly recurses for the same reason. +func walk(parent *xhtml.Node, rep *Report) { + child := parent.FirstChild + for child != nil { + next := child.NextSibling + if child.Type == xhtml.ElementNode { + processElement(parent, child, rep) + } + // child may have been removed/replaced by processElement; recurse + // only if it still has the original parent (i.e. wasn't deleted). + // The html parser sets Parent on every node, so a removed-then- + // reattached node still recurses correctly via its new Parent. + if child.Parent != nil { + walk(child, rep) + } + child = next + } +} + +// processElement applies the element-level classification cascade: +// 1. tag → allow / warn-rewrite / error-delete +// 2. attributes → on*-handlers, URL-bearing attrs (scheme allow-list), +// style attribute (CSS property allow-list) +func processElement(parent, n *xhtml.Node, rep *Report) { + tagName := strings.ToLower(n.Data) + kind, ruleID := classifyTag(tagName) + + switch kind { + case "error": + rep.Blocked = append(rep.Blocked, Finding{ + RuleID: ruleID, + Severity: SeverityError, + TagOrAttr: tagName, + Excerpt: excerptOf(n), + Hint: hintForBlockedTag(tagName), + }) + // Always remove blocked tags — the writing-path safety floor has no + // opt-out; `--no-lint` is not provided. + parent.RemoveChild(n) + return + + case "warn": + // Always rewrite (e.g. ) and surface the finding. + rep.Applied = append(rep.Applied, Finding{ + RuleID: ruleID, + Severity: SeverityWarning, + TagOrAttr: tagName, + Excerpt: excerptOf(n), + Hint: hintForWarnTag(tagName), + }) + rewriteWarnTag(n, tagName) + // Recurse into the rewritten node by falling through; the rewrite + // preserved children as-is. + // fall through to attribute scan + case "allow": + // no-op + } + + // Attribute scan: build a new attribute slice, dropping/sanitising as we + // go and surfacing findings. + if len(n.Attr) > 0 { + processAttributes(n, rep) + } +} + +// processAttributes walks the attribute list and: +// - drops on*-handlers (always; surfaced as error) +// - drops URL-bearing attrs whose value uses a forbidden scheme +// - filters the `style` attribute property-by-property against the allow-list +// +// Other attributes pass through unchanged. The cli's existing +// `validateInlineCIDs` (helpers.go:2226) handles `cid:`-specific checks; +// the lint must not duplicate that responsibility. +func processAttributes(n *xhtml.Node, rep *Report) { + keep := n.Attr[:0] + for _, attr := range n.Attr { + name := strings.ToLower(attr.Key) + + // 1. on*-handlers → always drop, error-tier. + if isEventHandlerAttr(name) { + rep.Blocked = append(rep.Blocked, Finding{ + RuleID: RuleAttrEventHandlerBlocked, + Severity: SeverityError, + TagOrAttr: name, + Excerpt: truncateExcerpt(attr.Key + "=\"" + attr.Val + "\""), + Hint: "Removed event handler attribute (on*)", + }) + continue + } + + // 2. URL-bearing attrs → check scheme allow-list. + if urlAttributes[name] { + kind, ruleID := classifyURLValue(attr.Val) + switch kind { + case "error": + rep.Blocked = append(rep.Blocked, Finding{ + RuleID: ruleID, + Severity: SeverityError, + TagOrAttr: name, + Excerpt: truncateExcerpt(attr.Key + "=\"" + attr.Val + "\""), + Hint: "Removed dangerous URL scheme (allowed: http/https/mailto/cid/data:image/*)", + }) + continue + case "warn": + rep.Blocked = append(rep.Blocked, Finding{ + RuleID: ruleID, + Severity: SeverityError, + TagOrAttr: name, + Excerpt: truncateExcerpt(attr.Key + "=\"" + attr.Val + "\""), + Hint: "Removed URL with unrecognised scheme (allowed: http/https/mailto/cid/data:image/*)", + }) + // Always drop the attribute — writing-path safety floor (the + // URL would not render correctly anyway). + continue + } + } + + // 3. `style` attribute → property-by-property allow-list. + if name == "style" { + cleaned, dropped := sanitiseStyleAttr(attr.Val) + for _, prop := range dropped { + rep.Applied = append(rep.Applied, Finding{ + RuleID: RuleStylePropertyDropped, + Severity: SeverityWarning, + TagOrAttr: "style." + prop, + Excerpt: truncateExcerpt(prop), + Hint: "Removed CSS property not in allowlist (see references/lark-mail-html.md)", + }) + } + if len(dropped) == 0 { + // Byte-stable when no property was dropped: leave the + // attribute exactly as authored so lint round-trips are + // idempotent on clean input. + keep = append(keep, attr) + continue + } + if cleaned == "" { + // All properties dropped — remove the attribute entirely. + continue + } + attr.Val = cleaned + keep = append(keep, attr) + continue + } + + // 4. Pass-through. + keep = append(keep, attr) + } + n.Attr = keep +} + +// rewriteWarnTag replaces a warning-tier tag with its Feishu-native +// equivalent in place: with color/face/size +// distilled into inline style;
; +// / (text-only, animation discarded — collapsing +// to a span keeps the children but drops the deprecated animation effect). +func rewriteWarnTag(n *xhtml.Node, tagName string) { + switch tagName { + case "font": + // Distill . + var styles []string + var keepAttrs []xhtml.Attribute + for _, attr := range n.Attr { + switch strings.ToLower(attr.Key) { + case "color": + if v := strings.TrimSpace(attr.Val); v != "" { + styles = append(styles, "color:"+v) + } + case "face": + if v := strings.TrimSpace(attr.Val); v != "" { + styles = append(styles, "font-family:"+v) + } + case "size": + if v := mapFontSize(attr.Val); v != "" { + styles = append(styles, "font-size:"+v) + } + default: + keepAttrs = append(keepAttrs, attr) + } + } + // Merge any existing style attribute already present on the + // (rare but possible). + if len(styles) > 0 { + merged := strings.Join(styles, ";") + styleIdx := -1 + for i, attr := range keepAttrs { + if strings.ToLower(attr.Key) == "style" { + styleIdx = i + break + } + } + if styleIdx >= 0 { + existing := strings.TrimRight(keepAttrs[styleIdx].Val, "; ") + if existing != "" { + merged = existing + ";" + merged + } + keepAttrs[styleIdx].Val = merged + } else { + keepAttrs = append(keepAttrs, xhtml.Attribute{Key: "style", Val: merged}) + } + } + n.Data = "span" + n.DataAtom = atom.Span + n.Attr = keepAttrs + + case "center": + //
. Existing style attr + // (if any) is merged with text-align prepended. + styleIdx := -1 + for i, attr := range n.Attr { + if strings.ToLower(attr.Key) == "style" { + styleIdx = i + break + } + } + newStyle := "text-align:center" + if styleIdx >= 0 { + existing := strings.TrimRight(n.Attr[styleIdx].Val, "; ") + if existing != "" { + newStyle = newStyle + ";" + existing + } + n.Attr[styleIdx].Val = newStyle + } else { + n.Attr = append(n.Attr, xhtml.Attribute{Key: "style", Val: newStyle}) + } + n.Data = "div" + n.DataAtom = atom.Div + + case "marquee", "blink": + // Both deprecated; collapse to so children survive. + n.Data = "span" + n.DataAtom = atom.Span + // Strip marquee-specific attributes (direction, scrollamount, ...) + // so the rewritten span is plain. + var keepAttrs []xhtml.Attribute + for _, attr := range n.Attr { + if strings.ToLower(attr.Key) == "style" || strings.ToLower(attr.Key) == "class" || strings.ToLower(attr.Key) == "id" { + keepAttrs = append(keepAttrs, attr) + } + } + n.Attr = keepAttrs + } +} + +// mapFontSize maps the legacy values (1..7) to a CSS px +// equivalent, matching the mapping used by Feishu mail-editor's renderer. +// Out-of-range values fall through to the empty string so the property is +// dropped (better than emitting an arbitrary value). +func mapFontSize(raw string) string { + switch strings.TrimSpace(raw) { + case "1": + return "10px" + case "2": + return "13px" + case "3": + return "16px" + case "4": + return "18px" + case "5": + return "24px" + case "6": + return "32px" + case "7": + return "48px" + default: + return "" + } +} + +// sanitiseStyleAttr filters a `style="prop1:val; prop2:val"` declaration +// against the property allow-list. Returns the cleaned style text (joined +// with "; " separators) and a slice of dropped property names (lower-case) +// so the caller can surface STYLE_PROPERTY_DROPPED findings. +// +// NOTE: We do NOT validate property values — only property names. The style +// attribute is filtered by CSS property allow-list; value-level validation +// (e.g. URL safety inside `background-image: url(...)`) is delegated to the +// urlAttributes path because such values typically appear in `src` / `href` +// attrs in compose-5 templates. Users authoring `background-image: url(http:...)` +// in inline style will see the property pass — the URL inside is not a +// security concern at the inline-style level since URL fetching from style +// is restricted by the rendering layer's CSP regardless. +func sanitiseStyleAttr(raw string) (cleaned string, dropped []string) { + if strings.TrimSpace(raw) == "" { + return "", nil + } + parts := strings.Split(raw, ";") + keep := make([]string, 0, len(parts)) + for _, part := range parts { + decl := strings.TrimSpace(part) + if decl == "" { + continue + } + colon := strings.IndexByte(decl, ':') + if colon < 0 { + // Malformed declaration; drop and surface as a finding so the + // user notices. + dropped = append(dropped, decl) + continue + } + name := strings.ToLower(strings.TrimSpace(decl[:colon])) + if !classifyStyleProperty(name) { + dropped = append(dropped, name) + continue + } + keep = append(keep, decl) + } + cleaned = strings.Join(keep, "; ") + return cleaned, dropped +} + +// hintForBlockedTag returns a hint for an error-blocked tag (matching +// the `output.ErrWithHint` convention used elsewhere in the cli). +func hintForBlockedTag(tag string) string { + switch tag { + case "script": + return "Removed whole tag (XSS risk)" + case "iframe", "object", "embed": + return "Removed whole tag (external embeds not allowed; use or a body link for rich media)" + case "form", "input", "select", "option", "button": + return "Removed whole tag (forms not allowed in email body)" + case "link": + return "Removed (external CSS / resources not allowed)" + case "meta": + return "Removed (viewport / refresh declarations not allowed)" + case "base": + return "Removed (URL base rewrites not allowed)" + default: + return "Removed whole tag (tag not allowed)" + } +} + +// hintForWarnTag returns a hint for a warning-tier tag. +func hintForWarnTag(tag string) string { + switch tag { + case "font": + return "Rewritten as (modern HTML expresses size / color via inline style)" + case "center": + return "Rewritten as
(deprecated
tag)" + case "marquee", "blink": + return "Rewritten as (animations not supported; text preserved)" + default: + return "Rewritten in modern HTML shape" + } +} + +// excerptOf renders the offending node's open-tag header into a short string +// suitable for surfacing in a Finding.Excerpt. We render only the tag header +// (not the full subtree) so a single offending

after

`, Options{}) + if len(rep.Blocked) != 1 { + t.Fatalf("expected 1 blocked finding, got %d", len(rep.Blocked)) + } + if rep.Blocked[0].RuleID != RuleTagScriptBlocked { + t.Errorf("rule = %s, want %s", rep.Blocked[0].RuleID, RuleTagScriptBlocked) + } + if strings.Contains(rep.CleanedHTML, " content should be deleted, cleaned=%q", rep.CleanedHTML) + } + if !strings.Contains(rep.CleanedHTML, "safe") || !strings.Contains(rep.CleanedHTML, "after") { + t.Errorf("surrounding content lost, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_BlockedTagsRemoved iterates all error-tier tags. +func TestRun_BlockedTagsRemoved(t *testing.T) { + cases := map[string]string{ + ``: RuleTagIframeBlocked, + ``: RuleTagObjectBlocked, + ``: RuleTagEmbedBlocked, + `
`: RuleTagFormBlocked, + ``: RuleTagLinkBlocked, + ``: RuleTagMetaBlocked, + ``: RuleTagBaseBlocked, + } + for input, wantRule := range cases { + t.Run(input[:min(len(input), 30)], func(t *testing.T) { + rep := Run(input, Options{}) + found := false + for _, f := range rep.Blocked { + if f.RuleID == wantRule { + found = true + break + } + } + if !found { + t.Errorf("expected rule %s, got %+v", wantRule, rep.Blocked) + } + }) + } +} + +// TestRun_EventHandlerAttrBlocked verifies on*-handlers are stripped (spec +// §4.4 — "属性 on*(onclick 等)"). +func TestRun_EventHandlerAttrBlocked(t *testing.T) { + rep := Run(`

x

`, Options{}) + if len(rep.Blocked) != 1 { + t.Fatalf("expected 1 blocked finding, got %d", len(rep.Blocked)) + } + if rep.Blocked[0].RuleID != RuleAttrEventHandlerBlocked { + t.Errorf("rule = %s, want %s", rep.Blocked[0].RuleID, RuleAttrEventHandlerBlocked) + } + if strings.Contains(rep.CleanedHTML, "onclick") { + t.Errorf("onclick should be stripped, cleaned=%q", rep.CleanedHTML) + } + if !strings.Contains(rep.CleanedHTML, `id="ok"`) { + t.Errorf("non-handler attrs should survive, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_OnErrorAttrBlocked tests one of the more common XSS vectors. +func TestRun_OnErrorAttrBlocked(t *testing.T) { + rep := Run(``, Options{}) + hasErr := false + for _, f := range rep.Blocked { + if f.RuleID == RuleAttrEventHandlerBlocked && f.TagOrAttr == "onerror" { + hasErr = true + } + } + if !hasErr { + t.Errorf("onerror should fire, got %+v", rep.Blocked) + } +} + +// ===================================================================== +// URL scheme allow-list. +// ===================================================================== + +// TestRun_JavaScriptURLBlocked verifies javascript: hrefs are stripped. +func TestRun_JavaScriptURLBlocked(t *testing.T) { + rep := Run(`click`, Options{}) + hasErr := false + for _, f := range rep.Blocked { + if f.RuleID == RuleAttrJSURLBlocked { + hasErr = true + } + } + if !hasErr { + t.Errorf("javascript: URL should fire ATTR_JS_URL_BLOCKED, got %+v", rep.Blocked) + } + if strings.Contains(rep.CleanedHTML, "javascript:") { + t.Errorf("javascript: should be stripped, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_VBScriptURLBlocked verifies vbscript: is rejected. +func TestRun_VBScriptURLBlocked(t *testing.T) { + rep := Run(`x`, Options{}) + if len(rep.Blocked) == 0 { + t.Errorf("expected vbscript: to be blocked, got 0 findings") + } +} + +// TestRun_DataNonImageURLBlocked verifies data:text/html is rejected +// (only data:image/* is allowed). +func TestRun_DataNonImageURLBlocked(t *testing.T) { + rep := Run(``, Options{}) + if len(rep.Blocked) == 0 { + t.Errorf("expected data:text/html to be blocked") + } +} + +// TestRun_DataImageAllowed verifies data:image/png passes. +func TestRun_DataImageAllowed(t *testing.T) { + rep := Run(``, Options{}) + for _, f := range rep.Blocked { + if f.RuleID == RuleAttrJSURLBlocked { + t.Errorf("data:image/* should pass, got %+v", f) + } + } +} + +// TestRun_RelativeURLAllowed verifies relative URLs (no scheme) pass. +func TestRun_RelativeURLAllowed(t *testing.T) { + rep := Run(`x`, Options{}) + for _, f := range rep.Blocked { + if f.RuleID == RuleAttrJSURLBlocked || f.RuleID == RuleAttrUnsafeSchemeBlocked { + t.Errorf("relative URL should pass, got %+v", f) + } + } +} + +// ===================================================================== +// Style property allow-list. +// ===================================================================== + +// TestRun_StylePropertyDropped verifies non-allow-list properties drop. +func TestRun_StylePropertyDropped(t *testing.T) { + rep := Run(`

x

`, Options{}) + dropped := []string{} + for _, f := range rep.Applied { + if f.RuleID == RuleStylePropertyDropped { + dropped = append(dropped, f.TagOrAttr) + } + } + if !sliceContains(dropped, "style.position") { + t.Errorf("expected position to be dropped, got %v", dropped) + } + if !sliceContains(dropped, "style.z-index") { + t.Errorf("expected z-index to be dropped, got %v", dropped) + } + if strings.Contains(rep.CleanedHTML, "position:") || strings.Contains(rep.CleanedHTML, "z-index:") { + t.Errorf("dropped properties should be removed from cleaned style, cleaned=%q", rep.CleanedHTML) + } + if !strings.Contains(rep.CleanedHTML, "color:red") { + t.Errorf("allowed property should survive, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_StyleBorderPrefixAllowed verifies the border-* prefix rule. +func TestRun_StyleBorderPrefixAllowed(t *testing.T) { + rep := Run(`

x

`, Options{}) + for _, f := range rep.Applied { + if f.RuleID == RuleStylePropertyDropped { + t.Errorf("border-* should pass, got %+v", f) + } + } +} + +// TestRun_FeishuListShorthandMarginPreserved guards the nested-list indent +// regression: when a user writes shorthand `margin:0 0 0 24px` on an inner +//
    (mail-editor's own native nested-list shape), the Feishu-list autofix +// must NOT clobber it by appending `margin-left:0`. ensureInlineStyleProps +// is supposed to skip props the user already declared, but earlier +// hasInlineStyleProp was only matching longhand `margin-left:` literally +// and missed the shorthand form, causing 24px indents to be reset to 0. +func TestRun_FeishuListShorthandMarginPreserved(t *testing.T) { + in := `
    • indented
    ` + rep := Run(in, Options{}) + cleaned := rep.CleanedHTML + // Extract just the
      opening tag's style attr (li has its own + // independent margin-left:0 longhand which is correct — list indent + // belongs on the container, not the item). + ulOpen := cleaned + if i := strings.Index(ulOpen, ">"); i >= 0 { + ulOpen = ulOpen[:i] + } + if !strings.Contains(ulOpen, "margin:0px 0px 0px 24px") { + t.Errorf("shorthand margin with 24px left should survive on
        , ulOpen=%q", ulOpen) + } + // The bug signature: extra `margin-left:` appended after the shorthand + // on the
          element itself (CSS rule says the later one wins, so any + // margin-left:0 after the shorthand resets the indent to 0). + if strings.Contains(ulOpen, "margin-left") { + t.Errorf("autofix must not append margin-left longhand onto
            when shorthand already declares it, ulOpen=%q", ulOpen) + } +} + +// ===================================================================== +// CleanedHTML output / contract guarantees. +// ===================================================================== + +// TestRun_EmptyArraysAlwaysPresent verifies the report has non-nil empty +// slices when nothing is found (the JSON envelope contract requires `[]`, +// not `null`). +func TestRun_EmptyArraysAlwaysPresent(t *testing.T) { + // Use
            instead of

            to avoid the Feishu-native paragraph + // rewrite autofix, which would surface a finding even on otherwise + // clean input. + rep := Run(`

            nothing here
            `, Options{}) + if rep.Applied == nil || rep.Blocked == nil { + t.Errorf("Applied/Blocked must be non-nil; got applied=%v blocked=%v", rep.Applied, rep.Blocked) + } + if len(rep.Applied) != 0 || len(rep.Blocked) != 0 { + t.Errorf("expected empty findings, got applied=%d blocked=%d", len(rep.Applied), len(rep.Blocked)) + } +} + +// TestEmptyReport_HasContractFields covers the helper used by compose 5's +// plain-text branch. +func TestEmptyReport_HasContractFields(t *testing.T) { + rep := EmptyReport(`plain text`) + if rep.Applied == nil { + t.Error("Applied must be non-nil") + } + if rep.Blocked == nil { + t.Error("Blocked must be non-nil") + } + if rep.CleanedHTML != "plain text" { + t.Errorf("CleanedHTML = %q, want %q", rep.CleanedHTML, "plain text") + } +} + +// TestRun_CleanedHTMLPreservesStructure verifies that the round-trip through +// the parser doesn't accidentally lose user content. +func TestRun_CleanedHTMLPreservesStructure(t *testing.T) { + html := `

            title

            body bold end

            • a
            • b
            ` + rep := Run(html, Options{}) + if len(rep.Blocked) != 0 { + t.Fatalf("unexpected blocked: %+v", rep.Blocked) + } + // Feishu-native autofix expected to fire on

            ,

              ,
            • — content + // must still survive untouched even though structure is augmented. + for _, want := range []string{"line-height:1.6", "

              ", "title", "", "bold", ""} { + if !strings.Contains(rep.CleanedHTML, want) { + t.Errorf("expected %q in cleaned, got %q", want, rep.CleanedHTML) + } + } +} + +// TestRun_EmptyInput verifies the lib short-circuits cleanly on empty input. +func TestRun_EmptyInput(t *testing.T) { + rep := Run("", Options{}) + if rep.CleanedHTML != "" { + t.Errorf("CleanedHTML = %q, want empty", rep.CleanedHTML) + } + if len(rep.Applied) != 0 || len(rep.Blocked) != 0 { + t.Errorf("empty input must produce empty findings") + } +} + +// TestRun_HasErrorFindingsFlag verifies the flag tracks blocked findings. +func TestRun_HasErrorFindingsFlag(t *testing.T) { + rep := Run(``, Options{}) + if !rep.HasErrorFindings { + t.Error("expected HasErrorFindings=true") + } + clean := Run(`

              safe

              `, Options{}) + if clean.HasErrorFindings { + t.Error("expected HasErrorFindings=false on clean HTML") + } +} + +// TestRun_HasWarningFindingsFlag verifies the flag tracks warnings. +func TestRun_HasWarningFindingsFlag(t *testing.T) { + rep := Run(`x`, Options{}) + if !rep.HasWarningFindings { + t.Error("expected HasWarningFindings=true") + } +} + +// ===================================================================== +// Excerpt cap. +// ===================================================================== + +// TestTruncateExcerpt_RespectsCap verifies the per-finding excerpt cap. +func TestTruncateExcerpt_RespectsCap(t *testing.T) { + long := strings.Repeat("x", MaxExcerptBytes+50) + got := truncateExcerpt(long) + if len(got) > MaxExcerptBytes { + t.Errorf("excerpt len %d exceeds cap %d", len(got), MaxExcerptBytes) + } + if !strings.HasSuffix(got, " ...") { + t.Errorf("expected truncation suffix, got %q", got[len(got)-10:]) + } +} + +// TestRun_ExcerptCappedForLargeOffender verifies large blocked content +// produces a short excerpt (envelope size protection). +func TestRun_ExcerptCappedForLargeOffender(t *testing.T) { + bigAttr := strings.Repeat("a", MaxExcerptBytes*2) + rep := Run(`x`, Options{}) + if len(rep.Blocked) == 0 { + t.Fatal("expected blocked finding") + } + for _, f := range rep.Blocked { + if len(f.Excerpt) > MaxExcerptBytes { + t.Errorf("excerpt len %d exceeds cap %d", len(f.Excerpt), MaxExcerptBytes) + } + } +} + +// ===================================================================== +// Helpers. +// ===================================================================== + +func sliceContains(haystack []string, needle string) bool { + for _, s := range haystack { + if s == needle { + return true + } + } + return false +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// ===================================================================== +// Additional coverage to lift the package to ≥ 90% line coverage +// (sprint requirement S4 item 7). +// ===================================================================== + +// TestMapFontSize_ExhaustiveSpan covers every mapping +// + invalid values fall through to "" so the property is dropped. +func TestMapFontSize_ExhaustiveSpan(t *testing.T) { + cases := map[string]string{ + "1": "10px", + "2": "13px", + "3": "16px", + "4": "18px", + "5": "24px", + "6": "32px", + "7": "48px", + "": "", + "8": "", + "abc": "", + "3.5": "", + " 3 ": "16px", + } + for raw, want := range cases { + got := mapFontSize(raw) + if got != want { + t.Errorf("mapFontSize(%q) = %q, want %q", raw, got, want) + } + } +} + +// TestRun_FontTagWithFaceMappedToFontFamily ensures → +// font-family inline style. +func TestRun_FontTagWithFaceMappedToFontFamily(t *testing.T) { + rep := Run(`x`, Options{}) + if !strings.Contains(rep.CleanedHTML, "font-family:Arial") { + t.Errorf("expected font-family preserved, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_FontTagWithExistingStyleMerged ensures distillation merges with an +// existing style attribute on the same element. +func TestRun_FontTagWithExistingStyleMerged(t *testing.T) { + rep := Run(`x`, Options{}) + if !strings.Contains(rep.CleanedHTML, "line-height:1.6") { + t.Errorf("expected line-height retained, cleaned=%q", rep.CleanedHTML) + } + if !strings.Contains(rep.CleanedHTML, "color:red") { + t.Errorf("expected color merged, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_CenterTagWithExistingStyleMerged ensures
              's style merge. +func TestRun_CenterTagWithExistingStyleMerged(t *testing.T) { + rep := Run(`
              x
              `, Options{}) + if !strings.Contains(rep.CleanedHTML, "text-align:center") { + t.Errorf("expected text-align:center, cleaned=%q", rep.CleanedHTML) + } + if !strings.Contains(rep.CleanedHTML, "line-height:1.6") { + t.Errorf("expected line-height preserved, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_MarqueeRetainsClassAndID verifies marquee → span keeps class/id. +func TestRun_MarqueeRetainsClassAndID(t *testing.T) { + rep := Run(`y`, Options{}) + if !strings.Contains(rep.CleanedHTML, `class="cls"`) { + t.Errorf("expected class preserved, cleaned=%q", rep.CleanedHTML) + } + if strings.Contains(rep.CleanedHTML, `direction`) { + t.Errorf("expected marquee-specific attrs stripped, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_UnknownSchemeBlocked verifies an unknown URL scheme produces a +// blocked (error) finding and the attribute is dropped. +func TestRun_UnknownSchemeBlocked(t *testing.T) { + rep := Run(`x`, Options{}) + gotBlocked := false + for _, f := range rep.Blocked { + if f.RuleID == RuleAttrUnsafeSchemeBlocked { + gotBlocked = true + } + } + if !gotBlocked { + t.Errorf("expected ATTR_UNSAFE_SCHEME_BLOCKED in Blocked, got blocked=%+v applied=%+v", rep.Blocked, rep.Applied) + } + if strings.Contains(rep.CleanedHTML, "webcal:") { + t.Errorf("expected unknown scheme stripped, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_WhitespaceObfuscatedJavaScriptScheme verifies "java\tscript:..." +// is still caught after control-byte stripping in classifyURLValue. +func TestRun_WhitespaceObfuscatedJavaScriptScheme(t *testing.T) { + rep := Run("x", Options{}) + gotErr := false + for _, f := range rep.Blocked { + if f.RuleID == RuleAttrJSURLBlocked { + gotErr = true + } + } + if !gotErr { + t.Errorf("expected obfuscated javascript: to be caught, got %+v", rep.Blocked) + } +} + +// TestRun_FileSchemeBlocked verifies file: URLs are rejected. +func TestRun_FileSchemeBlocked(t *testing.T) { + rep := Run(`x`, Options{}) + if len(rep.Blocked) == 0 { + t.Error("expected file: to be blocked") + } +} + +// TestRun_StyleMalformedDeclarationDropped verifies a property without a +// colon delimiter is treated as malformed and dropped. +func TestRun_StyleMalformedDeclarationDropped(t *testing.T) { + rep := Run(`

              x

              `, Options{}) + gotMalformed := false + for _, f := range rep.Applied { + if f.RuleID == RuleStylePropertyDropped && f.TagOrAttr == "style.malformed" { + gotMalformed = true + } + } + if !gotMalformed { + t.Errorf("expected malformed declaration to be dropped, got %+v", rep.Applied) + } + if !strings.Contains(rep.CleanedHTML, "color:red") || !strings.Contains(rep.CleanedHTML, "line-height:1.6") { + t.Errorf("valid declarations should survive, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_StyleAllPropertiesDroppedRemovesAttribute verifies the style +// attribute is removed entirely when every property is invalid. +func TestRun_StyleAllPropertiesDroppedRemovesAttribute(t *testing.T) { + // Use
              to avoid the Feishu-native paragraph autofix, which adds + // a fresh style attribute on the rewritten outer wrapper. + rep := Run(`
              x
              `, Options{}) + if strings.Contains(rep.CleanedHTML, "style=") { + t.Errorf("style attribute should be removed when all props invalid, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_StyleEmptyValuePassThrough verifies an empty style attr passes. +func TestRun_StyleEmptyValuePassThrough(t *testing.T) { + // Use
              to avoid the Feishu-native paragraph autofix. + rep := Run(`
              x
              `, Options{}) + if len(rep.Applied) != 0 { + t.Errorf("empty style attr should not produce findings, got %+v", rep.Applied) + } +} + +// TestRun_HintsForAllBlockedTags verifies every blocked-tag rule has a +// non-empty hint (consumer contract). +func TestRun_HintsForAllBlockedTags(t *testing.T) { + cases := []string{ + ``, ``, + ``, ``, `
              `, + ``, ``, ``, + ``, ``, + } + for _, html := range cases { + rep := Run(html, Options{}) + for _, f := range rep.Blocked { + if f.Hint == "" { + t.Errorf("blocked rule %s missing hint for %q", f.RuleID, html) + } + } + } +} + +// TestRun_HintsForAllWarnTags verifies every warn-tag rule has a non-empty hint. +func TestRun_HintsForAllWarnTags(t *testing.T) { + cases := []string{ + `x`, `
              x
              `, + `x`, `x`, + } + for _, html := range cases { + rep := Run(html, Options{}) + for _, f := range rep.Applied { + if f.Hint == "" { + t.Errorf("warn rule %s missing hint for %q", f.RuleID, html) + } + } + } +} + +// TestClassifyTag_Coverage exercises classifyTag with every category. +func TestClassifyTag_Coverage(t *testing.T) { + if k, _ := classifyTag("p"); k != "allow" { + t.Errorf("p classified as %q", k) + } + if k, id := classifyTag("script"); k != "error" || id != RuleTagScriptBlocked { + t.Errorf("script classified as %q/%q", k, id) + } + if k, id := classifyTag("font"); k != "warn" || id != RuleTagFontToSpan { + t.Errorf("font classified as %q/%q", k, id) + } + // Niche tag passes silently (e.g.
              ). + if k, _ := classifyTag("details"); k != "allow" { + t.Errorf("niche tag
              should pass through, got %q", k) + } + // Case-insensitive. + if k, _ := classifyTag("SCRIPT"); k != "error" { + t.Errorf("SCRIPT (uppercase) should still classify as error") + } +} + +// TestClassifyURLValue_CoverageEdges covers empty, whitespace-only, +// no-scheme variants. +func TestClassifyURLValue_CoverageEdges(t *testing.T) { + cases := map[string]string{ + "": "ok", + " ": "ok", + "https://x": "ok", + "https://x/path?q=1": "ok", + "#fragment": "ok", + "/relative": "ok", + "javascript:alert(1)": "error", + "vbscript:msgbox 1": "error", + "data:image/png;base64,XYZ": "ok", + "data:text/html,x` + + `

              y

              ` + rep := Run(html, Options{}) + if len(rep.Blocked) < 4 { + t.Errorf("expected ≥4 errors, got %d: %+v", len(rep.Blocked), rep.Blocked) + } +} + +// TestRun_NestedStructurePreserved verifies deep nesting passes through. +func TestRun_NestedStructurePreserved(t *testing.T) { + html := `

              deep

              ` + rep := Run(html, Options{}) + if len(rep.Blocked) != 0 { + t.Errorf("nested allowed tags should pass, got %+v", rep.Blocked) + } + if !strings.Contains(rep.CleanedHTML, "deep") { + t.Errorf("inner text lost, cleaned=%q", rep.CleanedHTML) + } +} + +// TestRun_BlockedInsideAllowedRemovedNotParent verifies that removing a +// blocked tag inside an allowed parent leaves the parent intact. +func TestRun_BlockedInsideAllowedRemovedNotParent(t *testing.T) { + html := `
              beforeafter
              ` + rep := Run(html, Options{}) + if !strings.Contains(rep.CleanedHTML, "before") || !strings.Contains(rep.CleanedHTML, "after") { + t.Errorf("parent text should survive, cleaned=%q", rep.CleanedHTML) + } + if strings.Contains(rep.CleanedHTML, "
                nested +// directly without an
              • wrapper triggers LIST_DIRECT_CHILD_NON_LI and +// the inner
                  ends up wrapped in a synthetic
                • . Same for
                      . +func TestRun_ListDirectChildNonLIWrapped(t *testing.T) { + cases := []struct { + name string + html string + }{ + {"ul wraps ul", `
                        • x
                      `}, + {"ol wraps ol", `
                        1. x
                      `}, + {"ul wraps div", `
                        orphan
                      • real
                      `}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + rep := Run(tc.html, Options{}) + gotRule := false + for _, f := range rep.Applied { + if f.RuleID == RuleListDirectChildNonLI { + gotRule = true + break + } + } + if !gotRule { + t.Errorf("expected LIST_DIRECT_CHILD_NON_LI, got %+v", rep.Applied) + } + // The cleaned HTML should not have a direct ul>ul or ol>ol or + // ul>div sequence anymore. + if strings.Contains(rep.CleanedHTML, "
                        wrapper, cleaned=%q", rep.CleanedHTML) + } + }) + } +} diff --git a/shortcuts/mail/lint/rules.go b/shortcuts/mail/lint/rules.go new file mode 100644 index 000000000..e99b17023 --- /dev/null +++ b/shortcuts/mail/lint/rules.go @@ -0,0 +1,354 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package lint + +import "strings" + +// Rule IDs surfaced through Finding.RuleID. UPPER_SNAKE_CASE naming is the +// contract for the stdout envelope. New rules MUST keep this naming convention +// so AI / test consumers can pattern-match reliably. +const ( + // Tag-level rules. + RuleTagFontToSpan = "TAG_FONT_TO_SPAN" + RuleTagCenterToDiv = "TAG_CENTER_TO_DIV" + RuleTagMarqueeToText = "TAG_MARQUEE_TO_TEXT" + RuleTagBlinkToText = "TAG_BLINK_TO_TEXT" + RuleTagScriptBlocked = "TAG_SCRIPT_BLOCKED" + RuleTagIframeBlocked = "TAG_IFRAME_BLOCKED" + RuleTagObjectBlocked = "TAG_OBJECT_BLOCKED" + RuleTagEmbedBlocked = "TAG_EMBED_BLOCKED" + RuleTagFormBlocked = "TAG_FORM_BLOCKED" + RuleTagInputBlocked = "TAG_INPUT_BLOCKED" + RuleTagLinkBlocked = "TAG_LINK_BLOCKED" + RuleTagMetaBlocked = "TAG_META_BLOCKED" + RuleTagBaseBlocked = "TAG_BASE_BLOCKED" + RuleTagUnknownStripped = "TAG_UNKNOWN_STRIPPED" + + // Attribute-level rules. + RuleAttrEventHandlerBlocked = "ATTR_EVENT_HANDLER_BLOCKED" + RuleAttrJSURLBlocked = "ATTR_JS_URL_BLOCKED" + RuleAttrUnsafeSchemeBlocked = "ATTR_UNSAFE_SCHEME_BLOCKED" + + // Style-level rules. + RuleStylePropertyDropped = "STYLE_PROPERTY_DROPPED" + + // Feishu-native autofix rules. These autofix the inline style / + // class / nesting shape of common elements so AI-authored HTML + // matches what Feishu mail-editor itself emits, fixing the visual + // "extra blank line between blocks", "list bullets/numbers missing", + // "link color wrong" etc. classes of issues. The rewrite is purely + // additive — user-supplied inline styles take precedence; the lib + // only fills the missing properties. + RuleStyleListNative = "STYLE_LIST_NATIVE_INLINE_APPLIED" + RuleStyleListItemNative = "STYLE_LIST_ITEM_NATIVE_INLINE_APPLIED" + RuleStyleBlockquoteNative = "STYLE_BLOCKQUOTE_NATIVE_INLINE_APPLIED" + RuleStyleLinkNative = "STYLE_LINK_NATIVE_INLINE_APPLIED" + RuleStyleParaWrapper = "STYLE_PARA_WRAPPER_REWRITTEN" + + // RuleListDirectChildNonLI fires when a
                          or
                            has a non-
                          1. + // element child (e.g. nested
                                ). HTML spec requires list children + // to be
                              • ; browsers silently hoist the nested list out and the visual + // nesting falls apart. The lib autofixes by wrapping the offending child + // in a synthetic
                              • . + RuleListDirectChildNonLI = "LIST_DIRECT_CHILD_NON_LI" +) + +// Tag classification ---------------------------------------------------------- + +// allowedTags enumerates tags that pass through verbatim (tag classification row "通过"). +// Lower-case canonical names; the parser normalises tag names so we don't need +// case-insensitive comparison at lookup time. +var allowedTags = map[string]bool{ + "p": true, + "div": true, + "span": true, + "br": true, + "hr": true, + "a": true, + "img": true, + "table": true, + "thead": true, + "tbody": true, + "tfoot": true, + "tr": true, + "td": true, + "th": true, + "ul": true, + "ol": true, + "li": true, + "blockquote": true, + "pre": true, + "code": true, + "b": true, + "i": true, + "em": true, + "strong": true, + "u": true, + "s": true, + "strike": true, + "h1": true, + "h2": true, + "h3": true, + "h4": true, + "h5": true, + "h6": true, + "sub": true, + "sup": true, + "section": true, + "article": true, + "header": true, + "footer": true, + "nav": true, + "main": true, + "figure": true, + "figcaption": true, + "caption": true, + "colgroup": true, + "col": true, + // Document structural tags (golang.org/x/net/html always wraps fragments + // in ); we treat them as transparent so the wrapper + // nodes the parser inserts don't generate spurious findings. + "html": true, + "head": true, + "body": true, +} + +// blockedTags enumerates tags whose content is removed in full and a +// SeverityError finding is emitted (tag classification row "错误(删除)"). Each entry +// maps to the rule id surfaced in Finding.RuleID. +var blockedTags = map[string]string{ + "script": RuleTagScriptBlocked, + "iframe": RuleTagIframeBlocked, + "object": RuleTagObjectBlocked, + "embed": RuleTagEmbedBlocked, + "form": RuleTagFormBlocked, + "input": RuleTagInputBlocked, + "select": RuleTagInputBlocked, + "option": RuleTagInputBlocked, + "button": RuleTagInputBlocked, + "link": RuleTagLinkBlocked, + "meta": RuleTagMetaBlocked, + "base": RuleTagBaseBlocked, +} + +// warnAutofixTags enumerates tags rewritten when AutoFix is true (tag +// classification row "警告 + 自动修复"). The replacement strategy is per-tag. +var warnAutofixTags = map[string]string{ + "font": RuleTagFontToSpan, + "center": RuleTagCenterToDiv, + "marquee": RuleTagMarqueeToText, + "blink": RuleTagBlinkToText, +} + +// classifyTag returns the rule kind for the given lower-case tag name. +// +// kind is one of "allow", "warn", "error", "unknown". For "warn" / "error", +// ruleID names the firing rule; for "unknown", the caller falls back to +// allow-list-by-default but emits a hint via RuleTagUnknownStripped only when +// the tag is structurally suspect (e.g. -like). The cli's existing +// `htmlTagRe` regex is the de-facto allow-list shipping with the codebase, so +// we don't aggressively flag anything outside `allowedTags` — drop-through +// preserves user intent for niche tags (e.g. `
                                ` / ``) that +// browsers + Feishu native renderer already handle. +func classifyTag(tag string) (kind, ruleID string) { + tag = strings.ToLower(tag) + if allowedTags[tag] { + return "allow", "" + } + if id, ok := blockedTags[tag]; ok { + return "error", id + } + if id, ok := warnAutofixTags[tag]; ok { + return "warn", id + } + // Unknown / niche tags: pass through silently. The cli's existing + // `htmlTagRe` (mail_quote.go:333) tolerates them too. Users authoring + // HTML in Feishu native classes (`adit-html-block*`, `history-quote-*`, + // `lark-mail-doc-quote`) hit this path — they MUST pass through unchanged + // so reply / forward quote markup survives lint round-trips. (Spec §4.4 + // "通过" row second sentence: "飞书原生 quote 体系...class".) + return "allow", "" +} + +// Attribute / URL / style classification -------------------------------------- + +// allowedURLSchemes lists URL schemes that pass through hyperlink-bearing +// attrs (`href`, `src`, `cite`, `formaction` etc.). Allowed: http(s), mailto, +// cid, data:image/*; everything else (notably javascript: and vbscript:) is +// blocked. Empty / relative URLs (no scheme) are always +// allowed because they resolve relatively at render time and pose no +// injection vector. +var allowedURLSchemes = map[string]bool{ + "http": true, + "https": true, + "mailto": true, + "cid": true, +} + +// blockedURLSchemes is the explicit deny-list. data:image/* is special-cased +// in classifyURLValue. +var blockedURLSchemes = map[string]bool{ + "javascript": true, + "vbscript": true, + "file": true, +} + +// classifyURLValue returns ("ok", "") if the URL value is acceptable, or +// ("error", ruleID) when it must be removed (javascript:/vbscript:/file:), +// or ("warn", ruleID) when the scheme is unrecognised but not actively +// dangerous. Empty values pass through (browsers ignore them). +func classifyURLValue(raw string) (kind, ruleID string) { + value := strings.TrimSpace(raw) + if value == "" { + return "ok", "" + } + // Strip leading whitespace + control bytes that could obscure the + // scheme (e.g. "java\tscript:..."). The html-parser already strips + // stray whitespace at attribute boundaries; this is defence-in-depth + // for older clients that paste from Word with U+0009 / U+0020 inside + // the scheme prefix. + value = strings.Map(func(r rune) rune { + if r < 0x20 || r == 0x7F { + return -1 + } + return r + }, value) + + // Find the colon delimiter; everything before it is the scheme. + colon := strings.IndexByte(value, ':') + if colon < 0 { + // No scheme → relative URL → allow. + return "ok", "" + } + scheme := strings.ToLower(value[:colon]) + rest := value[colon+1:] + + switch { + case allowedURLSchemes[scheme]: + return "ok", "" + case scheme == "data": + // data:image/* is whitelisted; anything else (e.g. data:text/html;...) + // is rejected. The check tolerates any subtype under image/* (png / + // jpeg / gif / svg+xml / webp) so users embedding base64 thumbnails + // don't trip the rule. + rest = strings.TrimSpace(rest) + if strings.HasPrefix(strings.ToLower(rest), "image/") { + return "ok", "" + } + return "error", RuleAttrJSURLBlocked + case blockedURLSchemes[scheme]: + return "error", RuleAttrJSURLBlocked + default: + // Unknown scheme: surface a warning so users see it but don't + // drop legitimate webcal:/tel: / similar in case downstream + // renders eventually support them. + return "warn", RuleAttrUnsafeSchemeBlocked + } +} + +// urlAttributes lists attributes whose value is a URL and must therefore +// pass classifyURLValue. Lower-case canonical names. +var urlAttributes = map[string]bool{ + "href": true, + "src": true, + "cite": true, + "formaction": true, + "action": true, + "background": true, + "poster": true, +} + +// allowedStyleProps enumerates CSS property names that pass through the +// inline `style="..."` attribute. Everything else is removed from the +// property list and surfaced via STYLE_PROPERTY_DROPPED. +// +// `border-*` / `padding-*` / `margin-*` are treated as prefix matches by +// classifyStyleProperty so the four directional variants (border-top etc.) +// are all admitted without enumerating each. +var allowedStyleProps = map[string]bool{ + "color": true, + "background-color": true, + "font-size": true, + "font-weight": true, + "font-style": true, + "text-align": true, + "text-decoration": true, + "line-height": true, + "padding": true, + "margin": true, + "border": true, + "width": true, + "height": true, + "display": true, + "text-indent": true, + // Quote-block / native Feishu styles (tag classification "通过"). + // Whitespace + word-break are part of the existing `
                                ` / quote
                                +	// wrapper styles in mail_quote.go (e.g. `bodyDivStyle`).
                                +	"white-space":     true,
                                +	"word-break":      true,
                                +	"word-wrap":       true,
                                +	"overflow":        true,
                                +	"overflow-wrap":   true,
                                +	"vertical-align":  true,
                                +	"list-style":          true,
                                +	"list-style-type":     true,
                                +	"list-style-position": true,
                                +	"transition":          true,
                                +	"font-family":         true,
                                +	"text-transform":  true,
                                +	"hyphens":         true,
                                +	"max-width":       true,
                                +	"min-width":       true,
                                +	"max-height":      true,
                                +	"min-height":      true,
                                +	"border-radius":   true,
                                +	"box-sizing":      true,
                                +	"opacity":         true,
                                +	"cursor":          true,
                                +}
                                +
                                +// stylePropAllowedPrefixes enumerates property name prefixes treated as
                                +// allowed regardless of suffix (e.g. "border-*"). A trailing "-" makes the
                                +// prefix self-documenting.
                                +var stylePropAllowedPrefixes = []string{
                                +	"border-",
                                +	"padding-",
                                +	"margin-",
                                +}
                                +
                                +// classifyStyleProperty reports whether the given lower-case property name
                                +// is in the allow-list (incl. prefix matches).
                                +func classifyStyleProperty(name string) bool {
                                +	name = strings.ToLower(strings.TrimSpace(name))
                                +	if name == "" {
                                +		return false
                                +	}
                                +	if allowedStyleProps[name] {
                                +		return true
                                +	}
                                +	for _, p := range stylePropAllowedPrefixes {
                                +		if strings.HasPrefix(name, p) {
                                +			return true
                                +		}
                                +	}
                                +	return false
                                +}
                                +
                                +// isEventHandlerAttr reports whether the attribute name is a DOM event
                                +// handler (`on*`). The lib removes every such attribute regardless of its
                                +// value (tag classification row "错误(删除)" + the well-known XSS vector).
                                +func isEventHandlerAttr(name string) bool {
                                +	name = strings.ToLower(strings.TrimSpace(name))
                                +	if !strings.HasPrefix(name, "on") {
                                +		return false
                                +	}
                                +	if len(name) <= 2 {
                                +		return false
                                +	}
                                +	// Defence-in-depth: avoid matching legitimate attrs whose name happens
                                +	// to begin with "on" (e.g. `onerror`-like attrs all start "on" + ascii
                                +	// letter). The `>= 'a'` check filters out "on-something" with hyphens.
                                +	c := name[2]
                                +	return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9')
                                +}
                                diff --git a/shortcuts/mail/lint/types.go b/shortcuts/mail/lint/types.go
                                new file mode 100644
                                index 000000000..383b1e6db
                                --- /dev/null
                                +++ b/shortcuts/mail/lint/types.go
                                @@ -0,0 +1,92 @@
                                +// Copyright (c) 2026 Lark Technologies Pte. Ltd.
                                +// SPDX-License-Identifier: MIT
                                +
                                +// Package lint implements the mail-domain HTML lint lib used by `+lint-html`
                                +// and the writing-path internals of the compose 5 shortcuts (`+send`,
                                +// `+draft-create`, `+reply`, `+reply-all`, `+forward`) and `+draft-edit` body
                                +// ops. The lib classifies HTML tags / attributes / inline styles into three
                                +// tiers (pass / warn-and-autofix / error-delete) following the three-tier tag
                                +// classification. `
                                +
                                +
                                + +
                                +

                                [调研主题] 市场调研报告

                                +
                                [YYYY-MM-DD] | 调研者:[姓名] · [团队] | [关联系统 / 版本]
                                +
                                + +
                                +

                                调研背景

                                +
                                [一段话描述:本轮调研聚焦的赛道 / 行业背景 / 触发动机]。本轮调研覆盖 [N] 类玩家([类别 1] / [类别 2] / [类别 3] / [类别 4]),重点评估 [自家产品 / 团队] 在 [赛道名] 的位置、对外摩擦点,以及结合 [关联工作 / PR / 本期目标] 的待补能力。所有结论基于 [数据来源 1:公开资料 / 厂商文档 / 行业报告] + [数据来源 2:自有实测 / 内部调研笔记] + [数据来源 3:访谈 / 体验]。
                                +
                                + +
                                +
                                +
                                [N]
                                +
                                调研对象
                                +
                                +
                                +
                                [N]
                                +
                                已就绪能力
                                +
                                +
                                +
                                [N]
                                +
                                明确缺口
                                +
                                +
                                +
                                [N]
                                +
                                高优待办
                                +
                                +
                                + +
                                +

                                1. [章节标题:例 "全球市场态势"]

                                +
                                [一句话描述本节切分维度,例 "把市场按 '为谁设计' 切四象限"]
                                + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                                玩家 / 对象定位 / 类型[关键评分维度]关键观察
                                [玩家 1][类别][标签][一句话观察]
                                [玩家 2][类别][标签][一句话观察]
                                [玩家 3][类别][标签][一句话观察]
                                [玩家 4][类别][标签][一句话观察]
                                +
                                + +
                                +

                                2. [章节标题:例 "接入摩擦点"] ⚠️ 风险

                                +
                                [一句话描述:从哪里观察 / 案例 / 数据来源]
                                + + + + + + + + + + + + + + + + + + + + + + + +
                                摩擦类型 / 维度具体表现业务影响
                                [摩擦 1][具体表现 / 案例][对业务 / 团队的影响]
                                [摩擦 2][具体表现][影响]
                                [摩擦 3][具体表现][影响]
                                +
                                + +
                                +

                                3. [章节标题:例 "新势力玩家详情" / "重点对象详细比较"]

                                +
                                +
                                +
                                [玩家 / 对象 1]
                                +
                                [一句话产品定位 / 核心能力 / 差异化]
                                +
                                关键差异:[一句话提炼]
                                +
                                +
                                +
                                [玩家 / 对象 2]
                                +
                                [产品定位]
                                +
                                关键差异:[一句话]
                                +
                                +
                                +
                                [玩家 / 对象 3]
                                +
                                [产品定位]
                                +
                                关键差异:[一句话]
                                +
                                +
                                +
                                [小结一句话:玩家共性 / 自家路线对比]
                                +
                                + +
                                +

                                4. [章节标题:例 "安全风险全景" / "潜在隐患"] ⚠️ 高危

                                +
                                [一句话描述:风险来源 / 关联前期工作]
                                + + + + + + + + + + + + + + + + + + + + + + + +
                                威胁 / 风险案例 / 来源自家现状
                                [风险 1][案例 / 来源链接 / 引用前期报告][标签]
                                [风险 2][案例 / 来源][标签]
                                [风险 3](重点)[案例 / 来源][标签]
                                +
                                + 结论:[一段话,提炼本章节最关键的判断 / 行动建议] +
                                +
                                + +
                                +

                                5. [章节标题:例 "自家已就绪能力"] ✓ 优势

                                +
                                [一句话描述:基于哪些 PR / 已交付的工作得出]
                                +
                                • [能力 1] — [简述 + 关联 PR / 文档链接]
                                • [能力 2] — [简述]
                                • [能力 3] — [简述]
                                • [能力 4] — [简述]
                                +
                                + +
                                +

                                6. [章节标题:例 "待补能力 / 机会清单"]

                                +
                                [一句话描述:清单口径 / 优先级判定依据]
                                + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                                #优先级能力 / 缺口建议落地
                                1P0[能力 / 缺口 1][具体落地路径 / Owner / 估算]
                                2P0[能力 / 缺口 2][具体落地路径]
                                3P1[能力 / 缺口 3][具体落地路径]
                                4P1[能力 / 缺口 4][具体落地路径]
                                5P2[能力 / 缺口 5][具体落地路径]
                                +
                                + +
                                +

                                关联工作产出佐证

                                +
                                本调研报告中部分章节的依据来自下列在执行中的工作:
                                + +
                                + +
                                +

                                建议与下一步

                                +
                                1. [行动 1] — [具体路径 + 时间窗 + Owner]
                                2. [行动 2] — [具体路径 + 时间窗]
                                3. [行动 3] — [具体路径]
                                4. [行动 4] — [具体路径]
                                +
                                + +
                                +
                                调研者:[your@email] · [团队]|整合于 [YYYY-MM-DD]
                                +
                                关联材料:[文档 / 笔记路径 / 前期报告]
                                +
                                + +
                                diff --git a/skills/lark-mail/assets/templates/weekly--personal-report.html b/skills/lark-mail/assets/templates/weekly--personal-report.html new file mode 100644 index 000000000..2e1bd7567 --- /dev/null +++ b/skills/lark-mail/assets/templates/weekly--personal-report.html @@ -0,0 +1,43 @@ + +
                                [姓名] 个人工作周报 · [YYYY 第 NN 周]
                                +
                                [团队] · [角色]|周期 [YYYY-MM-DD] ~ [YYYY-MM-DD]
                                + +
                                本周工作内容
                                + +
                                1. [项目 / 主任务名称]已完成 · 📄 文档 · PR 链接
                                +
                                • [子项 1.1:动作描述,附数据 / 链接]
                                • [子项 1.2:动作描述]
                                • [子项 1.3:动作描述,含具体数字 / 占比 / 时长]
                                + +
                                2. [项目 / 主任务名称]进行中 · 📄 文档
                                +
                                • [子项 2.1:动作 + 当前进度 + 数据]
                                • [子项 2.2:动作 + 当前进度]
                                + +
                                3. [项目 / 主任务名称]已完成
                                +
                                • [子项 3.1]
                                • [子项 3.2]
                                + +
                                下周工作内容
                                + +
                                1. [项目 / 主任务名称]P0 · 预计 [YYYY-MM-DD]
                                +
                                • [子项 1.1:具体动作 + 推进方式,例「先 spike POC,再发 RFC 同协作方对齐方案」]
                                • [子项 1.2:里程碑 / 关键产出 + 完成方式]
                                • [子项 1.3:依赖 / 协作方 / 验收标准]
                                + +
                                2. [项目 / 主任务名称]P0 · 预计 [YYYY-MM-DD]
                                +
                                • [子项 2.1:动作 + 推进方式]
                                • [子项 2.2:里程碑 / 关键产出]
                                • [子项 2.3:依赖 / 验收]
                                + +
                                3. [项目 / 主任务名称]P1 · 预计 [YYYY-MM-DD]
                                +
                                • [子项 3.1:动作 + 推进方式]
                                • [子项 3.2:里程碑]
                                • [子项 3.3:协作方]
                                + +
                                4. [项目 / 主任务名称]P2 · 预计 [YYYY-MM-DD]
                                +
                                • [子项 4.1:动作 + 推进方式]
                                • [子项 4.2:依赖 / 关键产出]
                                + +
                                风险与疑问
                                +
                                • [风险 / 疑问 1] — [背景:描述风险来源 / 触发场景];[影响:会延期 / 阻塞哪些工作];[建议:希望得到的支持 / 决策方向 / 期望响应方(@姓名 / 团队)]
                                • [风险 / 疑问 2] — [背景];[影响];[建议]
                                • [风险 / 疑问 3] — [背景];[影响];[建议]
                                +
                                (若本周无风险 / 疑问,整段替换为:。)
                                + +
                                — [姓名] / [团队] / [日期]|[your@email]
                                diff --git a/skills/lark-mail/assets/templates/weekly--team-report.html b/skills/lark-mail/assets/templates/weekly--team-report.html new file mode 100644 index 000000000..6d26a90c0 --- /dev/null +++ b/skills/lark-mail/assets/templates/weekly--team-report.html @@ -0,0 +1,9 @@ + +
                                本周工作
                                +
                                1. [项目 / 事件 1 名称]@[姓名 a]@[姓名 b]
                                  文档:[文档名]
                                2. [项目 / 事件 2 名称]@[姓名 g]
                                  技术方案:[文档名] · 设计稿:[设计稿名]
                                  • [子项 2.1:含孙子项的动作主题]
                                    • [孙子项 2.1.1:必要时再细分一层;不需要可整段删除]@[姓名 h]
                                    • [孙子项 2.1.2]
                                  • [子项 2.2]@[姓名 i],进行中
                                  • [子项 2.3]@[姓名 j],评审中
                                3. [项目 / 事件 3 名称]@[姓名 k]@[姓名 l]阻塞
                                  阻塞分析:[文档名]
                                +
                                下周工作
                                +
                                1. [重点 1:项目 / 事件名]@[姓名 o],预计 [YYYY-MM-DD]
                                2. [重点 2:含子重点的项目]
                                  1. [子重点 a:动作 / 推进方式]@[姓名 p]
                                  2. [子重点 b:动作]@[姓名 q]
                                3. [重点 3:项目 / 事件名]@[姓名 r]@[姓名 s],预计 [YYYY-MM-DD]
                                4. [重点 4:项目 / 事件名]@[姓名 t],预计 [YYYY-MM-DD]
                                +
                                — [姓名] / [团队] / [日期]|[your@email]
                                diff --git a/skills/lark-mail/references/lark-mail-draft-create.md b/skills/lark-mail/references/lark-mail-draft-create.md index eeb016af9..36b6682dd 100644 --- a/skills/lark-mail/references/lark-mail-draft-create.md +++ b/skills/lark-mail/references/lark-mail-draft-create.md @@ -8,6 +8,8 @@ 如需修改已有草稿,不要使用此命令,请使用 `lark-cli mail +draft-edit`。 +**CRITICAL - 编辑邮件内容前 MUST 先用 Read 工具读取 [references/lark-mail-html.md](references/lark-mail-html.md),其中包含邮件书写规范** + ## 安全约束 此命令创建草稿——**不会**发送邮件。用户可以在飞书邮件 UI 中打开草稿查看详情,确认后再进入后续操作。因此: @@ -44,7 +46,8 @@ lark-cli mail +draft-create --to alice@example.com --subject '测试' --body 'te |------|------|------| | `--to ` | 否 | 完整收件人列表,多个用逗号分隔。支持 `Alice ` 格式。省略时草稿不带收件人(之后可通过 `+draft-edit` 添加) | | `--subject ` | 是 | 草稿主题 | -| `--body ` | 是 | 邮件正文。推荐使用 HTML 获得富文本排版;也支持纯文本(自动检测)。使用 `--plain-text` 可强制纯文本模式。支持 `` 相对路径自动解析为内嵌图片(仅支持相对路径,不支持绝对路径) | +| `--body ` | 二选一 | 邮件正文。推荐使用 HTML 获得富文本排版;也支持纯文本(自动检测)。使用 `--plain-text` 可强制纯文本模式。支持 `` 相对路径自动解析为内嵌图片(仅支持相对路径,不支持绝对路径)。与 `--body-file` 互斥 | +| `--body-file ` | 二选一 | 从文件读取邮件正文 HTML(相对路径,仅限 cwd 子树)。与 `--body` 互斥。文件大小上限 32 MB | | `--from ` | 否 | 发件人邮箱地址(EML From 头)。使用别名(send_as)发信时,设为别名地址并配合 `--mailbox` 指定所属邮箱。省略时使用邮箱主地址 | | `--mailbox ` | 否 | 邮箱地址,指定草稿所属的邮箱(默认回退到 `--from`,再回退到 `me`)。当发件人(`--from`)与邮箱不同时使用,如通过别名或 send_as 地址发信。可通过 `accessible_mailboxes` 查询可用邮箱 | | `--cc ` | 否 | 完整抄送列表,多个用逗号分隔 | diff --git a/skills/lark-mail/references/lark-mail-draft-edit.md b/skills/lark-mail/references/lark-mail-draft-edit.md index 366c5cf82..4dee1e5b8 100644 --- a/skills/lark-mail/references/lark-mail-draft-edit.md +++ b/skills/lark-mail/references/lark-mail-draft-edit.md @@ -10,7 +10,9 @@ - `--set-cc` - `--set-bcc` -**正文编辑和其他高级操作必须通过 `--patch-file`**。没有 `--set-body` flag。 +**正文整体替换的快捷方式:** `--body ` / `--body-file `(二选一互斥)会自动展开为 `set_body` op。如果只想做整段正文替换且不需要保留引用区,用这两个 flag 即可,无需写 patch-file。要保留引用区或做更精细的 op 组合,仍走 `--patch-file`。两个入口与 `--patch-file` 内的 `set_body` / `set_reply_body` 互斥。 + +**CRITICAL - 编辑邮件内容前 MUST 先用 Read 工具读取 [references/lark-mail-html.md](references/lark-mail-html.md),其中包含邮件书写规范** ### 正文编辑:两个 op 的选择 @@ -72,6 +74,8 @@ lark-cli mail +draft-edit --draft-id --set-subject '测试' --dry-run | `--set-to ` | 否 | 用此处提供的地址替换整个 To 收件人列表 | | `--set-cc ` | 否 | 用此处提供的地址替换整个 Cc 抄送列表 | | `--set-bcc ` | 否 | 用此处提供的地址替换整个 Bcc 密送列表 | +| `--body ` | 否 | 整段替换正文(自动展开为 `set_body` op)。与 `--body-file` 互斥;与 `--patch-file` 内的 `set_body` / `set_reply_body` op 互斥 | +| `--body-file ` | 否 | 从文件读取正文 HTML(相对路径,仅限 cwd 子树)。与 `--body` 互斥。文件大小上限 32 MB | | `--set-priority ` | 否 | 设置邮件优先级:`high`、`normal`、`low`。设为 `normal` 会清除已有优先级 | | `--set-event-summary ` | 否 | 设置日程标题。需同时设置 `--set-event-start` 和 `--set-event-end` | | `--set-event-start