diff --git a/drivers/115/util.go b/drivers/115/util.go index 7298f565d0c..fc17fe3cebf 100644 --- a/drivers/115/util.go +++ b/drivers/115/util.go @@ -405,7 +405,7 @@ func (d *Pan115) UploadByMultipart(ctx context.Context, params *driver115.Upload if _, err = tmpF.ReadAt(buf, chunk.Offset); err != nil && !errors.Is(err, io.EOF) { continue } - if part, err = bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf)), + if part, err = bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, bytes.NewReader(buf)), chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil { break } diff --git a/drivers/123/driver.go b/drivers/123/driver.go index 7d457138fde..32c053e22ab 100644 --- a/drivers/123/driver.go +++ b/drivers/123/driver.go @@ -2,11 +2,8 @@ package _123 import ( "context" - "crypto/md5" "encoding/base64" - "encoding/hex" "fmt" - "io" "net/http" "net/url" "sync" @@ -18,6 +15,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -187,25 +185,12 @@ func (d *Pan123) Remove(ctx context.Context, obj model.Obj) error { func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { etag := file.GetHash().GetHash(utils.MD5) + var err error if len(etag) < utils.MD5.Width { - // const DEFAULT int64 = 10485760 - h := md5.New() - // need to calculate md5 of the full content - tempFile, err := file.CacheFullInTempFile() + _, etag, err = stream.CacheFullInTempFileAndHash(file, utils.MD5) if err != nil { return err } - defer func() { - _ = tempFile.Close() - }() - if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } - etag = hex.EncodeToString(h.Sum(nil)) } data := base.Json{ "driveId": 0, diff --git a/drivers/123/upload.go b/drivers/123/upload.go index dc148c4c93f..b0482a9f4c9 100644 --- a/drivers/123/upload.go +++ b/drivers/123/upload.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "math" "net/http" "strconv" @@ -70,27 +69,33 @@ func (d *Pan123) completeS3(ctx context.Context, upReq *UploadResp, file model.F } func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, up driver.UpdateProgress) error { - chunkSize := int64(1024 * 1024 * 16) + tmpF, err := file.CacheFullInTempFile() + if err != nil { + return err + } // fetch s3 pre signed urls - chunkCount := int(math.Ceil(float64(file.GetSize()) / float64(chunkSize))) + size := file.GetSize() + chunkSize := min(size, 16*utils.MB) + chunkCount := int(size / chunkSize) + lastChunkSize := size % chunkSize + if lastChunkSize > 0 { + chunkCount++ + } else { + lastChunkSize = chunkSize + } // only 1 batch is allowed - isMultipart := chunkCount > 1 batchSize := 1 getS3UploadUrl := d.getS3Auth - if isMultipart { + if chunkCount > 1 { batchSize = 10 getS3UploadUrl = d.getS3PreSignedUrls } - limited := driver.NewLimitedUploadStream(ctx, file) for i := 1; i <= chunkCount; i += batchSize { if utils.IsCanceled(ctx) { return ctx.Err() } start := i - end := i + batchSize - if end > chunkCount+1 { - end = chunkCount + 1 - } + end := min(i+batchSize, chunkCount+1) s3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, start, end) if err != nil { return err @@ -102,9 +107,9 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi } curSize := chunkSize if j == chunkCount { - curSize = file.GetSize() - (int64(chunkCount)-1)*chunkSize + curSize = lastChunkSize } - err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(limited, chunkSize), curSize, false, getS3UploadUrl) + err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.NewSectionReader(tmpF, chunkSize*int64(j-1), curSize), curSize, false, getS3UploadUrl) if err != nil { return err } @@ -115,12 +120,12 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi return d.completeS3(ctx, upReq, file, chunkCount > 1) } -func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader io.Reader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error { +func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader *io.SectionReader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error { uploadUrl := s3PreSignedUrls.Data.PreSignedUrls[strconv.Itoa(cur)] if uploadUrl == "" { return fmt.Errorf("upload url is empty, s3PreSignedUrls: %+v", s3PreSignedUrls) } - req, err := http.NewRequest("PUT", uploadUrl, reader) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, reader)) if err != nil { return err } @@ -143,6 +148,7 @@ func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSign } s3PreSignedUrls.Data.PreSignedUrls = newS3PreSignedUrls.Data.PreSignedUrls // retry + reader.Seek(0, io.SeekStart) return d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, cur, end, reader, curSize, true, getS3UploadUrl) } if res.StatusCode != http.StatusOK { diff --git a/drivers/139/driver.go b/drivers/139/driver.go index f367c431c1b..0af5a4f781a 100644 --- a/drivers/139/driver.go +++ b/drivers/139/driver.go @@ -2,20 +2,19 @@ package _139 import ( "context" - "encoding/base64" "encoding/xml" "fmt" "io" "net/http" "path" "strconv" - "strings" "time" "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + streamPkg "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/cron" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils/random" @@ -72,28 +71,29 @@ func (d *Yun139) Init(ctx context.Context) error { default: return errs.NotImplement } - if d.ref != nil { - return nil - } - decode, err := base64.StdEncoding.DecodeString(d.Authorization) - if err != nil { - return err - } - decodeStr := string(decode) - splits := strings.Split(decodeStr, ":") - if len(splits) < 2 { - return fmt.Errorf("authorization is invalid, splits < 2") - } - d.Account = splits[1] - _, err = d.post("/orchestration/personalCloud/user/v1.0/qryUserExternInfo", base.Json{ - "qryUserExternInfoReq": base.Json{ - "commonAccountInfo": base.Json{ - "account": d.getAccount(), - "accountType": 1, - }, - }, - }, nil) - return err + // if d.ref != nil { + // return nil + // } + // decode, err := base64.StdEncoding.DecodeString(d.Authorization) + // if err != nil { + // return err + // } + // decodeStr := string(decode) + // splits := strings.Split(decodeStr, ":") + // if len(splits) < 2 { + // return fmt.Errorf("authorization is invalid, splits < 2") + // } + // d.Account = splits[1] + // _, err = d.post("/orchestration/personalCloud/user/v1.0/qryUserExternInfo", base.Json{ + // "qryUserExternInfoReq": base.Json{ + // "commonAccountInfo": base.Json{ + // "account": d.getAccount(), + // "accountType": 1, + // }, + // }, + // }, nil) + // return err + return nil } func (d *Yun139) InitReference(storage driver.Driver) error { @@ -503,23 +503,15 @@ func (d *Yun139) Remove(ctx context.Context, obj model.Obj) error { } } -const ( - _ = iota //ignore first value by assigning to blank identifier - KB = 1 << (10 * iota) - MB - GB - TB -) - func (d *Yun139) getPartSize(size int64) int64 { if d.CustomUploadPartSize != 0 { return d.CustomUploadPartSize } // 网盘对于分片数量存在上限 - if size/GB > 30 { - return 512 * MB + if size/utils.GB > 30 { + return 512 * utils.MB } - return 100 * MB + return 100 * utils.MB } func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { @@ -527,29 +519,28 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr case MetaPersonalNew: var err error fullHash := stream.GetHash().GetHash(utils.SHA256) - if len(fullHash) <= 0 { - tmpF, err := stream.CacheFullInTempFile() - if err != nil { - return err - } - fullHash, err = utils.HashFile(utils.SHA256, tmpF) + if len(fullHash) != utils.SHA256.Width { + _, fullHash, err = streamPkg.CacheFullInTempFileAndHash(stream, utils.SHA256) if err != nil { return err } } - partInfos := []PartInfo{} - var partSize = d.getPartSize(stream.GetSize()) - part := (stream.GetSize() + partSize - 1) / partSize - if part == 0 { + size := stream.GetSize() + var partSize = d.getPartSize(size) + part := size / partSize + if size%partSize > 0 { + part++ + } else if part == 0 { part = 1 } + partInfos := make([]PartInfo, 0, part) for i := int64(0); i < part; i++ { if utils.IsCanceled(ctx) { return ctx.Err() } start := i * partSize - byteSize := stream.GetSize() - start + byteSize := size - start if byteSize > partSize { byteSize = partSize } @@ -577,7 +568,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr "contentType": "application/octet-stream", "parallelUpload": false, "partInfos": firstPartInfos, - "size": stream.GetSize(), + "size": size, "parentFileId": dstDir.GetID(), "name": stream.GetName(), "type": "file", @@ -630,7 +621,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr } // Progress - p := driver.NewProgress(stream.GetSize(), up) + p := driver.NewProgress(size, up) rateLimited := driver.NewLimitedUploadStream(ctx, stream) // 上传所有分片 @@ -790,12 +781,14 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr return fmt.Errorf("get file upload url failed with result code: %s, message: %s", resp.Data.Result.ResultCode, resp.Data.Result.ResultDesc) } + size := stream.GetSize() // Progress - p := driver.NewProgress(stream.GetSize(), up) - - var partSize = d.getPartSize(stream.GetSize()) - part := (stream.GetSize() + partSize - 1) / partSize - if part == 0 { + p := driver.NewProgress(size, up) + var partSize = d.getPartSize(size) + part := size / partSize + if size%partSize > 0 { + part++ + } else if part == 0 { part = 1 } rateLimited := driver.NewLimitedUploadStream(ctx, stream) @@ -805,10 +798,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr } start := i * partSize - byteSize := stream.GetSize() - start - if byteSize > partSize { - byteSize = partSize - } + byteSize := min(size-start, partSize) limitReader := io.LimitReader(rateLimited, byteSize) // Update Progress @@ -820,7 +810,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr req = req.WithContext(ctx) req.Header.Set("Content-Type", "text/plain;name="+unicode(stream.GetName())) - req.Header.Set("contentSize", strconv.FormatInt(stream.GetSize(), 10)) + req.Header.Set("contentSize", strconv.FormatInt(size, 10)) req.Header.Set("range", fmt.Sprintf("bytes=%d-%d", start, start+byteSize-1)) req.Header.Set("uploadtaskID", resp.Data.UploadResult.UploadTaskID) req.Header.Set("rangeType", "0") diff --git a/drivers/139/util.go b/drivers/139/util.go index 3e1a61edc81..53defef528e 100644 --- a/drivers/139/util.go +++ b/drivers/139/util.go @@ -67,6 +67,7 @@ func (d *Yun139) refreshToken() error { if len(splits) < 3 { return fmt.Errorf("authorization is invalid, splits < 3") } + d.Account = splits[1] strs := strings.Split(splits[2], "|") if len(strs) < 4 { return fmt.Errorf("authorization is invalid, strs < 4") diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index fb1a183ab38..c391f7e676f 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -3,16 +3,15 @@ package _189pc import ( "bytes" "context" - "crypto/md5" "encoding/base64" "encoding/hex" "encoding/xml" "fmt" "io" - "math" "net/http" "net/http/cookiejar" "net/url" + "os" "regexp" "sort" "strconv" @@ -28,6 +27,7 @@ import ( "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/errgroup" "github.com/alist-org/alist/v3/pkg/utils" @@ -473,12 +473,8 @@ func (y *Cloud189PC) refreshSession() (err error) { // 普通上传 // 无法上传大小为0的文件 func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { - var sliceSize = partSize(file.GetSize()) - count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize))) - lastPartSize := file.GetSize() % sliceSize - if file.GetSize() > 0 && lastPartSize == 0 { - lastPartSize = sliceSize - } + size := file.GetSize() + sliceSize := partSize(size) params := Params{ "parentFolderId": dstDir.GetID(), @@ -512,22 +508,29 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo retry.DelayType(retry.BackOffDelay)) sem := semaphore.NewWeighted(3) - fileMd5 := md5.New() - silceMd5 := md5.New() + count := int(size / sliceSize) + lastPartSize := size % sliceSize + if lastPartSize > 0 { + count++ + } else { + lastPartSize = sliceSize + } + fileMd5 := utils.MD5.NewFunc() + silceMd5 := utils.MD5.NewFunc() silceMd5Hexs := make([]string, 0, count) - + teeReader := io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5)) + byteSize := sliceSize for i := 1; i <= count; i++ { if utils.IsCanceled(upCtx) { break } - byteData := make([]byte, sliceSize) if i == count { - byteData = byteData[:lastPartSize] + byteSize = lastPartSize } - + byteData := make([]byte, byteSize) // 读取块 silceMd5.Reset() - if _, err := io.ReadFull(io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5)), byteData); err != io.EOF && err != nil { + if _, err := io.ReadFull(teeReader, byteData); err != io.EOF && err != nil { sem.Release(1) return nil, err } @@ -607,24 +610,43 @@ func (y *Cloud189PC) RapidUpload(ctx context.Context, dstDir model.Obj, stream m // 快传 func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { - tempFile, err := file.CacheFullInTempFile() - if err != nil { - return nil, err - } - - var sliceSize = partSize(file.GetSize()) - count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize))) - lastSliceSize := file.GetSize() % sliceSize - if file.GetSize() > 0 && lastSliceSize == 0 { + var ( + cache = file.GetFile() + tmpF *os.File + err error + ) + size := file.GetSize() + if _, ok := cache.(io.ReaderAt); !ok && size > 0 { + tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return nil, err + } + defer func() { + _ = tmpF.Close() + _ = os.Remove(tmpF.Name()) + }() + cache = tmpF + } + sliceSize := partSize(size) + count := int(size / sliceSize) + lastSliceSize := size % sliceSize + if lastSliceSize > 0 { + count++ + } else { lastSliceSize = sliceSize } //step.1 优先计算所需信息 byteSize := sliceSize - fileMd5 := md5.New() - silceMd5 := md5.New() - silceMd5Hexs := make([]string, 0, count) + fileMd5 := utils.MD5.NewFunc() + sliceMd5 := utils.MD5.NewFunc() + sliceMd5Hexs := make([]string, 0, count) partInfos := make([]string, 0, count) + writers := []io.Writer{fileMd5, sliceMd5} + if tmpF != nil { + writers = append(writers, tmpF) + } + written := int64(0) for i := 1; i <= count; i++ { if utils.IsCanceled(ctx) { return nil, ctx.Err() @@ -634,19 +656,31 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode byteSize = lastSliceSize } - silceMd5.Reset() - if _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5, silceMd5), tempFile, byteSize); err != nil && err != io.EOF { + n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), file, byteSize) + written += n + if err != nil && err != io.EOF { return nil, err } - md5Byte := silceMd5.Sum(nil) - silceMd5Hexs = append(silceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Byte))) + md5Byte := sliceMd5.Sum(nil) + sliceMd5Hexs = append(sliceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Byte))) partInfos = append(partInfos, fmt.Sprint(i, "-", base64.StdEncoding.EncodeToString(md5Byte))) + sliceMd5.Reset() + } + + if tmpF != nil { + if size > 0 && written != size { + return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, size) + } + _, err = tmpF.Seek(0, io.SeekStart) + if err != nil { + return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") + } } fileMd5Hex := strings.ToUpper(hex.EncodeToString(fileMd5.Sum(nil))) sliceMd5Hex := fileMd5Hex - if file.GetSize() > sliceSize { - sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(silceMd5Hexs, "\n"))) + if size > sliceSize { + sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(sliceMd5Hexs, "\n"))) } fullUrl := UPLOAD_URL @@ -712,7 +746,7 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode } // step.4 上传切片 - _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(tempFile, offset, byteSize), isFamily) + _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(cache, offset, byteSize), isFamily) if err != nil { return err } @@ -794,11 +828,7 @@ func (y *Cloud189PC) GetMultiUploadUrls(ctx context.Context, isFamily bool, uplo // 旧版本上传,家庭云不支持覆盖 func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { - tempFile, err := file.CacheFullInTempFile() - if err != nil { - return nil, err - } - fileMd5, err := utils.HashFile(utils.MD5, tempFile) + tempFile, fileMd5, err := stream.CacheFullInTempFileAndHash(file, utils.MD5) if err != nil { return nil, err } diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index fb730de6966..4114c195182 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -1,7 +1,6 @@ package aliyundrive_open import ( - "bytes" "context" "encoding/base64" "fmt" @@ -15,6 +14,7 @@ import ( "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" + streamPkg "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/http_range" "github.com/alist-org/alist/v3/pkg/utils" "github.com/avast/retry-go" @@ -131,16 +131,19 @@ func (d *AliyundriveOpen) calProofCode(stream model.FileStreamer) (string, error return "", err } length := proofRange.End - proofRange.Start - buf := bytes.NewBuffer(make([]byte, 0, length)) reader, err := stream.RangeRead(http_range.Range{Start: proofRange.Start, Length: length}) if err != nil { return "", err } - _, err = utils.CopyWithBufferN(buf, reader, length) + buf := make([]byte, length) + n, err := io.ReadFull(reader, buf) + if err == io.ErrUnexpectedEOF { + return "", fmt.Errorf("can't read data, expected=%d, got=%d", len(buf), n) + } if err != nil { return "", err } - return base64.StdEncoding.EncodeToString(buf.Bytes()), nil + return base64.StdEncoding.EncodeToString(buf), nil } func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { @@ -183,25 +186,18 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m _, err, e := d.requestReturnErrResp("/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) { req.SetBody(createData).SetResult(&createResp) }) - var tmpF model.File if err != nil { if e.Code != "PreHashMatched" || !rapidUpload { return nil, err } log.Debugf("[aliyundrive_open] pre_hash matched, start rapid upload") - hi := stream.GetHash() - hash := hi.GetHash(utils.SHA1) - if len(hash) <= 0 { - tmpF, err = stream.CacheFullInTempFile() + hash := stream.GetHash().GetHash(utils.SHA1) + if len(hash) != utils.SHA1.Width { + _, hash, err = streamPkg.CacheFullInTempFileAndHash(stream, utils.SHA1) if err != nil { return nil, err } - hash, err = utils.HashFile(utils.SHA1, tmpF) - if err != nil { - return nil, err - } - } delete(createData, "pre_hash") diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 3cc1ae9ed97..c33e0b32b05 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -6,8 +6,8 @@ import ( "encoding/hex" "errors" "io" - "math" "net/url" + "os" stdpath "path" "strconv" "time" @@ -15,6 +15,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -185,16 +186,30 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return newObj, nil } - tempFile, err := stream.CacheFullInTempFile() - if err != nil { - return nil, err + var ( + cache = stream.GetFile() + tmpF *os.File + err error + ) + if _, ok := cache.(io.ReaderAt); !ok { + tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return nil, err + } + defer func() { + _ = tmpF.Close() + _ = os.Remove(tmpF.Name()) + }() + cache = tmpF } streamSize := stream.GetSize() sliceSize := d.getSliceSize(streamSize) - count := int(math.Max(math.Ceil(float64(streamSize)/float64(sliceSize)), 1)) + count := int(streamSize / sliceSize) lastBlockSize := streamSize % sliceSize - if streamSize > 0 && lastBlockSize == 0 { + if lastBlockSize > 0 { + count++ + } else { lastBlockSize = sliceSize } @@ -207,6 +222,11 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F sliceMd5H := md5.New() sliceMd5H2 := md5.New() slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) + writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write} + if tmpF != nil { + writers = append(writers, tmpF) + } + written := int64(0) for i := 1; i <= count; i++ { if utils.IsCanceled(ctx) { @@ -215,13 +235,23 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F if i == count { byteSize = lastBlockSize } - _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) + n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize) + written += n if err != nil && err != io.EOF { return nil, err } blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil))) sliceMd5H.Reset() } + if tmpF != nil { + if written != streamSize { + return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize) + } + _, err = tmpF.Seek(0, io.SeekStart) + if err != nil { + return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") + } + } contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) blockListStr, _ := utils.Json.MarshalToString(blockList) @@ -291,7 +321,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F "partseq": strconv.Itoa(partseq), } err := d.uploadSlice(ctx, params, stream.GetName(), - driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize))) + driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize))) if err != nil { return err } diff --git a/drivers/baidu_photo/driver.go b/drivers/baidu_photo/driver.go index eeee746f71d..5a34fcb4639 100644 --- a/drivers/baidu_photo/driver.go +++ b/drivers/baidu_photo/driver.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" "io" - "math" + "os" "regexp" "strconv" "strings" @@ -16,6 +16,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -241,11 +242,21 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil // TODO: // 暂时没有找到妙传方式 - - // 需要获取完整文件md5,必须支持 io.Seek - tempFile, err := stream.CacheFullInTempFile() - if err != nil { - return nil, err + var ( + cache = stream.GetFile() + tmpF *os.File + err error + ) + if _, ok := cache.(io.ReaderAt); !ok { + tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return nil, err + } + defer func() { + _ = tmpF.Close() + _ = os.Remove(tmpF.Name()) + }() + cache = tmpF } const DEFAULT int64 = 1 << 22 @@ -253,9 +264,11 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil // 计算需要的数据 streamSize := stream.GetSize() - count := int(math.Ceil(float64(streamSize) / float64(DEFAULT))) + count := int(streamSize / DEFAULT) lastBlockSize := streamSize % DEFAULT - if lastBlockSize == 0 { + if lastBlockSize > 0 { + count++ + } else { lastBlockSize = DEFAULT } @@ -266,6 +279,11 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil sliceMd5H := md5.New() sliceMd5H2 := md5.New() slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) + writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write} + if tmpF != nil { + writers = append(writers, tmpF) + } + written := int64(0) for i := 1; i <= count; i++ { if utils.IsCanceled(ctx) { return nil, ctx.Err() @@ -273,13 +291,23 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil if i == count { byteSize = lastBlockSize } - _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) + n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize) + written += n if err != nil && err != io.EOF { return nil, err } sliceMD5List = append(sliceMD5List, hex.EncodeToString(sliceMd5H.Sum(nil))) sliceMd5H.Reset() } + if tmpF != nil { + if written != streamSize { + return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize) + } + _, err = tmpF.Seek(0, io.SeekStart) + if err != nil { + return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") + } + } contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) blockListStr, _ := utils.Json.MarshalToString(sliceMD5List) @@ -291,7 +319,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil "rtype": "1", "ctype": "11", "path": fmt.Sprintf("/%s", stream.GetName()), - "size": fmt.Sprint(stream.GetSize()), + "size": fmt.Sprint(streamSize), "slice-md5": sliceMd5, "content-md5": contentMd5, "block_list": blockListStr, @@ -343,7 +371,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil r.SetContext(ctx) r.SetQueryParams(uploadParams) r.SetFileReader("file", stream.GetName(), - driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize))) + driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize))) }, nil) if err != nil { return err diff --git a/drivers/cloudreve/util.go b/drivers/cloudreve/util.go index 1fd5ed8abae..196d7303337 100644 --- a/drivers/cloudreve/util.go +++ b/drivers/cloudreve/util.go @@ -204,7 +204,7 @@ func (d *Cloudreve) upLocal(ctx context.Context, stream model.FileStreamer, u Up req.SetContentLength(true) req.SetHeader("Content-Length", strconv.FormatInt(byteSize, 10)) req.SetHeader("User-Agent", d.getUA()) - req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) }, nil) if err != nil { break @@ -239,7 +239,7 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U return err } req, err := http.NewRequest("POST", uploadUrl+"?chunk="+strconv.Itoa(chunk), - driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } @@ -280,7 +280,7 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u if err != nil { return err } - req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } diff --git a/drivers/github/util.go b/drivers/github/util.go index 03318784f72..7ddf8746c8f 100644 --- a/drivers/github/util.go +++ b/drivers/github/util.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "io" "strings" "text/template" "time" @@ -159,7 +158,7 @@ func signCommit(m *map[string]interface{}, entity *openpgp.Entity) (string, erro if err != nil { return "", err } - if _, err = io.Copy(armorWriter, &sigBuffer); err != nil { + if _, err = utils.CopyWithBuffer(armorWriter, &sigBuffer); err != nil { return "", err } _ = armorWriter.Close() diff --git a/drivers/ilanzou/driver.go b/drivers/ilanzou/driver.go index 39a311ddbc0..044193d3584 100644 --- a/drivers/ilanzou/driver.go +++ b/drivers/ilanzou/driver.go @@ -2,7 +2,6 @@ package template import ( "context" - "crypto/md5" "encoding/base64" "encoding/hex" "fmt" @@ -17,6 +16,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/foxxorcat/mopan-sdk-go" "github.com/go-resty/resty/v2" @@ -273,23 +273,14 @@ func (d *ILanZou) Remove(ctx context.Context, obj model.Obj) error { const DefaultPartSize = 1024 * 1024 * 8 func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { - h := md5.New() - // need to calculate md5 of the full content - tempFile, err := s.CacheFullInTempFile() - if err != nil { - return nil, err - } - defer func() { - _ = tempFile.Close() - }() - if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { - return nil, err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return nil, err + etag := s.GetHash().GetHash(utils.MD5) + var err error + if len(etag) != utils.MD5.Width { + _, etag, err = stream.CacheFullInTempFileAndHash(s, utils.MD5) + if err != nil { + return nil, err + } } - etag := hex.EncodeToString(h.Sum(nil)) // get upToken res, err := d.proved("/7n/getUpToken", http.MethodPost, func(req *resty.Request) { req.SetBody(base.Json{ @@ -309,7 +300,7 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreame key := fmt.Sprintf("disk/%d/%d/%d/%s/%016d", now.Year(), now.Month(), now.Day(), d.account, now.UnixMilli()) reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: &driver.SimpleReaderWithSize{ - Reader: tempFile, + Reader: s, Size: s.GetSize(), }, UpdateProgress: up, diff --git a/drivers/mopan/driver.go b/drivers/mopan/driver.go index 736d612a96b..f8f14300571 100644 --- a/drivers/mopan/driver.go +++ b/drivers/mopan/driver.go @@ -269,9 +269,6 @@ func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre if err != nil { return nil, err } - defer func() { - _ = file.Close() - }() // step.1 uploadPartData, err := mopan.InitUploadPartData(ctx, mopan.UpdloadFileParam{ diff --git a/drivers/netease_music/util.go b/drivers/netease_music/util.go index 2e78be14b97..217181062ca 100644 --- a/drivers/netease_music/util.go +++ b/drivers/netease_music/util.go @@ -227,7 +227,6 @@ func (d *NeteaseMusic) putSongStream(ctx context.Context, stream model.FileStrea if err != nil { return err } - defer tmp.Close() u := uploader{driver: d, file: tmp} diff --git a/drivers/onedrive/util.go b/drivers/onedrive/util.go index 554349679d0..e256b7ae262 100644 --- a/drivers/onedrive/util.go +++ b/drivers/onedrive/util.go @@ -220,7 +220,7 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil if err != nil { return err } - req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } diff --git a/drivers/onedrive_app/util.go b/drivers/onedrive_app/util.go index 1b01324e09a..5c3b6c922d8 100644 --- a/drivers/onedrive_app/util.go +++ b/drivers/onedrive_app/util.go @@ -170,7 +170,7 @@ func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model. if err != nil { return err } - req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } diff --git a/drivers/pikpak/util.go b/drivers/pikpak/util.go index 61396aa4ece..f88f085cd25 100644 --- a/drivers/pikpak/util.go +++ b/drivers/pikpak/util.go @@ -7,13 +7,6 @@ import ( "crypto/sha1" "encoding/hex" "fmt" - "github.com/alist-org/alist/v3/internal/driver" - "github.com/alist-org/alist/v3/internal/model" - "github.com/alist-org/alist/v3/internal/op" - "github.com/alist-org/alist/v3/pkg/utils" - "github.com/aliyun/aliyun-oss-go-sdk/oss" - jsoniter "github.com/json-iterator/go" - "github.com/pkg/errors" "io" "net/http" "path/filepath" @@ -24,7 +17,14 @@ import ( "time" "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aliyun/aliyun-oss-go-sdk/oss" "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + "github.com/pkg/errors" ) var AndroidAlgorithms = []string{ @@ -516,7 +516,7 @@ func (d *PikPak) UploadByMultipart(ctx context.Context, params *S3Params, fileSi continue } - b := driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf)) + b := driver.NewLimitedUploadStream(ctx, bytes.NewReader(buf)) if part, err = bucket.UploadPart(imur, b, chunk.Size, chunk.Number, OssOption(params)...); err == nil { break } diff --git a/drivers/quark_uc/driver.go b/drivers/quark_uc/driver.go index 0f8884fac53..7f497494502 100644 --- a/drivers/quark_uc/driver.go +++ b/drivers/quark_uc/driver.go @@ -3,9 +3,8 @@ package quark import ( "bytes" "context" - "crypto/md5" - "crypto/sha1" "encoding/hex" + "hash" "io" "net/http" "time" @@ -14,6 +13,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + streamPkg "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" @@ -136,33 +136,33 @@ func (d *QuarkOrUC) Remove(ctx context.Context, obj model.Obj) error { } func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - tempFile, err := stream.CacheFullInTempFile() - if err != nil { - return err - } - defer func() { - _ = tempFile.Close() - }() - m := md5.New() - _, err = utils.CopyWithBuffer(m, tempFile) - if err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } - md5Str := hex.EncodeToString(m.Sum(nil)) - s := sha1.New() - _, err = utils.CopyWithBuffer(s, tempFile) - if err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err + md5Str, sha1Str := stream.GetHash().GetHash(utils.MD5), stream.GetHash().GetHash(utils.SHA1) + var ( + md5 hash.Hash + sha1 hash.Hash + ) + writers := []io.Writer{} + if len(md5Str) != utils.MD5.Width { + md5 = utils.MD5.NewFunc() + writers = append(writers, md5) + } + if len(sha1Str) != utils.SHA1.Width { + sha1 = utils.SHA1.NewFunc() + writers = append(writers, sha1) + } + + if len(writers) > 0 { + _, err := streamPkg.CacheFullInTempFileAndWriter(stream, io.MultiWriter(writers...)) + if err != nil { + return err + } + if md5 != nil { + md5Str = hex.EncodeToString(md5.Sum(nil)) + } + if sha1 != nil { + sha1Str = hex.EncodeToString(sha1.Sum(nil)) + } } - sha1Str := hex.EncodeToString(s.Sum(nil)) // pre pre, err := d.upPre(stream, dstDir.GetID()) if err != nil { @@ -178,27 +178,28 @@ func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.File return nil } // part up - partSize := pre.Metadata.PartSize - var part []byte - md5s := make([]string, 0) - defaultBytes := make([]byte, partSize) total := stream.GetSize() left := total + partSize := int64(pre.Metadata.PartSize) + part := make([]byte, partSize) + count := int(total / partSize) + if total%partSize > 0 { + count++ + } + md5s := make([]string, 0, count) partNumber := 1 for left > 0 { if utils.IsCanceled(ctx) { return ctx.Err() } - if left > int64(partSize) { - part = defaultBytes - } else { - part = make([]byte, left) + if left < partSize { + part = part[:left] } - _, err := io.ReadFull(tempFile, part) + n, err := io.ReadFull(stream, part) if err != nil { return err } - left -= int64(len(part)) + left -= int64(n) log.Debugf("left: %d", left) reader := driver.NewLimitedUploadStream(ctx, bytes.NewReader(part)) m, err := d.upPart(ctx, pre, stream.GetMimetype(), partNumber, reader) diff --git a/drivers/thunder/driver.go b/drivers/thunder/driver.go index 7f41d003838..51396ee8038 100644 --- a/drivers/thunder/driver.go +++ b/drivers/thunder/driver.go @@ -12,6 +12,7 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" "github.com/aws/aws-sdk-go/aws" @@ -333,22 +334,17 @@ func (xc *XunLeiCommon) Remove(ctx context.Context, obj model.Obj) error { } func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { - hi := file.GetHash() - gcid := hi.GetHash(hash_extend.GCID) + gcid := file.GetHash().GetHash(hash_extend.GCID) + var err error if len(gcid) < hash_extend.GCID.Width { - tFile, err := file.CacheFullInTempFile() - if err != nil { - return err - } - - gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize()) + _, gcid, err = stream.CacheFullInTempFileAndHash(file, hash_extend.GCID, file.GetSize()) if err != nil { return err } } var resp UploadTaskResponse - _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + _, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { r.SetContext(ctx) r.SetBody(&base.Json{ "kind": FILE, diff --git a/drivers/thunder_browser/driver.go b/drivers/thunder_browser/driver.go index 7ce71f7d265..0b38d07714f 100644 --- a/drivers/thunder_browser/driver.go +++ b/drivers/thunder_browser/driver.go @@ -4,10 +4,15 @@ import ( "context" "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" + streamPkg "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" "github.com/aws/aws-sdk-go/aws" @@ -15,9 +20,6 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/go-resty/resty/v2" - "io" - "net/http" - "strings" ) type ThunderBrowser struct { @@ -456,15 +458,10 @@ func (xc *XunLeiBrowserCommon) Remove(ctx context.Context, obj model.Obj) error } func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - hi := stream.GetHash() - gcid := hi.GetHash(hash_extend.GCID) + gcid := stream.GetHash().GetHash(hash_extend.GCID) + var err error if len(gcid) < hash_extend.GCID.Width { - tFile, err := stream.CacheFullInTempFile() - if err != nil { - return err - } - - gcid, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + _, gcid, err = streamPkg.CacheFullInTempFileAndHash(stream, hash_extend.GCID, stream.GetSize()) if err != nil { return err } @@ -481,7 +478,7 @@ func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream } var resp UploadTaskResponse - _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + _, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { r.SetContext(ctx) r.SetBody(&js) }, &resp) diff --git a/drivers/thunderx/driver.go b/drivers/thunderx/driver.go index 2194bdc6e9c..6ee8901a4fc 100644 --- a/drivers/thunderx/driver.go +++ b/drivers/thunderx/driver.go @@ -3,11 +3,15 @@ package thunderx import ( "context" "fmt" + "net/http" + "strings" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" "github.com/aws/aws-sdk-go/aws" @@ -15,8 +19,6 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/go-resty/resty/v2" - "net/http" - "strings" ) type ThunderX struct { @@ -364,22 +366,17 @@ func (xc *XunLeiXCommon) Remove(ctx context.Context, obj model.Obj) error { } func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { - hi := file.GetHash() - gcid := hi.GetHash(hash_extend.GCID) + gcid := file.GetHash().GetHash(hash_extend.GCID) + var err error if len(gcid) < hash_extend.GCID.Width { - tFile, err := file.CacheFullInTempFile() - if err != nil { - return err - } - - gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize()) + _, gcid, err = stream.CacheFullInTempFileAndHash(file, hash_extend.GCID, file.GetSize()) if err != nil { return err } } var resp UploadTaskResponse - _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + _, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { r.SetContext(ctx) r.SetBody(&base.Json{ "kind": FILE, diff --git a/internal/archive/archives/utils.go b/internal/archive/archives/utils.go index fdae10091f6..2f499a10feb 100644 --- a/internal/archive/archives/utils.go +++ b/internal/archive/archives/utils.go @@ -10,6 +10,7 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/mholt/archives" ) @@ -73,7 +74,7 @@ func decompress(fsys fs2.FS, filePath, targetPath string, up model.UpdateProgres return err } defer f.Close() - _, err = io.Copy(f, &stream.ReaderUpdatingProgress{ + _, err = utils.CopyWithBuffer(f, &stream.ReaderUpdatingProgress{ Reader: &stream.SimpleReaderWithSize{ Reader: rc, Size: stat.Size(), diff --git a/internal/archive/iso9660/utils.go b/internal/archive/iso9660/utils.go index 12de8e6ea28..0e4cfb1caf3 100644 --- a/internal/archive/iso9660/utils.go +++ b/internal/archive/iso9660/utils.go @@ -1,14 +1,15 @@ package iso9660 import ( + "os" + stdpath "path" + "strings" + "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/kdomanski/iso9660" - "io" - "os" - stdpath "path" - "strings" ) func getImage(ss *stream.SeekableStream) (*iso9660.Image, error) { @@ -66,7 +67,7 @@ func decompress(f *iso9660.File, path string, up model.UpdateProgress) error { return err } defer file.Close() - _, err = io.Copy(file, &stream.ReaderUpdatingProgress{ + _, err = utils.CopyWithBuffer(file, &stream.ReaderUpdatingProgress{ Reader: &stream.SimpleReaderWithSize{ Reader: f.Reader(), Size: f.Size(), diff --git a/internal/fs/archive.go b/internal/fs/archive.go index b056decf9a2..dbae9b338de 100644 --- a/internal/fs/archive.go +++ b/internal/fs/archive.go @@ -90,9 +90,11 @@ func (t *ArchiveDownloadTask) RunWithoutPushUploadTask() (*ArchiveContentUploadT t.SetTotalBytes(total) t.status = "getting src object" for _, s := range ss { - _, err = s.CacheFullInTempFileAndUpdateProgress(func(p float64) { - t.SetProgress((float64(cur) + float64(s.GetSize())*p/100.0) / float64(total)) - }) + if s.GetFile() == nil { + _, err = stream.CacheFullInTempFileAndUpdateProgress(s, func(p float64) { + t.SetProgress((float64(cur) + float64(s.GetSize())*p/100.0) / float64(total)) + }) + } cur += s.GetSize() if err != nil { return nil, err diff --git a/internal/model/obj.go b/internal/model/obj.go index 552b1241e6e..f0fce7a133a 100644 --- a/internal/model/obj.go +++ b/internal/model/obj.go @@ -2,6 +2,7 @@ package model import ( "io" + "os" "sort" "strings" "time" @@ -48,7 +49,8 @@ type FileStreamer interface { RangeRead(http_range.Range) (io.Reader, error) //for a non-seekable Stream, if Read is called, this function won't work CacheFullInTempFile() (File, error) - CacheFullInTempFileAndUpdateProgress(up UpdateProgress) (File, error) + SetTmpFile(r *os.File) + GetFile() File } type UpdateProgress func(percentage float64) diff --git a/internal/net/request.go b/internal/net/request.go index d4f9321c585..a1ff6d20cf9 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -248,8 +248,9 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error { size: finalSize, id: d.nextChunk, buf: buf, + + newConcurrency: newConcurrency, } - ch.newConcurrency = newConcurrency d.pos += finalSize d.nextChunk++ d.chunkChannel <- ch diff --git a/internal/stream/stream.go b/internal/stream/stream.go index f6b045a0238..64160915792 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -94,27 +94,17 @@ func (f *FileStream) CacheFullInTempFile() (model.File, error) { f.Add(tmpF) f.tmpFile = tmpF f.Reader = tmpF - return f.tmpFile, nil + return tmpF, nil } -func (f *FileStream) CacheFullInTempFileAndUpdateProgress(up model.UpdateProgress) (model.File, error) { +func (f *FileStream) GetFile() model.File { if f.tmpFile != nil { - return f.tmpFile, nil + return f.tmpFile } if file, ok := f.Reader.(model.File); ok { - return file, nil - } - tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{ - Reader: f, - UpdateProgress: up, - }, f.GetSize()) - if err != nil { - return nil, err + return file } - f.Add(tmpF) - f.tmpFile = tmpF - f.Reader = tmpF - return f.tmpFile, nil + return nil } const InMemoryBufMaxSize = 10 // Megabytes @@ -127,31 +117,36 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { // 参考 internal/net/request.go httpRange.Length = f.GetSize() - httpRange.Start } - if f.peekBuff != nil && httpRange.Start < int64(f.peekBuff.Len()) && httpRange.Start+httpRange.Length-1 < int64(f.peekBuff.Len()) { + size := httpRange.Start + httpRange.Length + if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) { return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil } - if f.tmpFile == nil { - if httpRange.Start == 0 && httpRange.Length <= InMemoryBufMaxSizeBytes && f.peekBuff == nil { - bufSize := utils.Min(httpRange.Length, f.GetSize()) - newBuf := bytes.NewBuffer(make([]byte, 0, bufSize)) - n, err := utils.CopyWithBufferN(newBuf, f.Reader, bufSize) + var cache io.ReaderAt = f.GetFile() + if cache == nil { + if size <= InMemoryBufMaxSizeBytes { + bufSize := min(size, f.GetSize()) + // 使用bytes.Buffer作为io.CopyBuffer的写入对象,CopyBuffer会调用Buffer.ReadFrom + // 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大 + buf := make([]byte, bufSize) + n, err := io.ReadFull(f.Reader, buf) if err != nil { return nil, err } - if n != bufSize { + if n != int(bufSize) { return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n) } - f.peekBuff = bytes.NewReader(newBuf.Bytes()) + f.peekBuff = bytes.NewReader(buf) f.Reader = io.MultiReader(f.peekBuff, f.Reader) - return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil + cache = f.peekBuff } else { - _, err := f.CacheFullInTempFile() + var err error + cache, err = f.CacheFullInTempFile() if err != nil { return nil, err } } } - return io.NewSectionReader(f.tmpFile, httpRange.Start, httpRange.Length), nil + return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil } var _ model.FileStreamer = (*SeekableStream)(nil) @@ -176,13 +171,13 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) if len(fs.Mimetype) == 0 { fs.Mimetype = utils.GetMimeType(fs.Obj.GetName()) } - ss := SeekableStream{FileStream: fs, Link: link} + ss := &SeekableStream{FileStream: fs, Link: link} if ss.Reader != nil { result, ok := ss.Reader.(model.File) if ok { ss.mFile = result ss.Closers.Add(result) - return &ss, nil + return ss, nil } } if ss.Link != nil { @@ -198,7 +193,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) ss.mFile = mFile ss.Reader = mFile ss.Closers.Add(mFile) - return &ss, nil + return ss, nil } if ss.Link.RangeReadCloser != nil { ss.rangeReadCloser = &RateLimitRangeReadCloser{ @@ -206,7 +201,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) Limiter: ServerDownloadLimit, } ss.Add(ss.rangeReadCloser) - return &ss, nil + return ss, nil } if len(ss.Link.URL) > 0 { rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link) @@ -219,10 +214,12 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) } ss.rangeReadCloser = rrc ss.Add(rrc) - return &ss, nil + return ss, nil } } - + if fs.Reader != nil { + return ss, nil + } return nil, fmt.Errorf("illegal seekableStream") } @@ -248,7 +245,7 @@ func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, erro } return rc, nil } - return nil, fmt.Errorf("can't find mFile or rangeReadCloser") + return ss.FileStream.RangeRead(httpRange) } //func (f *FileStream) GetReader() io.Reader { @@ -278,7 +275,7 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { if ss.tmpFile != nil { return ss.tmpFile, nil } - if _, ok := ss.mFile.(*os.File); ok { + if ss.mFile != nil { return ss.mFile, nil } tmpF, err := utils.CreateTempFile(ss, ss.GetSize()) @@ -288,27 +285,17 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { ss.Add(tmpF) ss.tmpFile = tmpF ss.Reader = tmpF - return ss.tmpFile, nil + return tmpF, nil } -func (ss *SeekableStream) CacheFullInTempFileAndUpdateProgress(up model.UpdateProgress) (model.File, error) { +func (ss *SeekableStream) GetFile() model.File { if ss.tmpFile != nil { - return ss.tmpFile, nil - } - if _, ok := ss.mFile.(*os.File); ok { - return ss.mFile, nil + return ss.tmpFile } - tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{ - Reader: ss, - UpdateProgress: up, - }, ss.GetSize()) - if err != nil { - return nil, err + if ss.mFile != nil { + return ss.mFile } - ss.Add(tmpF) - ss.tmpFile = tmpF - ss.Reader = tmpF - return ss.tmpFile, nil + return nil } func (f *FileStream) SetTmpFile(r *os.File) { diff --git a/internal/stream/util.go b/internal/stream/util.go index 01019482e15..5b935a9043e 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -2,6 +2,7 @@ package stream import ( "context" + "encoding/hex" "fmt" "io" "net/http" @@ -96,3 +97,45 @@ func (r *ReaderWithCtx) Close() error { } return nil } + +func CacheFullInTempFileAndUpdateProgress(stream model.FileStreamer, up model.UpdateProgress) (model.File, error) { + if cache := stream.GetFile(); cache != nil { + up(100) + return cache, nil + } + tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{ + Reader: stream, + UpdateProgress: up, + }, stream.GetSize()) + if err == nil { + stream.SetTmpFile(tmpF) + } + return tmpF, err +} + +func CacheFullInTempFileAndWriter(stream model.FileStreamer, w io.Writer) (model.File, error) { + if cache := stream.GetFile(); cache != nil { + _, err := cache.Seek(0, io.SeekStart) + if err == nil { + _, err = utils.CopyWithBuffer(w, cache) + if err == nil { + _, err = cache.Seek(0, io.SeekStart) + } + } + return cache, err + } + tmpF, err := utils.CreateTempFile(io.TeeReader(stream, w), stream.GetSize()) + if err == nil { + stream.SetTmpFile(tmpF) + } + return tmpF, err +} + +func CacheFullInTempFileAndHash(stream model.FileStreamer, hashType *utils.HashType, params ...any) (model.File, string, error) { + h := hashType.NewFunc(params...) + tmpF, err := CacheFullInTempFileAndWriter(stream, h) + if err != nil { + return nil, "", err + } + return tmpF, hex.EncodeToString(h.Sum(nil)), err +} diff --git a/server/handles/fsup.go b/server/handles/fsup.go index 15a6328b60b..41344fb8d56 100644 --- a/server/handles/fsup.go +++ b/server/handles/fsup.go @@ -1,8 +1,6 @@ package handles import ( - "github.com/alist-org/alist/v3/internal/task" - "github.com/alist-org/alist/v3/pkg/utils" "io" "net/url" stdpath "path" @@ -12,6 +10,8 @@ import ( "github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/internal/task" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" ) @@ -44,7 +44,7 @@ func FsStream(c *gin.Context) { } if !overwrite { if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { - _, _ = io.Copy(io.Discard, c.Request.Body) + _, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body) common.ErrorStrResp(c, "file exists", 403) return } @@ -66,6 +66,10 @@ func FsStream(c *gin.Context) { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { h[utils.SHA256] = sha256 } + mimetype := c.GetHeader("Content-Type") + if len(mimetype) == 0 { + mimetype = utils.GetMimeType(name) + } s := &stream.FileStream{ Obj: &model.Object{ Name: name, @@ -74,7 +78,7 @@ func FsStream(c *gin.Context) { HashInfo: utils.NewHashInfoByMap(h), }, Reader: c.Request.Body, - Mimetype: c.GetHeader("Content-Type"), + Mimetype: mimetype, WebPutAsTask: asTask, } var t task.TaskExtensionInfo @@ -89,6 +93,9 @@ func FsStream(c *gin.Context) { return } if t == nil { + if n, _ := io.ReadFull(c.Request.Body, []byte{0}); n == 1 { + _, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body) + } common.SuccessResp(c) return } @@ -114,7 +121,7 @@ func FsForm(c *gin.Context) { } if !overwrite { if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { - _, _ = io.Copy(io.Discard, c.Request.Body) + _, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body) common.ErrorStrResp(c, "file exists", 403) return } @@ -150,6 +157,10 @@ func FsForm(c *gin.Context) { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { h[utils.SHA256] = sha256 } + mimetype := file.Header.Get("Content-Type") + if len(mimetype) == 0 { + mimetype = utils.GetMimeType(name) + } s := stream.FileStream{ Obj: &model.Object{ Name: name, @@ -158,7 +169,7 @@ func FsForm(c *gin.Context) { HashInfo: utils.NewHashInfoByMap(h), }, Reader: f, - Mimetype: file.Header.Get("Content-Type"), + Mimetype: mimetype, WebPutAsTask: asTask, } var t task.TaskExtensionInfo @@ -168,12 +179,7 @@ func FsForm(c *gin.Context) { }{f} t, err = fs.PutAsTask(c, dir, &s) } else { - ss, err := stream.NewSeekableStream(s, nil) - if err != nil { - common.ErrorResp(c, err, 500) - return - } - err = fs.PutDirectly(c, dir, ss, true) + err = fs.PutDirectly(c, dir, &s, true) } if err != nil { common.ErrorResp(c, err, 500)