From ffa9a27bb02d28da4b3be3fa1781c8024a3014d0 Mon Sep 17 00:00:00 2001 From: Bhunter <180028024+bhunter234@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:27:52 +0100 Subject: [PATCH] fix: cache invalidation in file update --- pkg/services/file.go | 56 ++++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/pkg/services/file.go b/pkg/services/file.go index cea8189b..e042d263 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -123,7 +123,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap return nil, &apiError{err: errors.New("file not found"), code: 404} } - file := mapper.ToFileOut(res[0], true) + file := res[0] newIds := []api.Part{} @@ -138,7 +138,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap for _, part := range file.Parts { ids = append(ids, int(part.ID)) } - messages, err := tgc.GetMessages(ctx, client.API(), ids, file.ChannelId.Value) + messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelID) if err != nil { return err @@ -179,7 +179,11 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap } } - newIds = append(newIds, api.Part{ID: msg.ID, Salt: file.Parts[i].Salt}) + p := api.Part{ID: msg.ID} + if file.Parts[i].Salt.Value != "" { + p.Salt = file.Parts[i].Salt + } + newIds = append(newIds, p) } return nil @@ -189,6 +193,10 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap return nil, &apiError{err: err} } + if len(newIds) != len(file.Parts) { + return nil, &apiError{err: errors.New("failed to copy all file parts")} + } + var parentId string if !isUUID(req.Destination) { var destRes []models.File @@ -204,10 +212,12 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap dbFile := models.File{} dbFile.Name = req.NewName.Or(file.Name) - dbFile.Size = utils.Ptr(file.Size.Value) + dbFile.Size = file.Size dbFile.Type = string(file.Type) - dbFile.MimeType = file.MimeType.Or(defaultContentType) - dbFile.Parts = datatypes.NewJSONSlice(newIds) + dbFile.MimeType = file.MimeType + if len(newIds) > 0 { + dbFile.Parts = datatypes.NewJSONSlice(newIds) + } dbFile.UserID = userId dbFile.Status = "active" dbFile.ParentID = sql.NullString{ @@ -215,8 +225,15 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap Valid: true, } dbFile.ChannelID = &channelId - dbFile.Encrypted = file.Encrypted.Value - dbFile.Category = string(file.Category.Value) + dbFile.Encrypted = file.Encrypted + dbFile.Category = string(file.Category) + if req.UpdatedAt.IsSet() && !req.UpdatedAt.Value.IsZero() { + dbFile.UpdatedAt = req.UpdatedAt.Value + dbFile.CreatedAt = req.UpdatedAt.Value + } else { + dbFile.UpdatedAt = time.Now().UTC() + dbFile.CreatedAt = time.Now().UTC() + } if err := a.db.Create(&dbFile).Error; err != nil { return nil, &apiError{err: err} @@ -236,8 +253,12 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi channelId int64 ) + if fileIn.Path.Value == "" && fileIn.ParentId.Value == "" { + return nil, &apiError{err: errors.New("parent id or path is required"), code: 409} + } + if fileIn.Path.Value != "" { - path = strings.ReplaceAll(path, "//", "/") + path = strings.ReplaceAll(fileIn.Path.Value, "//", "/") if path != "/" { path = strings.TrimSuffix(path, "/") } @@ -258,8 +279,6 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi Valid: true, } - } else { - return nil, &apiError{err: errors.New("parent id or path is required"), code: 409} } if fileIn.Type == "folder" { @@ -295,6 +314,13 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi fileDB.UserID = userId fileDB.Status = "active" fileDB.Encrypted = fileIn.Encrypted.Value + if fileIn.UpdatedAt.IsSet() && !fileIn.UpdatedAt.Value.IsZero() { + fileDB.UpdatedAt = fileIn.UpdatedAt.Value + fileDB.CreatedAt = fileIn.UpdatedAt.Value + } else { + fileDB.UpdatedAt = time.Now().UTC() + fileDB.CreatedAt = time.Now().UTC() + } if err := a.db.Create(&fileDB).Error; err != nil { if database.IsKeyConflictErr(err) { return nil, &apiError{err: errors.New("file already exists"), code: 409} @@ -491,9 +517,9 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param if req.Size.Value != 0 { updateDb.Size = utils.Ptr(req.Size.Value) } - if req.UpdatedAt.IsSet() { + if req.UpdatedAt.IsSet() && !req.UpdatedAt.Value.IsZero() { updateDb.UpdatedAt = req.UpdatedAt.Value - } else { + } else if !req.UpdatedAt.IsSet() && params.Skiputs.Value == "0" { updateDb.UpdatedAt = time.Now().UTC() } @@ -578,9 +604,9 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd } client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...) tgc.DeleteMessages(ctx, client, *file.ChannelID, ids) - keys := []string{fmt.Sprintf("files:%s", params.ID), fmt.Sprintf("files:messages:%s:%d", params.ID, userId)} + keys := []string{fmt.Sprintf("files:%s", params.ID), fmt.Sprintf("files:messages:%s", params.ID)} for _, part := range file.Parts { - keys = append(keys, fmt.Sprintf("files:location:%d:%s:%d", userId, params.ID, part.ID)) + keys = append(keys, fmt.Sprintf("files:location:%s:%d", params.ID, part.ID)) } a.cache.Delete(keys...)