diff --git a/Contributing.md b/Contributing.md index 44c75433..19a59025 100644 --- a/Contributing.md +++ b/Contributing.md @@ -188,3 +188,7 @@ We also need to build the server docker image and push it to Gadget's container ```bash make upload-container-image version=0.0.x ``` + +### Getting PASETO tokens locally + +You can sign PASETO tokens locally with this handy online tool: https://token.dev/paseto/. Ensure you use the V2 algorithm in the public mode, and copy the PASTEO public and private key from the `development` folder. diff --git a/Dockerfile b/Dockerfile index 6f93e691..a276bba7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,7 +27,7 @@ RUN go mod download # copy everything else and build the project COPY . ./ -RUN make release/server_linux_$TARGETARCH +RUN make release/server_linux_$TARGETARCH release/cached_linux_$TARGETARCH FROM buildpack-deps:bullseye AS build-release-stage ARG TARGETARCH @@ -48,11 +48,14 @@ RUN mkdir -p /home/main/secrets VOLUME /home/main/secrets/tls VOLUME /home/main/secrets/paseto +COPY --from=build-stage /app/release/cached_linux_${TARGETARCH} cached COPY --from=build-stage /app/release/server_linux_${TARGETARCH} server + COPY migrations migrations COPY entrypoint.sh entrypoint.sh -# smoke test -- ensure the server command can run +# smoke test -- ensure the commands can run RUN ./server --help +RUN ./cached --help ENTRYPOINT ["./entrypoint.sh"] diff --git a/Makefile b/Makefile index 4978a14a..b120fcdc 100644 --- a/Makefile +++ b/Makefile @@ -8,11 +8,13 @@ DB_USER ?= postgres DB_PASS ?= password DB_URI := postgres://$(DB_USER):$(DB_PASS)@$(DB_HOST):5432/dl -GRPC_PORT ?= 5051 GRPC_HOST ?= localhost +GRPC_PORT ?= 5051 +GRPC_CACHED_PORT ?= 5053 -DEV_TOKEN_ADMIN ?= v2.public.eyJzdWIiOiJhZG1pbiIsImlhdCI6IjIwMjEtMTAtMTVUMTE6MjA6MDAuMDM0WiJ9WtEey8KfQQRy21xoHq1C5KQatEevk8RxS47k4bRfMwVCPHumZmVuk6ADcfDHTmSnMtEGfFXdxnYOhRP6Clb_Dw -DEV_TOKEN_PROJECT_1 ?= v2.public.eyJzdWIiOiIxIiwiaWF0IjoiMjAyMS0xMC0xNVQxMToyMDowMC4wMzVaIn2MQ14RfIGpoEycCuvRu9J3CZp6PppUXf5l5w8uKKydN3C31z6f6GgOEPNcnwODqBnX7Pjarpz4i2uzWEqLgQYD +DEV_TOKEN_ADMIN ?= v2.public.eyJzdWIiOiJhZG1pbiJ9yt40HNkcyOUtDeFa_WPS6vi0WiE4zWngDGJLh17TuYvssTudCbOdQEkVDRD-mSNTXLgSRDXUkO-AaEr4ZLO4BQ +DEV_TOKEN_PROJECT_1 ?= v2.public.eyJzdWIiOiIxIn2jV7FOdEXafKDtAnVyDgI4fmIbqU7C1iuhKiL0lDnG1Z5-j6_ObNDd75sZvLZ159-X98_mP4qvwzui0w8pjt8F +DEV_SHARED_READER_TOKEN ?= v2.public.eyJzdWIiOiJzaGFyZWQtcmVhZGVyIn1CxWdB02s9el0Wt7qReARZ-7JtIb4Zj3D4Oiji1yXHqj0orkpbcVlswVUiekECJC16d1NrHwD2FWSwRORZn8gK PKG_GO_FILES := $(shell find pkg/ -type f -name '*.go') INTERNAL_GO_FILES := $(shell find internal/ -type f -name '*.go') @@ -63,13 +65,11 @@ development/server.key: development/server.crt: development/server.key -build: internal/pb/fs.pb.go internal/pb/fs_grpc.pb.go bin/server bin/client development/server.crt +build: internal/pb/fs.pb.go internal/pb/fs_grpc.pb.go internal/pb/cache.pb.go internal/pb/cache_grpc.pb.go bin/server bin/client bin/cached development/server.crt lint: golangci-lint run - - release/%_linux_amd64: cmd/%/main.go $(PKG_GO_FILES) $(INTERNAL_GO_FILES) go.sum CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build $(BUILD_FLAGS) -o $@ $< @@ -86,8 +86,9 @@ release/migrations.tar.gz: migrations/* tar -zcf $@ migrations release: build -release: release/server_linux_amd64 release/server_macos_amd64 release/server_macos_arm64 -release: release/client_linux_amd64 release/client_macos_amd64 release/client_macos_arm64 +release: release/server_linux_amd64 release/server_macos_amd64 release/server_macos_arm64 release/server_linux_arm64 +release: release/client_linux_amd64 release/client_macos_amd64 release/client_macos_arm64 release/client_linux_arm64 +release: release/cached_linux_amd64 release/cached_macos_amd64 release/cached_macos_arm64 release/cached_linux_arm64 release: release/migrations.tar.gz test: export DB_URI = postgres://$(DB_USER):$(DB_PASS)@$(DB_HOST):5432/dl_tests @@ -121,6 +122,11 @@ server-profile: export DL_ENV=dev server-profile: internal/pb/fs.pb.go internal/pb/fs_grpc.pb.go go run cmd/server/main.go --dburi $(DB_URI) --port $(GRPC_PORT) --profile cpu.prof --log-level info +cached: export DL_ENV=dev +cached: export DL_TOKEN=$(DEV_SHARED_READER_TOKEN) +cached: internal/pb/cache.pb.go internal/pb/cache_grpc.pb.go + go run cmd/cached/main.go --upstream-host $(GRPC_HOST) --upstream-port $(GRPC_PORT) --port $(GRPC_CACHED_PORT) --staging-path tmp/cache-stage + client-update: export DL_TOKEN=$(DEV_TOKEN_PROJECT_1) client-update: export DL_SKIP_SSL_VERIFICATION=1 client-update: @@ -169,6 +175,11 @@ client-getcache: export DL_SKIP_SSL_VERIFICATION=1 client-getcache: go run cmd/client/main.go getcache --host $(GRPC_HOST) --path input/cache +client-getcached: export DL_TOKEN=$(DEV_TOKEN_ADMIN) +client-getcached: export DL_SKIP_SSL_VERIFICATION=1 +client-getcached: + go run cmd/client/main.go getcached --host $(GRPC_HOST) --port $(GRPC_CACHED_PORT) --path input/cache + client-gc-contents: export DL_TOKEN=$(DEV_TOKEN_ADMIN) client-gc-contents: export DL_SKIP_SSL_VERIFICATION=1 client-gc-contents: @@ -242,8 +253,8 @@ else cd js && npm install endif -js/src/pb: $(PROTO_FILES) - cd js && mkdir -p ./src/pb && npx protoc --experimental_allow_proto3_optional --ts_out ./src/pb --ts_opt long_type_bigint,ts_nocheck,eslint_disable,add_pb_suffix --proto_path ../internal/pb/ ../$^ +js/src/pb: internal/pb/fs.proto + cd js && mkdir -p ./src/pb && npx protoc --experimental_allow_proto3_optional --ts_out ./src/pb --ts_opt long_type_bigint,ts_nocheck,eslint_disable,add_pb_suffix --proto_path ../internal/pb/ ../internal/pb/fs.proto js/dist: js/node_modules js/src/pb cd js && npm run build diff --git a/cmd/cached/main.go b/cmd/cached/main.go new file mode 100644 index 00000000..20a7b13e --- /dev/null +++ b/cmd/cached/main.go @@ -0,0 +1,7 @@ +package main + +import "github.com/gadget-inc/dateilager/pkg/cli" + +func main() { + cli.CacheDaemonExecute() +} diff --git a/development/paseto.key b/development/paseto.key new file mode 100644 index 00000000..cf19319d --- /dev/null +++ b/development/paseto.key @@ -0,0 +1,3 @@ +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEILTL+0PfTOIQcn2VPkpxMwf6Gbt9n4UEFDjZ4RuUKjd0 +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/development/paseto.pub b/development/paseto.pub index aee98703..98bc10bf 100644 --- a/development/paseto.pub +++ b/development/paseto.pub @@ -1,3 +1,3 @@ -----BEGIN PUBLIC KEY----- -MCowBQYDK2VwAyEASKQkA/AxlNCdOHTnp5McesmQ+y756VTtGz8Xrt1G0fs= ------END PUBLIC KEY----- +MCowBQYDK2VwAyEAHrnbu7wEfAP9cGBOAHHwmH4Wsot1ciXBHwBBXQ4gsaI= +-----END PUBLIC KEY----- \ No newline at end of file diff --git a/flake.nix b/flake.nix index 07a0d818..8e5f810d 100644 --- a/flake.nix +++ b/flake.nix @@ -54,6 +54,9 @@ postgresql = pkgs.postgresql_14; + golangci-lint = pkgs.golangci-lint; + + glibcLocales = pkgs.glibcLocales; ## DateiLager outputs dateilager = callPackage ./. { @@ -72,6 +75,8 @@ flake.packages.postgresql flake.packages.dev flake.packages.clean + flake.packages.golangci-lint + flake.packages.glibcLocales git protobuf protoc-gen-go @@ -83,6 +88,9 @@ shellHook = '' # prepend the built binaries to the $PATH export PATH="./bin":$PATH + + # silence ginko deprecations -- they come from the csi test suite that we don't control + export ACK_GINKGO_DEPRECATIONS=1.16.5 ''; }; } diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 244a0e12..92f51876 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -18,9 +18,10 @@ const ( type Role int const ( - None Role = iota - Project - Admin + None Role = iota + Project // read and write to one project + Admin // read and write to any project + SharedReader // read the shared caches, but no specific project data ) type Auth struct { @@ -36,14 +37,17 @@ func (a Auth) String() string { return fmt.Sprintf("project[%d]", *a.Project) case Admin: return "admin" + case SharedReader: + return "sharedReader" default: return "unknown" } } var ( - noAuth = Auth{Role: None} - adminAuth = Auth{Role: Admin} + noAuth = Auth{Role: None} + adminAuth = Auth{Role: Admin} + sharedReaderAuth = Auth{Role: SharedReader} ) type AuthValidator struct { @@ -71,6 +75,10 @@ func (av *AuthValidator) Validate(ctx context.Context, token string) (Auth, erro return adminAuth, nil } + if payload.Subject == "shared-reader" { + return sharedReaderAuth, nil + } + project, err := strconv.ParseInt(payload.Subject, 10, 64) if err != nil { return noAuth, fmt.Errorf("parse Paseto subject %v: %w", payload.Subject, err) diff --git a/internal/key/key.go b/internal/key/key.go index 3d3a1d9b..70c90506 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -1,6 +1,8 @@ package key import ( + "time" + "github.com/gadget-inc/dateilager/pkg/stringutil" "go.opentelemetry.io/otel/attribute" "go.uber.org/zap" @@ -36,7 +38,9 @@ const ( Worker = IntKey("dl.worker") WorkerCount = IntKey("dl.worker_count") Ignores = StringSliceKey("dl.ignores") + DurationMS = DurationKey("dl.duration_ms") CloneToProject = Int64Key("dl.clone_to_project") + CachePath = StringKey("dl.cache_path") ) var ( @@ -148,3 +152,13 @@ func (isk Int64SliceKey) Field(value []int64) zap.Field { func (isk Int64SliceKey) Attribute(value []int64) attribute.KeyValue { return attribute.Int64Slice(string(isk), value) } + +type DurationKey string + +func (dk DurationKey) Field(value time.Duration) zap.Field { + return zap.Duration(string(dk), value) +} + +func (dk DurationKey) Attribute(value time.Duration) attribute.KeyValue { + return attribute.Float64(string(dk), float64(value.Milliseconds())) +} diff --git a/internal/pb/cache.pb.go b/internal/pb/cache.pb.go new file mode 100644 index 00000000..09453708 --- /dev/null +++ b/internal/pb/cache.pb.go @@ -0,0 +1,216 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.33.0 +// protoc v4.24.4 +// source: internal/pb/cache.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type PopulateDiskCacheRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` +} + +func (x *PopulateDiskCacheRequest) Reset() { + *x = PopulateDiskCacheRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_internal_pb_cache_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PopulateDiskCacheRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PopulateDiskCacheRequest) ProtoMessage() {} + +func (x *PopulateDiskCacheRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_pb_cache_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PopulateDiskCacheRequest.ProtoReflect.Descriptor instead. +func (*PopulateDiskCacheRequest) Descriptor() ([]byte, []int) { + return file_internal_pb_cache_proto_rawDescGZIP(), []int{0} +} + +func (x *PopulateDiskCacheRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +type PopulateDiskCacheResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Version int64 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` +} + +func (x *PopulateDiskCacheResponse) Reset() { + *x = PopulateDiskCacheResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_internal_pb_cache_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PopulateDiskCacheResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PopulateDiskCacheResponse) ProtoMessage() {} + +func (x *PopulateDiskCacheResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_pb_cache_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PopulateDiskCacheResponse.ProtoReflect.Descriptor instead. +func (*PopulateDiskCacheResponse) Descriptor() ([]byte, []int) { + return file_internal_pb_cache_proto_rawDescGZIP(), []int{1} +} + +func (x *PopulateDiskCacheResponse) GetVersion() int64 { + if x != nil { + return x.Version + } + return 0 +} + +var File_internal_pb_cache_proto protoreflect.FileDescriptor + +var file_internal_pb_cache_proto_rawDesc = []byte{ + 0x0a, 0x17, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x70, 0x62, 0x2f, 0x63, 0x61, + 0x63, 0x68, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x70, 0x62, 0x22, 0x2e, 0x0a, + 0x18, 0x50, 0x6f, 0x70, 0x75, 0x6c, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x43, 0x61, 0x63, + 0x68, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x35, 0x0a, + 0x19, 0x50, 0x6f, 0x70, 0x75, 0x6c, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x43, 0x61, 0x63, + 0x68, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x32, 0x5a, 0x0a, 0x06, 0x43, 0x61, 0x63, 0x68, 0x65, 0x64, 0x12, 0x50, + 0x0a, 0x11, 0x50, 0x6f, 0x70, 0x75, 0x6c, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x43, 0x61, + 0x63, 0x68, 0x65, 0x12, 0x1c, 0x2e, 0x70, 0x62, 0x2e, 0x50, 0x6f, 0x70, 0x75, 0x6c, 0x61, 0x74, + 0x65, 0x44, 0x69, 0x73, 0x6b, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1d, 0x2e, 0x70, 0x62, 0x2e, 0x50, 0x6f, 0x70, 0x75, 0x6c, 0x61, 0x74, 0x65, 0x44, + 0x69, 0x73, 0x6b, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, + 0x61, 0x64, 0x67, 0x65, 0x74, 0x2d, 0x69, 0x6e, 0x63, 0x2f, 0x64, 0x61, 0x74, 0x65, 0x69, 0x6c, + 0x61, 0x67, 0x65, 0x72, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, +} + +var ( + file_internal_pb_cache_proto_rawDescOnce sync.Once + file_internal_pb_cache_proto_rawDescData = file_internal_pb_cache_proto_rawDesc +) + +func file_internal_pb_cache_proto_rawDescGZIP() []byte { + file_internal_pb_cache_proto_rawDescOnce.Do(func() { + file_internal_pb_cache_proto_rawDescData = protoimpl.X.CompressGZIP(file_internal_pb_cache_proto_rawDescData) + }) + return file_internal_pb_cache_proto_rawDescData +} + +var file_internal_pb_cache_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_internal_pb_cache_proto_goTypes = []interface{}{ + (*PopulateDiskCacheRequest)(nil), // 0: pb.PopulateDiskCacheRequest + (*PopulateDiskCacheResponse)(nil), // 1: pb.PopulateDiskCacheResponse +} +var file_internal_pb_cache_proto_depIdxs = []int32{ + 0, // 0: pb.Cached.PopulateDiskCache:input_type -> pb.PopulateDiskCacheRequest + 1, // 1: pb.Cached.PopulateDiskCache:output_type -> pb.PopulateDiskCacheResponse + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_internal_pb_cache_proto_init() } +func file_internal_pb_cache_proto_init() { + if File_internal_pb_cache_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_internal_pb_cache_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PopulateDiskCacheRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_internal_pb_cache_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PopulateDiskCacheResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_internal_pb_cache_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_internal_pb_cache_proto_goTypes, + DependencyIndexes: file_internal_pb_cache_proto_depIdxs, + MessageInfos: file_internal_pb_cache_proto_msgTypes, + }.Build() + File_internal_pb_cache_proto = out.File + file_internal_pb_cache_proto_rawDesc = nil + file_internal_pb_cache_proto_goTypes = nil + file_internal_pb_cache_proto_depIdxs = nil +} diff --git a/internal/pb/cache.proto b/internal/pb/cache.proto new file mode 100644 index 00000000..c4468590 --- /dev/null +++ b/internal/pb/cache.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package pb; + +option go_package = "github.com/gadget-inc/dateilager/pkg/pb"; + +service Cached { + rpc PopulateDiskCache(PopulateDiskCacheRequest) returns (PopulateDiskCacheResponse); +} + +message PopulateDiskCacheRequest { string path = 1; } + +message PopulateDiskCacheResponse { int64 version = 1; }; diff --git a/internal/pb/cache_grpc.pb.go b/internal/pb/cache_grpc.pb.go new file mode 100644 index 00000000..656b7015 --- /dev/null +++ b/internal/pb/cache_grpc.pb.go @@ -0,0 +1,109 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.3.0 +// - protoc v4.24.4 +// source: internal/pb/cache.proto + +package pb + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +const ( + Cached_PopulateDiskCache_FullMethodName = "/pb.Cached/PopulateDiskCache" +) + +// CachedClient is the client API for Cached service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type CachedClient interface { + PopulateDiskCache(ctx context.Context, in *PopulateDiskCacheRequest, opts ...grpc.CallOption) (*PopulateDiskCacheResponse, error) +} + +type cachedClient struct { + cc grpc.ClientConnInterface +} + +func NewCachedClient(cc grpc.ClientConnInterface) CachedClient { + return &cachedClient{cc} +} + +func (c *cachedClient) PopulateDiskCache(ctx context.Context, in *PopulateDiskCacheRequest, opts ...grpc.CallOption) (*PopulateDiskCacheResponse, error) { + out := new(PopulateDiskCacheResponse) + err := c.cc.Invoke(ctx, Cached_PopulateDiskCache_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// CachedServer is the server API for Cached service. +// All implementations must embed UnimplementedCachedServer +// for forward compatibility +type CachedServer interface { + PopulateDiskCache(context.Context, *PopulateDiskCacheRequest) (*PopulateDiskCacheResponse, error) + mustEmbedUnimplementedCachedServer() +} + +// UnimplementedCachedServer must be embedded to have forward compatible implementations. +type UnimplementedCachedServer struct { +} + +func (UnimplementedCachedServer) PopulateDiskCache(context.Context, *PopulateDiskCacheRequest) (*PopulateDiskCacheResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method PopulateDiskCache not implemented") +} +func (UnimplementedCachedServer) mustEmbedUnimplementedCachedServer() {} + +// UnsafeCachedServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to CachedServer will +// result in compilation errors. +type UnsafeCachedServer interface { + mustEmbedUnimplementedCachedServer() +} + +func RegisterCachedServer(s grpc.ServiceRegistrar, srv CachedServer) { + s.RegisterService(&Cached_ServiceDesc, srv) +} + +func _Cached_PopulateDiskCache_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PopulateDiskCacheRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(CachedServer).PopulateDiskCache(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Cached_PopulateDiskCache_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(CachedServer).PopulateDiskCache(ctx, req.(*PopulateDiskCacheRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Cached_ServiceDesc is the grpc.ServiceDesc for Cached service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Cached_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "pb.Cached", + HandlerType: (*CachedServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "PopulateDiskCache", + Handler: _Cached_PopulateDiskCache_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "internal/pb/cache.proto", +} diff --git a/internal/testutil/context.go b/internal/testutil/context.go index 44e72252..238bc2d7 100644 --- a/internal/testutil/context.go +++ b/internal/testutil/context.go @@ -9,6 +9,7 @@ import ( "github.com/gadget-inc/dateilager/internal/db" "github.com/gadget-inc/dateilager/internal/environment" "github.com/gadget-inc/dateilager/pkg/api" + "github.com/gadget-inc/dateilager/pkg/client" "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" "go.uber.org/zap" @@ -34,6 +35,9 @@ func NewTestCtx(t *testing.T, role auth.Role, projects ...int64) TestCtx { Project: project, }) + log := zaptest.NewLogger(t) + zap.ReplaceGlobals(log) + dbConn, err := newDbTestConnector(ctx, os.Getenv("DB_URI")) require.NoError(t, err, "connecting to DB") @@ -42,7 +46,7 @@ func NewTestCtx(t *testing.T, role auth.Role, projects ...int64) TestCtx { return TestCtx{ t: t, - log: zaptest.NewLogger(t), + log: log, dbConn: dbConn, lookup: lookup, ctx: ctx, @@ -87,3 +91,11 @@ func (tc *TestCtx) FsApi() *api.Fs { ContentLookup: tc.ContentLookup(), } } + +func (tc *TestCtx) CachedApi(cl *client.Client, stagingPath string) *api.Cached { + return &api.Cached{ + Env: environment.Test, + Client: cl, + StagingPath: stagingPath, + } +} diff --git a/js/spec/util.ts b/js/spec/util.ts index 922fe249..a7581dcb 100644 --- a/js/spec/util.ts +++ b/js/spec/util.ts @@ -3,7 +3,7 @@ import path from "path"; import { DateiLagerBinaryClient, DateiLagerGrpcClient } from "../src"; export const devAdminToken = - "v2.public.eyJzdWIiOiJhZG1pbiIsImlhdCI6IjIwMjEtMTAtMTVUMTE6MjA6MDAuMDM0WiJ9WtEey8KfQQRy21xoHq1C5KQatEevk8RxS47k4bRfMwVCPHumZmVuk6ADcfDHTmSnMtEGfFXdxnYOhRP6Clb_Dw"; + "v2.public.eyJzdWIiOiJhZG1pbiJ9yt40HNkcyOUtDeFa_WPS6vi0WiE4zWngDGJLh17TuYvssTudCbOdQEkVDRD-mSNTXLgSRDXUkO-AaEr4ZLO4BQ"; export const grpcClient = new DateiLagerGrpcClient({ server: "localhost:5051", diff --git a/pkg/api/cached.go b/pkg/api/cached.go new file mode 100644 index 00000000..9c833d7b --- /dev/null +++ b/pkg/api/cached.go @@ -0,0 +1,113 @@ +package api + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "os" + "path" + "time" + + "github.com/gadget-inc/dateilager/internal/environment" + "github.com/gadget-inc/dateilager/internal/files" + "github.com/gadget-inc/dateilager/internal/key" + "github.com/gadget-inc/dateilager/internal/logger" + "github.com/gadget-inc/dateilager/internal/pb" + "github.com/gadget-inc/dateilager/pkg/client" + "golang.org/x/sys/unix" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type Cached struct { + pb.UnimplementedCachedServer + + Env environment.Env + Client *client.Client + StagingPath string + + // the current directory holding a fully formed downloaded cache + currentDir string + // the current version of the cache on disk at currentDir + currentVersion int64 +} + +func (c *Cached) PopulateDiskCache(ctx context.Context, req *pb.PopulateDiskCacheRequest) (*pb.PopulateDiskCacheResponse, error) { + if c.Env != environment.Dev && c.Env != environment.Test { + return nil, status.Errorf(codes.Unimplemented, "Cached populateDiskCache only implemented in dev and test environments") + } + + err := requireAdminAuth(ctx) + if err != nil { + return nil, err + } + + destination := req.Path + + version, err := c.WriteCache(destination) + if err != nil { + return nil, err + } + + return &pb.PopulateDiskCacheResponse{Version: version}, nil +} + +// check if the destination exists, and if so, if its writable +// hardlink the golden copy into this downstream's destination, creating it if need be +func (c *Cached) WriteCache(destination string) (int64, error) { + if c.currentDir == "" { + return -1, errors.New("no cache prepared, currentDir is nil") + } + + stat, err := os.Stat(destination) + if !os.IsNotExist(err) { + if err != nil { + return -1, fmt.Errorf("failed to stat cache destination %s: %v", destination, err) + } + + if !stat.IsDir() { + return -1, fmt.Errorf("failed to open cache destination %s for writing -- it is already a file", destination) + } + + if unix.Access(destination, unix.W_OK) != nil { + return -1, fmt.Errorf("failed to open cache destination %s for writing -- write permission denied", destination) + } + } + + err = files.HardlinkDir(c.currentDir, destination) + if err != nil { + return -1, fmt.Errorf("failed to hardlink cache to destination %s: %v", destination, err) + } + return c.currentVersion, nil +} + +// Fetch the cache into a spot in the staging dir +func (c *Cached) Prepare(ctx context.Context) error { + start := time.Now() + folderName, err := randomString() + if err != nil { + return err + } + newDir := path.Join(c.StagingPath, folderName) + version, count, err := c.Client.GetCache(ctx, newDir) + if err != nil { + return err + } + + c.currentDir = newDir + c.currentVersion = version + + logger.Info(ctx, "downloaded golden copy", key.Directory.Field(newDir), key.DurationMS.Field(time.Since(start)), key.Version.Field(version), key.Count.Field(int64(count))) + return nil +} + +func randomString() (string, error) { + // Generate a secure random string for the temporary directory name + randBytes := make([]byte, 10) // Adjust the size of the byte slice as needed + if _, err := rand.Read(randBytes); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(randBytes), nil +} diff --git a/pkg/api/fs.go b/pkg/api/fs.go index 79837125..a35f76de 100644 --- a/pkg/api/fs.go +++ b/pkg/api/fs.go @@ -46,6 +46,24 @@ func requireProjectAuth(ctx context.Context) (int64, error) { return -1, status.Errorf(codes.PermissionDenied, "FS endpoint requires project access") } +func requireSharedReaderAuth(ctx context.Context) error { + ctxAuth := ctx.Value(auth.AuthCtxKey).(auth.Auth) + + if ctxAuth.Role == auth.Admin { + return nil + } + + if ctxAuth.Role == auth.Project { + return nil + } + + if ctxAuth.Role == auth.SharedReader { + return nil + } + + return status.Errorf(codes.PermissionDenied, "FS endpoint requires shared reader access") +} + type Fs struct { pb.UnimplementedFsServer @@ -948,7 +966,7 @@ func (f *Fs) GetCache(req *pb.GetCacheRequest, stream pb.Fs_GetCacheServer) erro ctx := stream.Context() trace.SpanFromContext(ctx) - _, err := requireProjectAuth(ctx) + err := requireSharedReaderAuth(ctx) if err != nil { return err } diff --git a/pkg/cached/cached.go b/pkg/cached/cached.go new file mode 100644 index 00000000..3bc36b27 --- /dev/null +++ b/pkg/cached/cached.go @@ -0,0 +1,68 @@ +package cached + +import ( + "context" + "crypto/ed25519" + "crypto/tls" + "net" + + "github.com/gadget-inc/dateilager/internal/auth" + "github.com/gadget-inc/dateilager/internal/logger" + "github.com/gadget-inc/dateilager/internal/pb" + "github.com/gadget-inc/dateilager/pkg/api" + "github.com/gadget-inc/dateilager/pkg/server" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +type CacheServer struct { + Grpc *grpc.Server + Health *health.Server +} + +func NewServer(ctx context.Context, cert *tls.Certificate, pasetoKey ed25519.PublicKey) *CacheServer { + creds := credentials.NewServerTLSFromCert(cert) + validator := auth.NewAuthValidator(pasetoKey) + + grpcServer := grpc.NewServer( + grpc.UnaryInterceptor( + grpc_middleware.ChainUnaryServer( + grpc_recovery.UnaryServerInterceptor(), + otelgrpc.UnaryServerInterceptor(), + logger.UnaryServerInterceptor(), + server.ValidateTokenUnary(validator), + ), + ), + grpc.ReadBufferSize(server.BUFFER_SIZE), + grpc.WriteBufferSize(server.BUFFER_SIZE), + grpc.InitialConnWindowSize(server.INITIAL_CONN_WINDOW_SIZE), + grpc.InitialWindowSize(server.INITIAL_WINDOW_SIZE), + grpc.MaxRecvMsgSize(server.MAX_MESSAGE_SIZE), + grpc.MaxSendMsgSize(server.MAX_MESSAGE_SIZE), + grpc.Creds(creds), + ) + + logger.Info(ctx, "register HealthServer") + healthServer := health.NewServer() + healthpb.RegisterHealthServer(grpcServer, healthServer) + + server := &CacheServer{ + Grpc: grpcServer, + Health: healthServer, + } + + return server +} + +func (s *CacheServer) RegisterCachedServer(ctx context.Context, cached *api.Cached) { + pb.RegisterCachedServer(s.Grpc, cached) +} + +func (s *CacheServer) Serve(lis net.Listener) error { + return s.Grpc.Serve(lis) +} diff --git a/pkg/cli/cached.go b/pkg/cli/cached.go new file mode 100644 index 00000000..6958a31c --- /dev/null +++ b/pkg/cli/cached.go @@ -0,0 +1,185 @@ +package cli + +import ( + "context" + "crypto/tls" + "flag" + "fmt" + "net" + "os" + "os/signal" + "runtime/pprof" + "syscall" + + "github.com/gadget-inc/dateilager/internal/environment" + "github.com/gadget-inc/dateilager/internal/key" + "github.com/gadget-inc/dateilager/internal/logger" + "github.com/gadget-inc/dateilager/internal/telemetry" + "github.com/gadget-inc/dateilager/pkg/api" + "github.com/gadget-inc/dateilager/pkg/cached" + "github.com/gadget-inc/dateilager/pkg/client" + "github.com/gadget-inc/dateilager/pkg/version" + "github.com/spf13/cobra" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func NewCacheDaemonCommand() *cobra.Command { + var ( + profilerEnabled bool = false + shutdownTelemetry func() + ) + + var ( + level *zapcore.Level + encoding string + tracing bool + profilePath string + upstreamHost string + upstreamPort uint16 + certFile string + keyFile string + pasetoFile string + port int + timeout uint + headlessHost string + stagingPath string + ) + + cmd := &cobra.Command{ + Use: "cached", + Short: "DateiLager cache daemon", + DisableAutoGenTag: true, + Version: version.Version, + RunE: func(cmd *cobra.Command, _ []string) error { + cmd.SilenceUsage = true // silence usage when an error occurs after flags have been parsed + + env, err := environment.LoadEnvironment() + if err != nil { + return fmt.Errorf("could not load environment: %w", err) + } + + var config zap.Config + if env == environment.Prod { + config = zap.NewProductionConfig() + } else { + config = zap.NewDevelopmentConfig() + } + + config.Encoding = encoding + config.Level = zap.NewAtomicLevelAt(*level) + config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + + err = logger.Init(config) + if err != nil { + return fmt.Errorf("could not initialize logger: %w", err) + } + + ctx := cmd.Context() + + if profilePath != "" { + file, err := os.Create(profilePath) + if err != nil { + return fmt.Errorf("cannot open profile path %s: %w", profilePath, err) + } + _ = pprof.StartCPUProfile(file) + profilerEnabled = true + } + + if tracing { + shutdownTelemetry = telemetry.Init(ctx, telemetry.Server) + } + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return fmt.Errorf("cannot open TLS cert and key files (%s, %s): %w", certFile, keyFile, err) + } + + pasetoKey, err := parsePublicKey(pasetoFile) + if err != nil { + return fmt.Errorf("cannot parse Paseto public key %s: %w", pasetoFile, err) + } + + cl, err := client.NewClient(ctx, upstreamHost, upstreamPort, client.WithheadlessHost(headlessHost)) + if err != nil { + return err + } + + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + return fmt.Errorf("failed to listen on TCP port %d: %w", port, err) + } + + s := cached.NewServer(ctx, &cert, pasetoKey) + logger.Info(ctx, "register Cached") + cached := &api.Cached{ + Env: env, + Client: cl, + StagingPath: stagingPath, + } + s.RegisterCachedServer(ctx, cached) + + osSignals := make(chan os.Signal, 1) + signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) + go func() { + <-osSignals + s.Grpc.GracefulStop() + }() + + err = cached.Prepare(ctx) + if err != nil { + return fmt.Errorf("failed to prepare cache daemon in %s: %w", stagingPath, err) + } + + logger.Info(ctx, "start cached server", key.Port.Field(port), key.Environment.Field(env.String())) + return s.Serve(listen) + }, + PostRunE: func(cmd *cobra.Command, _ []string) error { + if shutdownTelemetry != nil { + shutdownTelemetry() + } + + if profilerEnabled { + pprof.StopCPUProfile() + } + + return nil + }, + } + + flags := cmd.PersistentFlags() + + level = zap.LevelFlag("log-level", zap.DebugLevel, "Log level") + flags.AddGoFlag(flag.CommandLine.Lookup("log-level")) + flags.StringVar(&encoding, "log-encoding", "console", "Log encoding (console | json)") + flags.BoolVar(&tracing, "tracing", false, "Whether tracing is enabled") + flags.StringVar(&profilePath, "profile", "", "CPU profile output path (profiling enabled if set)") + + flags.IntVar(&port, "port", 5053, "cache API port") + flags.StringVar(&upstreamHost, "upstream-host", "localhost", "GRPC server hostname") + flags.Uint16Var(&upstreamPort, "upstream-port", 5051, "GRPC server port") + flags.StringVar(&headlessHost, "headless-host", "", "Alternative headless hostname to use for round robin connections") + flags.StringVar(&certFile, "cert", "development/server.crt", "TLS cert file") + flags.StringVar(&keyFile, "key", "development/server.key", "TLS key file") + flags.StringVar(&pasetoFile, "paseto", "development/paseto.pub", "Paseto public key file") + flags.UintVar(&timeout, "timeout", 0, "GRPC client timeout (ms)") + + flags.StringVar(&stagingPath, "staging-path", "", "path for staging downloaded caches") + _ = cmd.MarkPersistentFlagRequired("staging-path") + + return cmd +} + +func CacheDaemonExecute() { + ctx := context.Background() + cmd := NewCacheDaemonCommand() + + err := cmd.ExecuteContext(ctx) + + logger.Info(ctx, "shut down server") + _ = logger.Sync() + + if err != nil { + logger.Fatal(ctx, "server failed", zap.Error(err)) + } +} diff --git a/pkg/cli/client.go b/pkg/cli/client.go index 9b74b308..eb1bfe4e 100644 --- a/pkg/cli/client.go +++ b/pkg/cli/client.go @@ -5,6 +5,7 @@ import ( "encoding/json" "flag" "fmt" + "slices" "strings" "time" @@ -21,8 +22,9 @@ import ( ) var ( - shutdownTelemetry func() - span trace.Span + shutdownTelemetry func() + span trace.Span + requiresCachedClient = []string{"getcached"} ) func NewClientCommand() *cobra.Command { @@ -88,9 +90,16 @@ func NewClientCommand() *cobra.Command { if err != nil { return err } - ctx = client.IntoContext(ctx, cl) + if slices.Contains(requiresCachedClient, cmd.CalledAs()) { + cachedClient, err := client.NewCachedClient(ctx, host, port, client.WithheadlessHost(headlessHost)) + if err != nil { + return err + } + ctx = client.CachedIntoContext(ctx, cachedClient) + } + cmd.SetContext(ctx) return nil @@ -109,11 +118,14 @@ func NewClientCommand() *cobra.Command { flags.AddGoFlag(flag.CommandLine.Lookup("log-level")) flags.StringVar(&encoding, "log-encoding", "console", "Log encoding (console | json)") flags.BoolVar(&tracing, "tracing", false, "Whether tracing is enabled") + flags.StringVar(&otelContext, "otel-context", "", "Open Telemetry context") flags.StringVar(&host, "host", "", "GRPC server hostname") flags.Uint16Var(&port, "port", 5051, "GRPC server port") - flags.UintVar(&timeout, "timeout", 0, "GRPC client timeout (ms)") flags.StringVar(&headlessHost, "headless-host", "", "Alternative headless hostname to use for round robin connections") + flags.UintVar(&timeout, "timeout", 0, "GRPC client timeout (ms)") + + _ = cmd.MarkFlagRequired("host") cmd.AddCommand(NewCmdGet()) cmd.AddCommand(NewCmdInspect()) @@ -124,6 +136,7 @@ func NewClientCommand() *cobra.Command { cmd.AddCommand(NewCmdUpdate()) cmd.AddCommand(NewCmdGc()) cmd.AddCommand(NewCmdGetCache()) + cmd.AddCommand(NewCmdGetCacheFromDaemon()) return cmd } diff --git a/pkg/cli/cache.go b/pkg/cli/getcache.go similarity index 100% rename from pkg/cli/cache.go rename to pkg/cli/getcache.go diff --git a/pkg/cli/getcached.go b/pkg/cli/getcached.go new file mode 100644 index 00000000..2e810284 --- /dev/null +++ b/pkg/cli/getcached.go @@ -0,0 +1,37 @@ +package cli + +import ( + "github.com/gadget-inc/dateilager/internal/key" + "github.com/gadget-inc/dateilager/internal/logger" + "github.com/gadget-inc/dateilager/pkg/client" + "github.com/spf13/cobra" +) + +func NewCmdGetCacheFromDaemon() *cobra.Command { + var ( + path string + ) + + cmd := &cobra.Command{ + Use: "getcached", + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + c := client.CachedFromContext(ctx) + + version, err := c.PopulateDiskCache(ctx, path) + if err != nil { + return err + } + + logger.Info(ctx, "cache populated", key.Version.Field(version)) + + return nil + }, + } + + cmd.Flags().StringVar(&path, "path", "", "Cache directory") + + _ = cmd.MarkFlagRequired("path") + + return cmd +} diff --git a/pkg/cli/server.go b/pkg/cli/server.go index 94daa53d..c3d856ad 100644 --- a/pkg/cli/server.go +++ b/pkg/cli/server.go @@ -131,7 +131,7 @@ func NewServerCommand() *cobra.Command { s.Grpc.GracefulStop() }() - logger.Info(ctx, "start server", key.Port.Field(port), key.Environment.Field(env.String())) + logger.Info(ctx, "start fs server", key.Port.Field(port), key.Environment.Field(env.String())) return s.Serve(listen) }, PostRunE: func(cmd *cobra.Command, _ []string) error { diff --git a/pkg/client/client.go b/pkg/client/client.go index 805b0856..f2587add 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -50,10 +50,19 @@ type Client struct { fs pb.FsClient } +type CachedClient struct { + conn *grpc.ClientConn + cached pb.CachedClient +} + func NewClientConn(conn *grpc.ClientConn) *Client { return &Client{conn: conn, fs: pb.NewFsClient(conn)} } +func NewCachedClientConn(conn *grpc.ClientConn) *CachedClient { + return &CachedClient{conn: conn, cached: pb.NewCachedClient(conn)} +} + type options struct { headlessHost string token string @@ -71,12 +80,7 @@ func WithheadlessHost(host string) func(*options) { } } -func NewClient(ctx context.Context, host string, port uint16, opts ...func(*options)) (*Client, error) { - ctx, span := telemetry.Start(ctx, "client.new", trace.WithAttributes( - key.Server.Attribute(host), - )) - defer span.End() - +func grpcClientConn(ctx context.Context, host string, port uint16, opts ...func(*options)) (*grpc.ClientConn, error) { pool, err := x509.SystemCertPool() if err != nil { return nil, fmt.Errorf("load system cert pool: %w", err) @@ -115,7 +119,7 @@ func NewClient(ctx context.Context, host string, port uint16, opts ...func(*opti server = fmt.Sprintf("%s:%d", o.headlessHost, port) } - conn, err := grpc.DialContext(connectCtx, server, + return grpc.DialContext(connectCtx, server, grpc.WithTransportCredentials(creds), grpc.WithPerRPCCredentials(auth), grpc.WithReadBufferSize(BUFFER_SIZE), @@ -155,6 +159,15 @@ func NewClient(ctx context.Context, host string, port uint16, opts ...func(*opti } `), ) +} + +func NewClient(ctx context.Context, host string, port uint16, opts ...func(*options)) (*Client, error) { + ctx, span := telemetry.Start(ctx, "client.new", trace.WithAttributes( + key.Server.Attribute(host), + )) + defer span.End() + + conn, err := grpcClientConn(ctx, host, port, opts...) if err != nil { return nil, err } @@ -950,6 +963,45 @@ func (c *Client) CloneToProject(ctx context.Context, source int64, target int64, return &response.LatestVersion, nil } +func NewCachedClient(ctx context.Context, host string, port uint16, opts ...func(*options)) (*CachedClient, error) { + ctx, span := telemetry.Start(ctx, "cached-client.new", trace.WithAttributes( + key.Server.Attribute(host), + )) + defer span.End() + + conn, err := grpcClientConn(ctx, host, port, opts...) + if err != nil { + return nil, err + } + + return NewCachedClientConn(conn), nil +} + +func (c *CachedClient) Close() { + // Give a chance for the upstream socket to finish writing it's response + // https://github.com/grpc/grpc-go/issues/2869#issuecomment-503310136 + time.Sleep(1 * time.Millisecond) + c.conn.Close() +} + +func (c *CachedClient) PopulateDiskCache(ctx context.Context, destination string) (int64, error) { + ctx, span := telemetry.Start(ctx, "client.populate-disk-cache", trace.WithAttributes( + key.CachePath.Attribute(destination), + )) + defer span.End() + + request := &pb.PopulateDiskCacheRequest{ + Path: destination, + } + + response, err := c.cached.PopulateDiskCache(ctx, request) + if err != nil { + return 0, fmt.Errorf("populate disk cache for %s: %w", destination, err) + } + + return response.Version, nil +} + func parallelWorkerCount() int { envCount := os.Getenv("DL_WRITE_WORKERS") if envCount != "" { diff --git a/pkg/client/context.go b/pkg/client/context.go index 4f03ddcd..37d56f37 100644 --- a/pkg/client/context.go +++ b/pkg/client/context.go @@ -5,6 +5,7 @@ import ( ) type clientCtxKey struct{} +type cachedCtxKey struct{} func FromContext(ctx context.Context) *Client { client, ok := ctx.Value(clientCtxKey{}).(*Client) @@ -17,3 +18,15 @@ func FromContext(ctx context.Context) *Client { func IntoContext(ctx context.Context, client *Client) context.Context { return context.WithValue(ctx, clientCtxKey{}, client) } + +func CachedFromContext(ctx context.Context) *CachedClient { + client, ok := ctx.Value(cachedCtxKey{}).(*CachedClient) + if !ok { + return nil + } + return client +} + +func CachedIntoContext(ctx context.Context, client *CachedClient) context.Context { + return context.WithValue(ctx, cachedCtxKey{}, client) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 4483e0ad..d14d2bd6 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -130,7 +130,6 @@ type Server struct { func NewServer(ctx context.Context, dbConn *DbPoolConnector, cert *tls.Certificate, pasetoKey ed25519.PublicKey) *Server { creds := credentials.NewServerTLSFromCert(cert) - validator := auth.NewAuthValidator(pasetoKey) grpcServer := grpc.NewServer( @@ -139,7 +138,7 @@ func NewServer(ctx context.Context, dbConn *DbPoolConnector, cert *tls.Certifica grpc_recovery.UnaryServerInterceptor(), otelgrpc.UnaryServerInterceptor(), logger.UnaryServerInterceptor(), - validateTokenUnary(validator), + ValidateTokenUnary(validator), ), ), grpc.StreamInterceptor( @@ -210,7 +209,7 @@ func (s *Server) Serve(lis net.Listener) error { return s.Grpc.Serve(lis) } -func validateTokenUnary(validator *auth.AuthValidator) grpc.UnaryServerInterceptor { +func ValidateTokenUnary(validator *auth.AuthValidator) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { diff --git a/test/cached_test.go b/test/cached_test.go new file mode 100644 index 00000000..4c8ea025 --- /dev/null +++ b/test/cached_test.go @@ -0,0 +1,99 @@ +package test + +import ( + "fmt" + "os" + "path" + "testing" + + "github.com/gadget-inc/dateilager/internal/auth" + "github.com/gadget-inc/dateilager/internal/db" + util "github.com/gadget-inc/dateilager/internal/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPopulateCache(t *testing.T) { + tc := util.NewTestCtx(t, auth.Admin, 1) + defer tc.Close() + + writeProject(tc, 1, 2) + writeObject(tc, 1, 1, i(2), "a", "a v1") + aHash := writePackedFiles(tc, 1, 1, nil, "pack/a") + bHash := writePackedFiles(tc, 1, 1, nil, "pack/b") + version, err := db.CreateCache(tc.Context(), tc.Connect(), "", 100) + require.NoError(t, err) + + c, cached, close := createTestCachedClient(tc) + defer close() + + tmpDir := emptyTmpDir(t) + defer os.RemoveAll(tmpDir) + + require.NoError(t, cached.Prepare(tc.Context()), "cached.Prepare must succeed") + + _, err = c.PopulateDiskCache(tc.Context(), path.Join(tmpDir, "test")) + require.NoError(t, err, "Cached.PopulateDiskCache") + + verifyDir(t, path.Join(tmpDir, "test"), -1, map[string]expectedFile{ + fmt.Sprintf("objects/%v/pack/a/1", aHash): {content: "pack/a/1 v1"}, + fmt.Sprintf("objects/%v/pack/a/2", aHash): {content: "pack/a/2 v1"}, + fmt.Sprintf("objects/%v/pack/b/1", bHash): {content: "pack/b/1 v1"}, + fmt.Sprintf("objects/%v/pack/b/2", bHash): {content: "pack/b/2 v1"}, + "versions": {content: fmt.Sprintf("%v\n", version)}, + }) +} + +func TestPopulateEmptyCache(t *testing.T) { + tc := util.NewTestCtx(t, auth.Admin, 1) + defer tc.Close() + + writeProject(tc, 1, 2) + writeObject(tc, 1, 1, i(2), "a", "a v1") + // no packed files, so no cache + version, err := db.CreateCache(tc.Context(), tc.Connect(), "", 100) + require.NoError(t, err) + assert.NotEqual(t, int64(-1), version) + + c, cached, close := createTestCachedClient(tc) + defer close() + + tmpDir := emptyTmpDir(t) + defer os.RemoveAll(tmpDir) + + require.NoError(t, cached.Prepare(tc.Context()), "cached.Prepare must succeed") + + _, err = c.PopulateDiskCache(tc.Context(), path.Join(tmpDir, "test")) + require.NoError(t, err, "PopulateDiskCache must succeed") + + verifyDir(t, path.Join(tmpDir, "test"), -1, map[string]expectedFile{ + "objects/": {content: "", fileType: typeDirectory}, + }) +} + +func TestPopulateCacheToPathWithNoWritePermissions(t *testing.T) { + tc := util.NewTestCtx(t, auth.Admin, 1) + defer tc.Close() + + writeProject(tc, 1, 2) + writeObject(tc, 1, 1, i(2), "a", "a v1") + writePackedFiles(tc, 1, 1, nil, "pack/a") + writePackedFiles(tc, 1, 1, nil, "pack/b") + _, err := db.CreateCache(tc.Context(), tc.Connect(), "", 100) + require.NoError(t, err) + + c, cached, close := createTestCachedClient(tc) + defer close() + + tmpDir := emptyTmpDir(t) + defer os.RemoveAll(tmpDir) + + require.NoError(t, cached.Prepare(tc.Context()), "cached.Prepare must succeed") + + // Create a directory with no write permissions + err = os.Mkdir(path.Join(tmpDir, "test"), 0000) + require.NoError(t, err) + + _, err = c.PopulateDiskCache(tc.Context(), path.Join(tmpDir, "test")) + require.Error(t, err, "populating cache to a path with no write permissions must fail") +} diff --git a/test/client_rebuild_test.go b/test/client_rebuild_test.go index e8812865..fd13b03d 100644 --- a/test/client_rebuild_test.go +++ b/test/client_rebuild_test.go @@ -277,8 +277,8 @@ func TestRebuildWithCache(t *testing.T) { count: 2, }) - aCachePath := filepath.Join(client.CacheObjectsDir(cacheDir), ha.Hex(), "pack/a") - bCachePath := filepath.Join(client.CacheObjectsDir(cacheDir), hb.Hex(), "pack/b") + aCachePath := filepath.Join(client.CacheObjectsDir(cacheDir), ha, "pack/a") + bCachePath := filepath.Join(client.CacheObjectsDir(cacheDir), hb, "pack/b") verifyDir(t, tmpDir, 1, map[string]expectedFile{ "pack/a/1": {content: "pack/a/1 v1"}, diff --git a/test/shared_test.go b/test/shared_test.go index 18c8723a..7458b9b3 100644 --- a/test/shared_test.go +++ b/test/shared_test.go @@ -213,11 +213,12 @@ func writePackedObjects(tc util.TestCtx, project int64, start int64, stop *int64 return hash } -func writePackedFiles(tc util.TestCtx, project int64, start int64, stop *int64, path string) db.Hash { - return writePackedObjects(tc, project, start, stop, path, map[string]expectedObject{ +func writePackedFiles(tc util.TestCtx, project int64, start int64, stop *int64, path string) string { + hash := writePackedObjects(tc, project, start, stop, path, map[string]expectedObject{ filepath.Join(path, "1"): {content: fmt.Sprintf("%s v%d", filepath.Join(path, "1"), start)}, filepath.Join(path, "2"): {content: fmt.Sprintf("%s v%d", filepath.Join(path, "2"), start)}, }) + return hash.Hex() } func packObjects(tc util.TestCtx, objects map[string]expectedObject) []byte { @@ -338,11 +339,13 @@ func verifyDir(t *testing.T, dir string, version int64, files map[string]expecte dirEntries[fmt.Sprintf("%s/", *maybeEmptyDir)] = *maybeEmptyInfo } - fileVersion, err := client.ReadVersionFile(dir) - require.NoError(t, err, "read version file") + if version != -1 { + fileVersion, err := client.ReadVersionFile(dir) + require.NoError(t, err, "read version file") - assert.Equal(t, version, fileVersion, "expected file version %v", version) - assert.Equal(t, len(files), len(dirEntries), "expected %v files in %v", len(files), dir) + assert.Equal(t, version, fileVersion, "expected file version %v", version) + assert.Equal(t, len(files), len(dirEntries), "expected %v files in %v", len(files), dir) + } for name, file := range files { path := filepath.Join(dir, name) @@ -374,10 +377,7 @@ func verifyDir(t *testing.T, dir string, version int64, files map[string]expecte } } -func createTestClient(tc util.TestCtx) (*client.Client, *api.Fs, func()) { - fs := tc.FsApi() - reqAuth := tc.Context().Value(auth.AuthCtxKey).(auth.Auth) - +func createTestGRPCServer(tc util.TestCtx, reqAuth auth.Auth) (*bufconn.Listener, *grpc.Server, func() *grpc.ClientConn) { lis := bufconn.Listen(bufSize) s := grpc.NewServer( grpc.UnaryInterceptor( @@ -394,22 +394,54 @@ func createTestClient(tc util.TestCtx) (*client.Client, *api.Fs, func()) { ), ) + dialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + getConn := func() *grpc.ClientConn { + conn, err := grpc.DialContext(tc.Context(), "bufnet", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(tc.T(), err, "Failed to dial bufnet") + return conn + } + + return lis, s, getConn +} + +func createTestClient(tc util.TestCtx) (*client.Client, *api.Fs, func()) { + lis, s, getConn := createTestGRPCServer(tc, tc.Context().Value(auth.AuthCtxKey).(auth.Auth)) + + fs := tc.FsApi() pb.RegisterFsServer(s, fs) + go func() { err := s.Serve(lis) require.NoError(tc.T(), err, "Server exited") }() - dialer := func(context.Context, string) (net.Conn, error) { - return lis.Dial() - } + c := client.NewClientConn(getConn()) - conn, err := grpc.DialContext(tc.Context(), "bufnet", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials())) - require.NoError(tc.T(), err, "Failed to dial bufnet") + return c, fs, func() { c.Close(); s.Stop() } +} - c := client.NewClientConn(conn) +// Make a new client that connects to a test cached server +// Under the hood, this creates a test storage server and connects to that +func createTestCachedClient(tc util.TestCtx) (*client.CachedClient, *api.Cached, func()) { + lis, s, getConn := createTestGRPCServer(tc, tc.Context().Value(auth.AuthCtxKey).(auth.Auth)) - return c, fs, func() { c.Close(); s.Stop() } + cl, _, closeClient := createTestClient(tc) + stagingPath := emptyTmpDir(tc.T()) + + cached := tc.CachedApi(cl, stagingPath) + pb.RegisterCachedServer(s, cached) + + go func() { + err := s.Serve(lis) + require.NoError(tc.T(), err, "Server exited") + }() + + cachedClient := client.NewCachedClientConn(getConn()) + + return cachedClient, cached, func() { cachedClient.Close(); closeClient(); s.Stop() } } func rebuild(tc util.TestCtx, c *client.Client, project int64, toVersion *int64, dir string, cacheDir *string, expected expectedResponse) {