From dbb0d838b8d8712dda4e38acb017e3b708514314 Mon Sep 17 00:00:00 2001 From: Emilien Mantel Date: Thu, 12 Mar 2026 19:26:15 +0100 Subject: [PATCH 1/4] :hammer: Add make helpers for tests --- .gitignore | 4 ++++ Makefile | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 7dfc8d5..0468fa9 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,10 @@ # Config /.retyc +# Test coverage +coverage.out +coverage.html + # Dependency cache vendor/ diff --git a/Makefile b/Makefile index 8b7b816..df091f8 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ BINARY := retyc VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") LDFLAGS := -X $(MODULE)/cmd.Version=$(VERSION) -.PHONY: all build build-prod test vet lint clean install +.PHONY: all build build-prod test test-coverage vet lint clean install ## Default target: dev build (config in .retyc/ relative to CWD) all: build @@ -20,6 +20,13 @@ build-prod: test: go test -race ./... +## Run tests with coverage report and generate HTML report +test-coverage: + go test -race -coverprofile=coverage.out ./... + go tool cover -func=coverage.out | tail -1 + go tool cover -html=coverage.out -o coverage.html + @echo "Coverage report: coverage.html" + ## Run go vet vet: go vet ./... From 55f097de944f8098bcd7feb67db0e1b8c634ef99 Mon Sep 17 00:00:00 2001 From: Emilien Mantel Date: Thu, 12 Mar 2026 19:27:06 +0100 Subject: [PATCH 2/4] :white_check_mark: Add tests globally --- cmd/helpers_test.go | 155 ++++++++++ cmd/transfer.go | 5 +- internal/api/client_test.go | 283 ++++++++++++++++++ internal/api/login_test.go | 134 +++++++++ internal/api/transfer_test.go | 252 ++++++++++++++++ internal/api/user_test.go | 121 ++++++++ internal/auth/oidc_test.go | 529 +++++++++++++++++++++++++++++++++ internal/config/config_test.go | 168 +++++++++++ internal/crypto/age_test.go | 410 +++++++++++++++++++++++++ internal/ui/format_test.go | 28 ++ 10 files changed, 2082 insertions(+), 3 deletions(-) create mode 100644 cmd/helpers_test.go create mode 100644 internal/api/client_test.go create mode 100644 internal/api/login_test.go create mode 100644 internal/api/transfer_test.go create mode 100644 internal/api/user_test.go create mode 100644 internal/auth/oidc_test.go create mode 100644 internal/config/config_test.go create mode 100644 internal/crypto/age_test.go create mode 100644 internal/ui/format_test.go diff --git a/cmd/helpers_test.go b/cmd/helpers_test.go new file mode 100644 index 0000000..7b8c0ff --- /dev/null +++ b/cmd/helpers_test.go @@ -0,0 +1,155 @@ +package cmd + +import ( + "testing" +) + +func TestPtrOr_Nil(t *testing.T) { + got := ptrOr(nil, "fallback") + if got != "fallback" { + t.Errorf("ptrOr(nil) = %q, want fallback", got) + } +} + +func TestPtrOr_NonNil(t *testing.T) { + s := "actual value" + + got := ptrOr(&s, "fallback") + if got != "actual value" { + t.Errorf("ptrOr(&s) = %q, want actual value", got) + } +} + +func TestFormatExpiry(t *testing.T) { + tests := []struct { + seconds int + expected string + }{ + {0, "never"}, + {60, "in 1m"}, + {300, "in 5m"}, + {3599, "in 59m"}, + {3600, "in 1h"}, + {3601, "in 1h"}, + {7200, "in 2h"}, + {86399, "in 23h"}, + {86400, "in 1d"}, + {172800, "in 2d"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + got := formatExpiry(tt.seconds) + if got != tt.expected { + t.Errorf("formatExpiry(%d) = %q, want %q", tt.seconds, got, tt.expected) + } + }) + } +} + +func TestGenerateTransferPassphrase_Length(t *testing.T) { + p, err := generateTransferPassphrase() + if err != nil { + t.Fatalf("generateTransferPassphrase() error = %v", err) + } + + if len(p) != 32 { + t.Errorf("passphrase length = %d, want 32", len(p)) + } +} + +func TestGenerateTransferPassphrase_Charset(t *testing.T) { + p, err := generateTransferPassphrase() + if err != nil { + t.Fatalf("generateTransferPassphrase() error = %v", err) + } + + for i, c := range p { + if c < 0x21 || c > 0x7e { + t.Errorf("passphrase[%d] = %q (0x%02x), want printable ASCII (0x21–0x7e)", i, c, c) + } + } +} + +func TestGenerateTransferPassphrase_Unique(t *testing.T) { + p1, err := generateTransferPassphrase() + if err != nil { + t.Fatal(err) + } + + p2, err := generateTransferPassphrase() + if err != nil { + t.Fatal(err) + } + + if p1 == p2 { + t.Error("generateTransferPassphrase() returned identical passphrases on two consecutive calls") + } +} + +func TestRandomLetters_Length(t *testing.T) { + for _, n := range []int{0, 1, 8, 16, 32} { + got := randomLetters(n) + if len(got) != n { + t.Errorf("randomLetters(%d) length = %d, want %d", n, len(got), n) + } + } +} + +func TestRandomLetters_Charset(t *testing.T) { + s := randomLetters(64) + for i, c := range s { + if c < 'a' || c > 'z' { + t.Errorf("randomLetters[%d] = %q, want lowercase a-z", i, c) + } + } +} + +func TestRandomLetters_Unique(t *testing.T) { + a := randomLetters(16) + b := randomLetters(16) + + if a == b { + t.Error("randomLetters() returned identical strings on two consecutive calls") + } +} + +func TestIsOfflineToken_Valid(t *testing.T) { + // Manually crafted JWT with payload {"typ":"Offline"} (base64url-encoded). + // header.payload.signature — signature is irrelevant for this check. + //nolint:gosec // G101: test fixture JWT, not a real credential + jwt := "eyJhbGciOiJSUzI1NiJ9.eyJ0eXAiOiJPZmZsaW5lIn0.signature" + + if !isOfflineToken(jwt) { + t.Error("isOfflineToken() = false, want true for Offline typ") + } +} + +func TestIsOfflineToken_Regular(t *testing.T) { + // JWT payload {"typ":"Bearer"}. + //nolint:gosec // G101: test fixture JWT, not a real credential + jwt := "eyJhbGciOiJSUzI1NiJ9.eyJ0eXAiOiJCZWFyZXIifQ.signature" + + if isOfflineToken(jwt) { + t.Error("isOfflineToken() = true, want false for Bearer typ") + } +} + +func TestIsOfflineToken_NotJWT(t *testing.T) { + if isOfflineToken("not.a.jwt.at.all.parts") { + t.Error("isOfflineToken() = true, want false for non-JWT string") + } +} + +func TestIsOfflineToken_Empty(t *testing.T) { + if isOfflineToken("") { + t.Error("isOfflineToken() = true, want false for empty string") + } +} + +func TestIsOfflineToken_InvalidBase64(t *testing.T) { + // Three parts but invalid base64 in the payload segment. + if isOfflineToken("header.!!!invalid!!!.signature") { + t.Error("isOfflineToken() = true, want false for invalid base64 payload") + } +} diff --git a/cmd/transfer.go b/cmd/transfer.go index d1d3ae2..bebbeea 100644 --- a/cmd/transfer.go +++ b/cmd/transfer.go @@ -382,7 +382,6 @@ func ptrOr(s *string, fallback string) string { return *s } - // uploadChunkSize is the size of each plaintext chunk before encryption. const uploadChunkSize = 8 * 1024 * 1024 // 8 MB @@ -1122,10 +1121,10 @@ func downloadTransferFile( // crypto/rand.Int is used to avoid modulo bias. func generateTransferPassphrase() (string, error) { const chars = "!\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" - max := big.NewInt(int64(len(chars))) + maxchar := big.NewInt(int64(len(chars))) result := make([]byte, 32) for i := range result { - n, err := cryptorand.Int(cryptorand.Reader, max) + n, err := cryptorand.Int(cryptorand.Reader, maxchar) if err != nil { return "", err } diff --git a/internal/api/client_test.go b/internal/api/client_test.go new file mode 100644 index 0000000..3a7752d --- /dev/null +++ b/internal/api/client_test.go @@ -0,0 +1,283 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "golang.org/x/oauth2" +) + +func staticTokenSource() oauth2.TokenSource { + return oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "test-token", + TokenType: "Bearer", + Expiry: time.Now().Add(time.Hour), + }) +} + +func newTestClient(srv *httptest.Server) *Client { + return New(srv.URL, "retyc-test/1.0", staticTokenSource(), false, false) +} + +func TestClient_Get_Success(t *testing.T) { + type body struct { + Name string `json:"name"` + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("method = %s, want GET", r.Method) + } + _ = json.NewEncoder(w).Encode(body{Name: "test-value"}) + })) + defer srv.Close() + + var result body + if err := newTestClient(srv).Get(context.Background(), "/test", &result); err != nil { + t.Fatalf("Get() error = %v", err) + } + + if result.Name != "test-value" { + t.Errorf("Name = %q, want test-value", result.Name) + } +} + +func TestClient_Get_Non2xx(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + })) + defer srv.Close() + + err := newTestClient(srv).Get(context.Background(), "/missing", nil) + if err == nil { + t.Fatal("Get() should return error for 404") + } + + if !strings.Contains(err.Error(), "API error 404") { + t.Errorf("error %q should contain 'API error 404'", err.Error()) + } +} + +func TestClient_Get_InvalidJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "not valid json {{") + })) + defer srv.Close() + + var result struct{ Name string } + + err := newTestClient(srv).Get(context.Background(), "/bad", &result) + if err == nil { + t.Error("Get() should return error for invalid JSON response") + } +} + +func TestClient_Post_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method = %s, want POST", r.Method) + } + + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type = %q, want application/json", ct) + } + + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]string{"id": "new-id"}) + })) + defer srv.Close() + + body := strings.NewReader(`{"name":"test"}`) + var result map[string]string + + if err := newTestClient(srv).Post(context.Background(), "/resource", body, &result); err != nil { + t.Fatalf("Post() error = %v", err) + } + + if result["id"] != "new-id" { + t.Errorf("id = %q, want new-id", result["id"]) + } +} + +func TestClient_Put_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + t.Errorf("method = %s, want PUT", r.Method) + } + + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type = %q, want application/json", ct) + } + + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + body := strings.NewReader(`{"key":"val"}`) + if err := newTestClient(srv).Put(context.Background(), "/resource", body, nil); err != nil { + t.Fatalf("Put() error = %v", err) + } +} + +func TestClient_Put_NilBody(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if ct := r.Header.Get("Content-Type"); ct != "" { + t.Errorf("Content-Type = %q, want empty for nil body", ct) + } + + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + // nil body — no Content-Type should be set. + if err := newTestClient(srv).Put(context.Background(), "/re-enable", nil, nil); err != nil { + t.Fatalf("Put() nil body error = %v", err) + } +} + +func TestClient_Delete_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + t.Errorf("method = %s, want DELETE", r.Method) + } + + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + if err := newTestClient(srv).Delete(context.Background(), "/resource/123"); err != nil { + t.Fatalf("Delete() error = %v", err) + } +} + +func TestClient_GetBytes_Success(t *testing.T) { + expected := []byte{0xde, 0xad, 0xbe, 0xef} + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(expected) + })) + defer srv.Close() + + got, err := newTestClient(srv).GetBytes(context.Background(), "/file/chunk") + if err != nil { + t.Fatalf("GetBytes() error = %v", err) + } + + if string(got) != string(expected) { + t.Errorf("GetBytes() = %v, want %v", got, expected) + } +} + +func TestClient_GetBytes_Non2xx(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "server error", http.StatusInternalServerError) + })) + defer srv.Close() + + _, err := newTestClient(srv).GetBytes(context.Background(), "/fail") + if err == nil { + t.Error("GetBytes() should return error for non-2xx status") + } +} + +func TestClient_PostMultipartChunk(t *testing.T) { + chunkData := []byte("encrypted chunk payload") + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { + t.Errorf("Content-Type = %q, want multipart/form-data", r.Header.Get("Content-Type")) + } + + if err := r.ParseMultipartForm(10 << 20); err != nil { //nolint:gosec // G120: test server, no DoS risk + t.Errorf("ParseMultipartForm() error = %v", err) + + return + } + + f, hdr, err := r.FormFile("upload_file") + if err != nil { + t.Errorf("FormFile(upload_file) error = %v", err) + http.Error(w, "missing field", http.StatusBadRequest) + + return + } + defer f.Close() //nolint:errcheck + + if hdr.Filename != "chunk.age" { + t.Errorf("filename = %q, want chunk.age", hdr.Filename) + } + + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + if err := newTestClient(srv).PostMultipartChunk(context.Background(), "/file/abc/0", chunkData); err != nil { + t.Fatalf("PostMultipartChunk() error = %v", err) + } +} + +func TestUserAgentTransport_SetsHeader(t *testing.T) { + var gotUA string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + tr := &UserAgentTransport{ + UserAgent: "retyc-test/9.9", + Base: http.DefaultTransport, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatal(err) + } + + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip() error = %v", err) + } + + defer resp.Body.Close() //nolint:errcheck + + if gotUA != "retyc-test/9.9" { + t.Errorf("User-Agent = %q, want retyc-test/9.9", gotUA) + } +} + +func TestUserAgentTransport_NilBase(t *testing.T) { + var gotUA string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + // Base is nil — should fall back to http.DefaultTransport. + tr := &UserAgentTransport{UserAgent: "nil-base-test/1.0"} + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatal(err) + } + + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip() with nil Base error = %v", err) + } + + defer resp.Body.Close() //nolint:errcheck + + if gotUA != "nil-base-test/1.0" { + t.Errorf("User-Agent = %q, want nil-base-test/1.0", gotUA) + } +} diff --git a/internal/api/login_test.go b/internal/api/login_test.go new file mode 100644 index 0000000..457b902 --- /dev/null +++ b/internal/api/login_test.go @@ -0,0 +1,134 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestFetchOIDCConfig_Success(t *testing.T) { + // srv is referenced inside the handler to build the issuer URL dynamically. + var srv *httptest.Server + + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/login/config/public": + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": srv.URL, + "client_id": "device", + "scopes": []string{"openid", "offline_access"}, + }) + + case "/.well-known/openid-configuration": + _ = json.NewEncoder(w).Encode(map[string]string{ + "device_authorization_endpoint": srv.URL + "/device", + "token_endpoint": srv.URL + "/token", + "end_session_endpoint": srv.URL + "/logout", + }) + + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + cfg, err := FetchOIDCConfig(context.Background(), srv.URL, http.DefaultClient) + if err != nil { + t.Fatalf("FetchOIDCConfig() error = %v", err) + } + + if cfg.Issuer != srv.URL { + t.Errorf("Issuer = %q, want %q", cfg.Issuer, srv.URL) + } + + if cfg.ClientID != "device" { + t.Errorf("ClientID = %q, want device", cfg.ClientID) + } + + if cfg.DeviceAuthURL != srv.URL+"/device" { + t.Errorf("DeviceAuthURL = %q, want %q", cfg.DeviceAuthURL, srv.URL+"/device") + } + + if cfg.TokenURL != srv.URL+"/token" { + t.Errorf("TokenURL = %q, want %q", cfg.TokenURL, srv.URL+"/token") + } + + if cfg.EndSessionURL != srv.URL+"/logout" { + t.Errorf("EndSessionURL = %q, want %q", cfg.EndSessionURL, srv.URL+"/logout") + } +} + +func TestFetchOIDCConfig_LoginEndpointError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "server error", http.StatusInternalServerError) + })) + defer srv.Close() + + _, err := FetchOIDCConfig(context.Background(), srv.URL, http.DefaultClient) + if err == nil { + t.Error("FetchOIDCConfig() should return error when login endpoint returns 500") + } +} + +func TestFetchOIDCConfig_InvalidLoginJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "not json") + })) + defer srv.Close() + + _, err := FetchOIDCConfig(context.Background(), srv.URL, http.DefaultClient) + if err == nil { + t.Error("FetchOIDCConfig() should return error for invalid login config JSON") + } +} + +func TestFetchOIDCConfig_DiscoveryError(t *testing.T) { + var srv *httptest.Server + + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/login/config/public": + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": srv.URL, + "client_id": "device", + "scopes": []string{"openid"}, + }) + + default: + http.Error(w, "discovery unavailable", http.StatusInternalServerError) + } + })) + defer srv.Close() + + _, err := FetchOIDCConfig(context.Background(), srv.URL, http.DefaultClient) + if err == nil { + t.Error("FetchOIDCConfig() should return error when OIDC discovery returns 500") + } +} + +func TestFetchOIDCConfig_InvalidDiscoveryJSON(t *testing.T) { + var srv *httptest.Server + + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/login/config/public": + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": srv.URL, + "client_id": "device", + "scopes": []string{"openid"}, + }) + + default: + fmt.Fprint(w, "not valid json {{") + } + })) + defer srv.Close() + + _, err := FetchOIDCConfig(context.Background(), srv.URL, http.DefaultClient) + if err == nil { + t.Error("FetchOIDCConfig() should return error for invalid discovery JSON") + } +} diff --git a/internal/api/transfer_test.go b/internal/api/transfer_test.go new file mode 100644 index 0000000..ac9c864 --- /dev/null +++ b/internal/api/transfer_test.go @@ -0,0 +1,252 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestListTransfers(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/share" { + t.Errorf("path = %q, want /share", r.URL.Path) + } + + if got := r.URL.Query().Get("list_type"); got != "sent" { + t.Errorf("list_type = %q, want sent", got) + } + + if got := r.URL.Query().Get("page"); got != "1" { + t.Errorf("page = %q, want 1", got) + } + + title := "My Transfer" + _ = json.NewEncoder(w).Encode(TransferPage{ + Items: []Transfer{{ID: "abc", Title: &title, Status: "active", CreatedAt: time.Now()}}, + Total: 1, + Page: 1, + Pages: 1, + }) + })) + defer srv.Close() + + page, err := newTestClient(srv).ListTransfers(context.Background(), "sent", 1) + if err != nil { + t.Fatalf("ListTransfers() error = %v", err) + } + + if page.Total != 1 { + t.Errorf("Total = %d, want 1", page.Total) + } + + if len(page.Items) != 1 || page.Items[0].ID != "abc" { + t.Errorf("Items[0].ID = %q, want abc", page.Items[0].ID) + } +} + +func TestGetTransferDetails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/share/share-id-123/details" { + t.Errorf("path = %q, want /share/share-id-123/details", r.URL.Path) + } + + webURL := "https://retyc.com/t/share-id-123" + _ = json.NewEncoder(w).Encode(TransferDetails{WebURL: webURL}) + })) + defer srv.Close() + + details, err := newTestClient(srv).GetTransferDetails(context.Background(), "share-id-123") + if err != nil { + t.Fatalf("GetTransferDetails() error = %v", err) + } + + if details.WebURL != "https://retyc.com/t/share-id-123" { + t.Errorf("WebURL = %q, want https://retyc.com/t/share-id-123", details.WebURL) + } +} + +func TestListFiles(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/share/share-abc/files" { + t.Errorf("path = %q, want /share/share-abc/files", r.URL.Path) + } + + if got := r.URL.Query().Get("page"); got != "2" { + t.Errorf("page = %q, want 2", got) + } + + _ = json.NewEncoder(w).Encode(TransferFilePage{ + Items: []TransferFile{{ID: "file-1", OriginalSize: 1024}}, + Total: 1, + Page: 2, + Pages: 2, + }) + })) + defer srv.Close() + + fp, err := newTestClient(srv).ListFiles(context.Background(), "share-abc", 2) + if err != nil { + t.Fatalf("ListFiles() error = %v", err) + } + + if len(fp.Items) != 1 || fp.Items[0].ID != "file-1" { + t.Errorf("Items[0].ID = %q, want file-1", fp.Items[0].ID) + } +} + +func TestCreateShare_WithEmails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("Decode() error = %v", err) + } + + emails, ok := body["emails"].([]any) + if !ok || len(emails) != 1 { + t.Errorf("emails = %v, want a slice with 1 item", body["emails"]) + } + + _ = json.NewEncoder(w).Encode(ShareCreateResponse{ID: "new-share-id", Slug: "abc123"}) + })) + defer srv.Close() + + title := "Test Transfer" + + resp, err := newTestClient(srv).CreateShare(context.Background(), 3600, &title, true, []string{"user@example.com"}) + if err != nil { + t.Fatalf("CreateShare() error = %v", err) + } + + if resp.ID != "new-share-id" { + t.Errorf("ID = %q, want new-share-id", resp.ID) + } +} + +func TestCreateShare_NilEmails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("Decode() error = %v", err) + } + + // nil emails should be normalised to an empty array, not JSON null. + emails, ok := body["emails"].([]any) + if !ok { + t.Errorf("emails should be a JSON array, got %T: %v", body["emails"], body["emails"]) + } + + if len(emails) != 0 { + t.Errorf("emails length = %d, want 0", len(emails)) + } + + _ = json.NewEncoder(w).Encode(ShareCreateResponse{ID: "share-id"}) + })) + defer srv.Close() + + _, err := newTestClient(srv).CreateShare(context.Background(), 3600, nil, false, nil) + if err != nil { + t.Fatalf("CreateShare() nil emails error = %v", err) + } +} + +func TestCompleteTransfer(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/share/share-xyz/complete" { + t.Errorf("path = %q, want /share/share-xyz/complete", r.URL.Path) + } + + if r.Method != http.MethodPut { + t.Errorf("method = %s, want PUT", r.Method) + } + + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + req := CompleteTransferRequest{ + SessionPrivateKeyEnc: "enc-priv-key", + SessionPublicKey: "pub-key", + } + + if err := newTestClient(srv).CompleteTransfer(context.Background(), "share-xyz", req); err != nil { + t.Fatalf("CompleteTransfer() error = %v", err) + } +} + +func TestDisableTransfer(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/share/to-disable" { + t.Errorf("path = %q, want /share/to-disable", r.URL.Path) + } + + if r.Method != http.MethodDelete { + t.Errorf("method = %s, want DELETE", r.Method) + } + + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + if err := newTestClient(srv).DisableTransfer(context.Background(), "to-disable"); err != nil { + t.Fatalf("DisableTransfer() error = %v", err) + } +} + +func TestEnableTransfer(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/share/to-enable/re-enable" { + t.Errorf("path = %q, want /share/to-enable/re-enable", r.URL.Path) + } + + if r.Method != http.MethodPut { + t.Errorf("method = %s, want PUT", r.Method) + } + + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + if err := newTestClient(srv).EnableTransfer(context.Background(), "to-enable"); err != nil { + t.Fatalf("EnableTransfer() error = %v", err) + } +} + +func TestUploadChunk(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/file/file-id-abc/3" { + t.Errorf("path = %q, want /file/file-id-abc/3", r.URL.Path) + } + + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + if err := newTestClient(srv).UploadChunk(context.Background(), "file-id-abc", 3, []byte("chunk")); err != nil { + t.Fatalf("UploadChunk() error = %v", err) + } +} + +func TestDownloadChunk(t *testing.T) { + expected := []byte("decrypted chunk data") + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/file/file-id-xyz/0" { + t.Errorf("path = %q, want /file/file-id-xyz/0", r.URL.Path) + } + + _, _ = w.Write(expected) + })) + defer srv.Close() + + got, err := newTestClient(srv).DownloadChunk(context.Background(), "file-id-xyz", 0) + if err != nil { + t.Fatalf("DownloadChunk() error = %v", err) + } + + if string(got) != string(expected) { + t.Errorf("DownloadChunk() = %q, want %q", got, expected) + } +} diff --git a/internal/api/user_test.go b/internal/api/user_test.go new file mode 100644 index 0000000..2e530e1 --- /dev/null +++ b/internal/api/user_test.go @@ -0,0 +1,121 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetMe_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/user/me" { + t.Errorf("path = %q, want /user/me", r.URL.Path) + } + + name := "Alice Dupont" + _ = json.NewEncoder(w).Encode(userMeResponse{ + User: UserInfo{ + ID: "user-123", + Email: "alice@example.com", + FullName: &name, + }, + }) + })) + defer srv.Close() + + user, err := newTestClient(srv).GetMe(context.Background()) + if err != nil { + t.Fatalf("GetMe() error = %v", err) + } + + if user.ID != "user-123" { + t.Errorf("ID = %q, want user-123", user.ID) + } + + if user.Email != "alice@example.com" { + t.Errorf("Email = %q, want alice@example.com", user.Email) + } +} + +func TestGetActiveKey_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/user/me/key/active" { + t.Errorf("path = %q, want /user/me/key/active", r.URL.Path) + } + + _ = json.NewEncoder(w).Encode(UserKey{ + ID: "key-abc", + UserID: "user-123", + PublicKey: "age1pq1testpublickey", + PrivateKeyEnc: "AGE-ENCRYPTED-PRIVATE-KEY", + }) + })) + defer srv.Close() + + key, err := newTestClient(srv).GetActiveKey(context.Background()) + if err != nil { + t.Fatalf("GetActiveKey() error = %v", err) + } + + if key == nil { + t.Fatal("GetActiveKey() returned nil") + } + + if key.ID != "key-abc" { + t.Errorf("ID = %q, want key-abc", key.ID) + } +} + +func TestGetActiveKey_Null(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "null") + })) + defer srv.Close() + + key, err := newTestClient(srv).GetActiveKey(context.Background()) + if err != nil { + t.Fatalf("GetActiveKey() error = %v", err) + } + + if key != nil { + t.Errorf("GetActiveKey() = %v, want nil for null response", key) + } +} + +func TestGetQuota_Success(t *testing.T) { + maxCount := 10 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/user/quota" { + t.Errorf("path = %q, want /user/quota", r.URL.Path) + } + + _ = json.NewEncoder(w).Encode(UserQuota{ + CountShare: 3, + MaxCountShare: &maxCount, + UsedStorage: 1024 * 1024, + MaxStorage: 10 * 1024 * 1024, + }) + })) + defer srv.Close() + + quota, err := newTestClient(srv).GetQuota(context.Background()) + if err != nil { + t.Fatalf("GetQuota() error = %v", err) + } + + if quota.CountShare != 3 { + t.Errorf("CountShare = %d, want 3", quota.CountShare) + } + + if quota.MaxCountShare == nil || *quota.MaxCountShare != 10 { + t.Errorf("MaxCountShare = %v, want 10", quota.MaxCountShare) + } + + if quota.UsedStorage != 1024*1024 { + t.Errorf("UsedStorage = %d, want %d", quota.UsedStorage, 1024*1024) + } +} diff --git a/internal/auth/oidc_test.go b/internal/auth/oidc_test.go new file mode 100644 index 0000000..50982ed --- /dev/null +++ b/internal/auth/oidc_test.go @@ -0,0 +1,529 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/retyc/retyc-cli/internal/config" + "golang.org/x/oauth2" +) + +func TestTokenFromResponse_WithExpiry(t *testing.T) { + tr := TokenResponse{ + AccessToken: "access", + TokenType: "Bearer", + RefreshToken: "refresh", + ExpiresIn: 3600, + } + + before := time.Now() + tok := tokenFromResponse(tr) + after := time.Now() + + if tok.AccessToken != "access" { + t.Errorf("AccessToken = %q, want access", tok.AccessToken) + } + + if tok.TokenType != "Bearer" { + t.Errorf("TokenType = %q, want Bearer", tok.TokenType) + } + + if tok.RefreshToken != "refresh" { + t.Errorf("RefreshToken = %q, want refresh", tok.RefreshToken) + } + + low := before.Add(3600 * time.Second) + high := after.Add(3600 * time.Second) + + if tok.Expiry.Before(low) || tok.Expiry.After(high) { + t.Errorf("Expiry = %v, want between %v and %v", tok.Expiry, low, high) + } +} + +func TestTokenFromResponse_ZeroExpiry(t *testing.T) { + tr := TokenResponse{AccessToken: "access", ExpiresIn: 0} + + tok := tokenFromResponse(tr) + if !tok.Expiry.IsZero() { + t.Errorf("Expiry should be zero when ExpiresIn=0, got %v", tok.Expiry) + } +} + +func TestRequestDeviceCode_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(DeviceAuthResponse{ + DeviceCode: "dev-code-123", + UserCode: "ABCD-1234", + VerificationURIComplete: "https://example.com/activate?code=ABCD-1234", + ExpiresIn: 300, + Interval: 5, + }) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test-client", Scopes: []string{"openid"}, DeviceAuthURL: srv.URL} + + resp, err := requestDeviceCode(cfg, http.DefaultClient) + if err != nil { + t.Fatalf("requestDeviceCode() error = %v", err) + } + + if resp.DeviceCode != "dev-code-123" { + t.Errorf("DeviceCode = %q, want dev-code-123", resp.DeviceCode) + } + + if resp.UserCode != "ABCD-1234" { + t.Errorf("UserCode = %q, want ABCD-1234", resp.UserCode) + } +} + +func TestRequestDeviceCode_NonOK(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal server error", http.StatusInternalServerError) + })) + defer srv.Close() + + cfg := config.OIDCConfig{DeviceAuthURL: srv.URL} + + _, err := requestDeviceCode(cfg, http.DefaultClient) + if err == nil { + t.Error("requestDeviceCode() should return error for non-2xx status") + } +} + +func TestRequestDeviceCode_InvalidJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "not valid json") + })) + defer srv.Close() + + cfg := config.OIDCConfig{DeviceAuthURL: srv.URL} + + _, err := requestDeviceCode(cfg, http.DefaultClient) + if err == nil { + t.Error("requestDeviceCode() should return error for invalid JSON") + } +} + +func TestPollToken_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(TokenResponse{ //nolint:gosec // G117: test fixture, not real credentials + AccessToken: "new-access-token", + TokenType: "Bearer", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + }) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + tok, err := pollToken(cfg, "device-code", http.DefaultClient) + if err != nil { + t.Fatalf("pollToken() error = %v", err) + } + + if tok == nil { + t.Fatal("pollToken() returned nil token") + } + + if tok.AccessToken != "new-access-token" { + t.Errorf("AccessToken = %q, want new-access-token", tok.AccessToken) + } +} + +func TestPollToken_AuthorizationPending(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + tok, err := pollToken(cfg, "device-code", http.DefaultClient) + if err != nil { + t.Errorf("pollToken() error = %v, want nil for authorization_pending", err) + } + + if tok != nil { + t.Error("pollToken() should return nil token for authorization_pending") + } +} + +func TestPollToken_SlowDown(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "slow_down"}) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + tok, err := pollToken(cfg, "device-code", http.DefaultClient) + if err != nil || tok != nil { + t.Errorf("pollToken() = (%v, %v), want (nil, nil) for slow_down", tok, err) + } +} + +func TestPollToken_ExpiredToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "expired_token"}) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + _, err := pollToken(cfg, "device-code", http.DefaultClient) + if err == nil { + t.Error("pollToken() should return error for expired_token") + } +} + +func TestPollToken_AccessDenied(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "access_denied"}) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + _, err := pollToken(cfg, "device-code", http.DefaultClient) + if err == nil { + t.Error("pollToken() should return error for access_denied") + } +} + +func TestPollToken_UnknownError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": "custom_error", + "error_description": "something went wrong", + }) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + _, err := pollToken(cfg, "device-code", http.DefaultClient) + if err == nil { + t.Error("pollToken() should return error for unknown error field") + } +} + +func TestPollToken_InvalidJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "not json at all") + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + _, err := pollToken(cfg, "device-code", http.DefaultClient) + if err == nil { + t.Error("pollToken() should return error for invalid JSON response") + } +} + +func TestRefresh_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(TokenResponse{ //nolint:gosec // G117: test fixture, not real credentials + AccessToken: "fresh-access", + TokenType: "Bearer", + RefreshToken: "new-refresh", + ExpiresIn: 3600, + }) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + tok, err := Refresh(context.Background(), cfg, "old-refresh-token", http.DefaultClient) + if err != nil { + t.Fatalf("Refresh() error = %v", err) + } + + if tok.AccessToken != "fresh-access" { + t.Errorf("AccessToken = %q, want fresh-access", tok.AccessToken) + } + + if tok.RefreshToken != "new-refresh" { + t.Errorf("RefreshToken = %q, want new-refresh", tok.RefreshToken) + } +} + +func TestRefresh_PreservesRefreshToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Server omits refresh_token — the original must be preserved. + _ = json.NewEncoder(w).Encode(TokenResponse{ //nolint:gosec // G117: test fixture, not real credentials + AccessToken: "fresh-access", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + tok, err := Refresh(context.Background(), cfg, "original-refresh", http.DefaultClient) + if err != nil { + t.Fatalf("Refresh() error = %v", err) + } + + if tok.RefreshToken != "original-refresh" { + t.Errorf("RefreshToken = %q, want original-refresh (should be preserved)", tok.RefreshToken) + } +} + +func TestRefresh_ErrorResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": "invalid_grant", + "error_description": "token expired", + }) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + _, err := Refresh(context.Background(), cfg, "bad-token", http.DefaultClient) + if err == nil { + t.Error("Refresh() should return error for invalid_grant") + } +} + +func TestRevoke_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", EndSessionURL: srv.URL} + + if err := Revoke(context.Background(), cfg, "refresh-token", http.DefaultClient); err != nil { + t.Errorf("Revoke() error = %v", err) + } +} + +func TestRevoke_NoEndpointURL(t *testing.T) { + cfg := config.OIDCConfig{EndSessionURL: ""} + + err := Revoke(context.Background(), cfg, "token", http.DefaultClient) + if err == nil { + t.Error("Revoke() should return error when EndSessionURL is empty") + } +} + +func TestRevoke_HTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer srv.Close() + + cfg := config.OIDCConfig{ClientID: "test", EndSessionURL: srv.URL} + + if err := Revoke(context.Background(), cfg, "refresh-token", http.DefaultClient); err == nil { + t.Error("Revoke() should return error for non-2xx status") + } +} + +func TestRefreshingTokenSource_ValidToken(t *testing.T) { + validTok := &oauth2.Token{ + AccessToken: "valid-access", + TokenType: "Bearer", + RefreshToken: "some-refresh", + Expiry: time.Now().Add(time.Hour), + } + + src := NewRefreshingTokenSource(validTok, config.OIDCConfig{}, http.DefaultClient, false) + + got, err := src.Token() + if err != nil { + t.Fatalf("Token() error = %v", err) + } + + if got.AccessToken != "valid-access" { + t.Errorf("AccessToken = %q, want valid-access", got.AccessToken) + } +} + +func TestRefreshingTokenSource_ExpiredToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(TokenResponse{ //nolint:gosec // G117: test fixture, not real credentials + AccessToken: "refreshed-access", + TokenType: "Bearer", + RefreshToken: "new-refresh", + ExpiresIn: 3600, + }) + })) + defer srv.Close() + + t.Setenv("RETYC_CONFIG_DIR", t.TempDir()) + + expiredTok := &oauth2.Token{ + AccessToken: "expired-access", + RefreshToken: "valid-refresh", + Expiry: time.Now().Add(-time.Hour), + } + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + src := NewRefreshingTokenSource(expiredTok, cfg, http.DefaultClient, true) + + got, err := src.Token() + if err != nil { + t.Fatalf("Token() error = %v", err) + } + + if got.AccessToken != "refreshed-access" { + t.Errorf("AccessToken = %q, want refreshed-access", got.AccessToken) + } +} + +func TestRefreshingTokenSource_NoRefreshToken(t *testing.T) { + expiredTok := &oauth2.Token{ + AccessToken: "expired", + Expiry: time.Now().Add(-time.Hour), + } + + src := NewRefreshingTokenSource(expiredTok, config.OIDCConfig{}, http.DefaultClient, false) + + _, err := src.Token() + if err == nil { + t.Fatal("Token() should return error when token expired and no refresh token available") + } + + if err != ErrNoRefreshToken { + t.Errorf("error = %v, want ErrNoRefreshToken", err) + } +} + +func TestGetValidToken_EnvToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(TokenResponse{ //nolint:gosec // G117: test fixture, not real credentials + AccessToken: "env-refreshed-token", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + })) + defer srv.Close() + + t.Setenv("RETYC_TOKEN", "offline-token") + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + tok, err := GetValidToken(context.Background(), cfg, http.DefaultClient) + if err != nil { + t.Fatalf("GetValidToken() error = %v", err) + } + + if tok.AccessToken != "env-refreshed-token" { + t.Errorf("AccessToken = %q, want env-refreshed-token", tok.AccessToken) + } +} + +func TestGetValidToken_ValidStored(t *testing.T) { + t.Setenv("RETYC_CONFIG_DIR", t.TempDir()) + t.Setenv("RETYC_TOKEN", "") + + validTok := &oauth2.Token{ + AccessToken: "stored-valid-token", + TokenType: "Bearer", + RefreshToken: "refresh", + Expiry: time.Now().Add(time.Hour), + } + + if err := config.SaveToken(validTok); err != nil { + t.Fatalf("SaveToken() error = %v", err) + } + + tok, err := GetValidToken(context.Background(), config.OIDCConfig{}, http.DefaultClient) + if err != nil { + t.Fatalf("GetValidToken() error = %v", err) + } + + if tok.AccessToken != "stored-valid-token" { + t.Errorf("AccessToken = %q, want stored-valid-token", tok.AccessToken) + } +} + +func TestGetValidToken_ExpiredWithRefresh(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(TokenResponse{ //nolint:gosec // G117: test fixture, not real credentials + AccessToken: "refreshed-token", + TokenType: "Bearer", + RefreshToken: "new-refresh", + ExpiresIn: 3600, + }) + })) + defer srv.Close() + + t.Setenv("RETYC_CONFIG_DIR", t.TempDir()) + t.Setenv("RETYC_TOKEN", "") + + expiredTok := &oauth2.Token{ + AccessToken: "expired", + RefreshToken: "valid-refresh", + Expiry: time.Now().Add(-time.Hour), + } + + if err := config.SaveToken(expiredTok); err != nil { + t.Fatalf("SaveToken() error = %v", err) + } + + cfg := config.OIDCConfig{ClientID: "test", TokenURL: srv.URL} + + tok, err := GetValidToken(context.Background(), cfg, http.DefaultClient) + if err != nil { + t.Fatalf("GetValidToken() error = %v", err) + } + + if tok.AccessToken != "refreshed-token" { + t.Errorf("AccessToken = %q, want refreshed-token", tok.AccessToken) + } +} + +func TestGetValidToken_ExpiredNoRefresh(t *testing.T) { + t.Setenv("RETYC_CONFIG_DIR", t.TempDir()) + t.Setenv("RETYC_TOKEN", "") + + expiredTok := &oauth2.Token{ + AccessToken: "expired", + Expiry: time.Now().Add(-time.Hour), + } + + if err := config.SaveToken(expiredTok); err != nil { + t.Fatalf("SaveToken() error = %v", err) + } + + _, err := GetValidToken(context.Background(), config.OIDCConfig{}, http.DefaultClient) + if err == nil { + t.Fatal("GetValidToken() should return error when token expired and no refresh token") + } + + if err != ErrNoRefreshToken { + t.Errorf("error = %v, want ErrNoRefreshToken", err) + } +} + +func TestGetValidToken_NoToken(t *testing.T) { + t.Setenv("RETYC_CONFIG_DIR", t.TempDir()) + t.Setenv("RETYC_TOKEN", "") + + _, err := GetValidToken(context.Background(), config.OIDCConfig{}, http.DefaultClient) + if err == nil { + t.Fatal("GetValidToken() should return error when no token stored") + } + + if err != ErrNoToken { + t.Errorf("error = %v, want ErrNoToken", err) + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..229dc31 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,168 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/spf13/viper" + "golang.org/x/oauth2" +) + +// resetViper resets viper global state after the test to avoid cross-test pollution. +func resetViper(t *testing.T) { + t.Helper() + t.Cleanup(func() { viper.Reset() }) +} + +func TestSetDefaults(t *testing.T) { + resetViper(t) + SetDefaults() + + if got := viper.GetString("api.base_url"); got != defaultAPIBaseURL { + t.Errorf("api.base_url = %q, want %q", got, defaultAPIBaseURL) + } + + if got := viper.GetBool("keyring.enabled"); !got { + t.Error("keyring.enabled should be true by default") + } + + if got := viper.GetInt("keyring.ttl"); got != 60 { + t.Errorf("keyring.ttl = %d, want 60", got) + } +} + +func TestLoad_Defaults(t *testing.T) { + resetViper(t) + SetDefaults() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + + if cfg.API.BaseURL != defaultAPIBaseURL { + t.Errorf("API.BaseURL = %q, want %q", cfg.API.BaseURL, defaultAPIBaseURL) + } + + if !cfg.Keyring.Enabled { + t.Error("Keyring.Enabled should be true by default") + } + + if cfg.Keyring.TTL != 60 { + t.Errorf("Keyring.TTL = %d, want 60", cfg.Keyring.TTL) + } +} + +func TestSaveToken_LoadToken_RoundTrip(t *testing.T) { + t.Setenv("RETYC_CONFIG_DIR", t.TempDir()) + + tok := &oauth2.Token{ + AccessToken: "access-token-value", + TokenType: "Bearer", + RefreshToken: "refresh-token-value", + Expiry: time.Now().Add(time.Hour), + } + + if err := SaveToken(tok); err != nil { + t.Fatalf("SaveToken() error = %v", err) + } + + got, err := LoadToken() + if err != nil { + t.Fatalf("LoadToken() error = %v", err) + } + + if got.AccessToken != tok.AccessToken { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, tok.AccessToken) + } + + if got.RefreshToken != tok.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", got.RefreshToken, tok.RefreshToken) + } +} + +func TestSaveToken_FilePermissions(t *testing.T) { + dir := t.TempDir() + t.Setenv("RETYC_CONFIG_DIR", dir) + + if err := SaveToken(&oauth2.Token{AccessToken: "test"}); err != nil { + t.Fatalf("SaveToken() error = %v", err) + } + + info, err := os.Stat(filepath.Join(dir, "token.json")) + if err != nil { + t.Fatalf("os.Stat() error = %v", err) + } + + if perm := info.Mode().Perm(); perm != 0600 { + t.Errorf("token.json permissions = %04o, want 0600", perm) + } +} + +func TestLoadToken_NoFile(t *testing.T) { + t.Setenv("RETYC_CONFIG_DIR", t.TempDir()) + + _, err := LoadToken() + if err == nil { + t.Fatal("LoadToken() should return error when no token file exists") + } + + if err.Error() != "no stored token found" { + t.Errorf("error = %q, want 'no stored token found'", err.Error()) + } +} + +func TestLoadToken_CorruptJSON(t *testing.T) { + dir := t.TempDir() + t.Setenv("RETYC_CONFIG_DIR", dir) + + if err := os.WriteFile(filepath.Join(dir, "token.json"), []byte("not valid json {{{"), 0600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + _, err := LoadToken() + if err == nil { + t.Error("LoadToken() should return error for corrupt JSON") + } +} + +func TestDeleteToken(t *testing.T) { + t.Setenv("RETYC_CONFIG_DIR", t.TempDir()) + + if err := SaveToken(&oauth2.Token{AccessToken: "test"}); err != nil { + t.Fatalf("SaveToken() error = %v", err) + } + + if err := DeleteToken(); err != nil { + t.Fatalf("DeleteToken() error = %v", err) + } + + _, err := LoadToken() + if err == nil { + t.Error("LoadToken() should fail after DeleteToken()") + } +} + +func TestDeleteToken_NoFile(t *testing.T) { + t.Setenv("RETYC_CONFIG_DIR", t.TempDir()) + + if err := DeleteToken(); err != nil { + t.Errorf("DeleteToken() should not error when no file exists, got: %v", err) + } +} + +func TestConfigDir_EnvOverride(t *testing.T) { + dir := t.TempDir() + t.Setenv("RETYC_CONFIG_DIR", dir) + + got, err := ConfigDir() + if err != nil { + t.Fatalf("ConfigDir() error = %v", err) + } + + if got != dir { + t.Errorf("ConfigDir() = %q, want %q", got, dir) + } +} diff --git a/internal/crypto/age_test.go b/internal/crypto/age_test.go new file mode 100644 index 0000000..524ccef --- /dev/null +++ b/internal/crypto/age_test.go @@ -0,0 +1,410 @@ +package crypto + +import ( + "bytes" + "strings" + "testing" +) + +func TestGenerateKeyPair(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatalf("GenerateKeyPair() error = %v", err) + } + + rec := identity.Recipient().String() + if !strings.HasPrefix(rec, "age1pq1") { + t.Errorf("recipient %q does not start with age1pq1", rec) + } +} + +func TestGenerateKeyPair_Uniqueness(t *testing.T) { + id1, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + id2, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + if id1.String() == id2.String() { + t.Error("GenerateKeyPair() returned identical private keys") + } +} + +func TestParseIdentity_Valid(t *testing.T) { + original, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + parsed, err := ParseIdentity(original.String()) + if err != nil { + t.Fatalf("ParseIdentity() error = %v", err) + } + + if parsed.String() != original.String() { + t.Error("ParseIdentity() round-trip changed the private key") + } +} + +func TestParseIdentity_Whitespace(t *testing.T) { + original, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + _, err = ParseIdentity(" " + original.String() + "\n") + if err != nil { + t.Errorf("ParseIdentity() should strip whitespace, got: %v", err) + } +} + +func TestParseIdentity_Invalid(t *testing.T) { + _, err := ParseIdentity("not-a-valid-private-key") + if err == nil { + t.Error("ParseIdentity() should return error for invalid key") + } +} + +func TestParseRecipient_Valid(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + pubKey := identity.Recipient().String() + + parsed, err := ParseRecipient(pubKey) + if err != nil { + t.Fatalf("ParseRecipient() error = %v", err) + } + + if parsed.String() != pubKey { + t.Error("ParseRecipient() round-trip changed the public key") + } +} + +func TestParseRecipient_Invalid(t *testing.T) { + _, err := ParseRecipient("not-a-valid-public-key") + if err == nil { + t.Error("ParseRecipient() should return error for invalid key") + } +} + +func TestEncryptDecrypt_RoundTrip(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + plaintext := []byte("hello, world — retyc test payload") + + ciphertext, err := Encrypt(plaintext, identity.Recipient()) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + + got, err := Decrypt(ciphertext, identity) + if err != nil { + t.Fatalf("Decrypt() error = %v", err) + } + + if !bytes.Equal(got, plaintext) { + t.Errorf("Decrypt() = %q, want %q", got, plaintext) + } +} + +func TestEncryptDecrypt_EmptyPlaintext(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + ciphertext, err := Encrypt([]byte{}, identity.Recipient()) + if err != nil { + t.Fatalf("Encrypt() empty plaintext error = %v", err) + } + + got, err := Decrypt(ciphertext, identity) + if err != nil { + t.Fatalf("Decrypt() empty plaintext error = %v", err) + } + + if len(got) != 0 { + t.Errorf("Decrypt() = %q, want empty", got) + } +} + +func TestEncryptDecrypt_WrongKey(t *testing.T) { + identity1, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + identity2, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + ciphertext, err := Encrypt([]byte("secret"), identity1.Recipient()) + if err != nil { + t.Fatal(err) + } + + _, err = Decrypt(ciphertext, identity2) + if err == nil { + t.Error("Decrypt() should fail when using the wrong key") + } +} + +func TestDecrypt_InvalidCiphertext(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + _, err = Decrypt("this is not valid age ciphertext", identity) + if err == nil { + t.Error("Decrypt() should return error for invalid ciphertext") + } +} + +func TestEncryptToString_DecryptToString_RoundTrip(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + original := "plaintext string value" + + ciphertext, err := EncryptToString(original, identity.Recipient()) + if err != nil { + t.Fatalf("EncryptToString() error = %v", err) + } + + got, err := DecryptToString(ciphertext, identity) + if err != nil { + t.Fatalf("DecryptToString() error = %v", err) + } + + if got != original { + t.Errorf("DecryptToString() = %q, want %q", got, original) + } +} + +func TestEncryptWithPassphrase_DecryptWithPassphrase_RoundTrip(t *testing.T) { + passphrase := "correct-horse-battery-staple" + plaintext := []byte("super secret data") + + ciphertext, err := EncryptWithPassphrase(plaintext, passphrase) + if err != nil { + t.Fatalf("EncryptWithPassphrase() error = %v", err) + } + + got, err := DecryptWithPassphrase(ciphertext, passphrase) + if err != nil { + t.Fatalf("DecryptWithPassphrase() error = %v", err) + } + + if !bytes.Equal(got, plaintext) { + t.Errorf("DecryptWithPassphrase() = %q, want %q", got, plaintext) + } +} + +func TestDecryptWithPassphrase_WrongPassphrase(t *testing.T) { + ciphertext, err := EncryptWithPassphrase([]byte("secret"), "correct-pass") + if err != nil { + t.Fatal(err) + } + + _, err = DecryptWithPassphrase(ciphertext, "wrong-pass") + if err == nil { + t.Error("DecryptWithPassphrase() should fail with wrong passphrase") + } +} + +func TestDecryptToStringWithPassphrase_RoundTrip(t *testing.T) { + passphrase := "test-passphrase-123" + original := "value to encrypt and decrypt" + + ciphertext, err := EncryptWithPassphrase([]byte(original), passphrase) + if err != nil { + t.Fatalf("EncryptWithPassphrase() error = %v", err) + } + + got, err := DecryptToStringWithPassphrase(ciphertext, passphrase) + if err != nil { + t.Fatalf("DecryptToStringWithPassphrase() error = %v", err) + } + + if got != original { + t.Errorf("DecryptToStringWithPassphrase() = %q, want %q", got, original) + } +} + +func TestEncryptBinaryForKey_DecryptBinary_RoundTrip(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + plaintext := []byte{0x00, 0xff, 0xab, 0xcd, 0xef, 0x42, 0x00} + + encrypted, err := EncryptBinaryForKey(plaintext, identity.Recipient().String()) + if err != nil { + t.Fatalf("EncryptBinaryForKey() error = %v", err) + } + + got, err := DecryptBinary(encrypted, identity) + if err != nil { + t.Fatalf("DecryptBinary() error = %v", err) + } + + if !bytes.Equal(got, plaintext) { + t.Errorf("DecryptBinary() = %v, want %v", got, plaintext) + } +} + +func TestEncryptBinaryForKey_WrongKey(t *testing.T) { + identity1, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + identity2, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + encrypted, err := EncryptBinaryForKey([]byte("data"), identity1.Recipient().String()) + if err != nil { + t.Fatal(err) + } + + _, err = DecryptBinary(encrypted, identity2) + if err == nil { + t.Error("DecryptBinary() should fail with wrong key") + } +} + +func TestEncryptBinaryForKey_InvalidKey(t *testing.T) { + _, err := EncryptBinaryForKey([]byte("data"), "not-a-valid-public-key") + if err == nil { + t.Error("EncryptBinaryForKey() should return error for invalid key") + } +} + +func TestDecryptBinary_InvalidData(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + _, err = DecryptBinary([]byte("not valid binary age data"), identity) + if err == nil { + t.Error("DecryptBinary() should return error for invalid data") + } +} + +func TestEncryptStringForKeys_SingleKey(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + original := "test value" + + ciphertext, err := EncryptStringForKeys(original, []string{identity.Recipient().String()}) + if err != nil { + t.Fatalf("EncryptStringForKeys() error = %v", err) + } + + got, err := DecryptToString(ciphertext, identity) + if err != nil { + t.Fatalf("DecryptToString() error = %v", err) + } + + if got != original { + t.Errorf("got %q, want %q", got, original) + } +} + +func TestEncryptStringForKeys_MultipleKeys(t *testing.T) { + identity1, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + identity2, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + original := "multi-recipient secret" + + ciphertext, err := EncryptStringForKeys(original, []string{ + identity1.Recipient().String(), + identity2.Recipient().String(), + }) + if err != nil { + t.Fatalf("EncryptStringForKeys() error = %v", err) + } + + // Both recipients must be able to decrypt independently. + got1, err := DecryptToString(ciphertext, identity1) + if err != nil { + t.Fatalf("identity1 DecryptToString() error = %v", err) + } + + got2, err := DecryptToString(ciphertext, identity2) + if err != nil { + t.Fatalf("identity2 DecryptToString() error = %v", err) + } + + if got1 != original || got2 != original { + t.Errorf("got1=%q got2=%q, want %q", got1, got2, original) + } +} + +func TestEncryptStringForKeys_SkipsEmptyKeys(t *testing.T) { + identity, err := GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + + // Empty strings mixed with a valid key should be silently skipped. + ciphertext, err := EncryptStringForKeys("hello", []string{ + "", + identity.Recipient().String(), + "", + }) + if err != nil { + t.Fatalf("EncryptStringForKeys() error = %v", err) + } + + got, err := DecryptToString(ciphertext, identity) + if err != nil || got != "hello" { + t.Errorf("DecryptToString() = %q, err = %v", got, err) + } +} + +func TestEncryptStringForKeys_NoValidKeys(t *testing.T) { + _, err := EncryptStringForKeys("hello", []string{"", ""}) + if err == nil { + t.Error("EncryptStringForKeys() should return error with no valid recipients") + } + + if !strings.Contains(err.Error(), "no valid recipients") { + t.Errorf("error %q should mention 'no valid recipients'", err.Error()) + } +} + +func TestEncryptStringForKeys_InvalidKey(t *testing.T) { + _, err := EncryptStringForKeys("hello", []string{"not-a-valid-key"}) + if err == nil { + t.Error("EncryptStringForKeys() should return error for invalid key") + } +} diff --git a/internal/ui/format_test.go b/internal/ui/format_test.go new file mode 100644 index 0000000..c8f2a02 --- /dev/null +++ b/internal/ui/format_test.go @@ -0,0 +1,28 @@ +package ui + +import "testing" + +func TestFormatSize(t *testing.T) { + tests := []struct { + input int64 + expected string + }{ + {0, "0 B"}, + {1, "1 B"}, + {1023, "1023 B"}, + {1024, "1.0 KiB"}, + {1536, "1.5 KiB"}, + {1048576, "1.0 MiB"}, + {1073741824, "1.0 GiB"}, + {1099511627776, "1.0 TiB"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + got := FormatSize(tt.input) + if got != tt.expected { + t.Errorf("FormatSize(%d) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} From 0f3e27bb42cb3bdc4d4b70a51096eaaad0a7e837 Mon Sep 17 00:00:00 2001 From: Emilien Mantel Date: Thu, 12 Mar 2026 19:37:10 +0100 Subject: [PATCH 3/4] :construction_worker: Add tests in CI --- .github/workflows/_ci.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/_ci.yml b/.github/workflows/_ci.yml index e5a6c98..2526387 100644 --- a/.github/workflows/_ci.yml +++ b/.github/workflows/_ci.yml @@ -28,6 +28,9 @@ jobs: - name: Test run: go test -race -coverprofile=coverage.out ./... + - name: Coverage summary + run: go tool cover -func=coverage.out | tail -1 + - name: Build (dev) run: CGO_ENABLED=0 go build ./... From 2171cefe7133f5bc6673c7b7249488541ca660a9 Mon Sep 17 00:00:00 2001 From: Emilien Mantel Date: Thu, 12 Mar 2026 19:38:02 +0100 Subject: [PATCH 4/4] :art: Fix var name --- cmd/transfer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/transfer.go b/cmd/transfer.go index bebbeea..285a2c9 100644 --- a/cmd/transfer.go +++ b/cmd/transfer.go @@ -1121,10 +1121,10 @@ func downloadTransferFile( // crypto/rand.Int is used to avoid modulo bias. func generateTransferPassphrase() (string, error) { const chars = "!\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" - maxchar := big.NewInt(int64(len(chars))) + maxChar := big.NewInt(int64(len(chars))) result := make([]byte, 32) for i := range result { - n, err := cryptorand.Int(cryptorand.Reader, maxchar) + n, err := cryptorand.Int(cryptorand.Reader, maxChar) if err != nil { return "", err }