diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 6686eaffa..25864a105 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net/http" "net/url" "os" stdpath "path" @@ -35,11 +36,15 @@ type BaiduNetdisk struct { uploadThread int vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M) - upClient *resty.Client // 上传文件使用的http客户端 - uploadUrlG singleflight.Group[string] - uploadUrlMu sync.RWMutex - uploadUrl string // 上传域名 - uploadUrlUpdateTime time.Time // 上传域名上次更新时间 + upClient *resty.Client // 上传文件使用的http客户端 + uploadUrlG singleflight.Group[string] + uploadUrlMu sync.RWMutex + uploadUrlCache map[string]uploadURLCacheEntry +} + +type uploadURLCacheEntry struct { + url string + updateTime time.Time } var ErrUploadIDExpired = errors.New("uploadid expired") @@ -58,6 +63,7 @@ func (d *BaiduNetdisk) Init(ctx context.Context) error { SetRetryCount(UPLOAD_RETRY_COUNT). SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME). SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME) + d.uploadUrlCache = make(map[string]uploadURLCacheEntry) d.uploadThread, _ = strconv.Atoi(d.UploadThread) if d.uploadThread < 1 { d.uploadThread, d.UploadThread = 1, "1" @@ -298,12 +304,22 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return fileToObj(precreateResp.File), nil } } + ensureUploadURL := func() { + if precreateResp.UploadURL != "" { + return + } + precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid) + } + ensureUploadURL() // step.2 上传分片 uploadLoop: for attempt := 0; attempt < 2; attempt++ { // 获取上传域名 - uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid) + if precreateResp.UploadURL == "" { + ensureUploadURL() + } + uploadUrl := precreateResp.UploadURL // 并发上传 threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, retry.Attempts(1), @@ -363,6 +379,7 @@ uploadLoop: } if errors.Is(err, ErrUploadIDExpired) { log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch") + d.clearUploadUrlCache(precreateResp.Uploadid) // 重新 precreate(所有分片都要重传) newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime) if err2 != nil { @@ -372,6 +389,8 @@ uploadLoop: return fileToObj(newPre.File), nil } precreateResp = newPre + precreateResp.UploadURL = "" + ensureUploadURL() // 覆盖掉旧的进度 base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) continue uploadLoop @@ -390,6 +409,7 @@ uploadLoop: newFile.Mtime = mtime // 上传成功清理进度 base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5) + d.clearUploadUrlCache(precreateResp.Uploadid) return fileToObj(newFile), nil } @@ -438,6 +458,9 @@ func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params return err } log.Debugln(res.RawResponse.Status + res.String()) + if res.StatusCode() != http.StatusOK { + return errs.NewErr(errs.StreamIncomplete, "baidu upload failed, status=%d, body=%s", res.StatusCode(), res.String()) + } errCode := utils.Json.Get(res.Body(), "error_code").ToInt() errNo := utils.Json.Get(res.Body(), "errno").ToInt() respStr := res.String() diff --git a/drivers/baidu_netdisk/types.go b/drivers/baidu_netdisk/types.go index bb920d165..03e84b396 100644 --- a/drivers/baidu_netdisk/types.go +++ b/drivers/baidu_netdisk/types.go @@ -193,6 +193,8 @@ type PrecreateResp struct { // return_type=2 File File `json:"info"` + + UploadURL string `json:"-"` // 保存断点续传对应的上传域名 } type UploadServerResp struct { diff --git a/drivers/baidu_netdisk/util.go b/drivers/baidu_netdisk/util.go index cd38d11fe..70c1f4c2b 100644 --- a/drivers/baidu_netdisk/util.go +++ b/drivers/baidu_netdisk/util.go @@ -394,29 +394,28 @@ func (d *BaiduNetdisk) quota(ctx context.Context) (model.DiskUsage, error) { return driver.DiskUsageFromUsedAndTotal(resp.Used, resp.Total), nil } -// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会被缓存1h。 +// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会在 uploadid 生命周期内复用。 // 如果获取失败,则返回 Upload API设置项。 func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string { - if !d.UseDynamicUploadAPI { + if !d.UseDynamicUploadAPI || uploadId == "" { return d.UploadAPI } - getCachedUrlFunc := func() string { + getCachedUrlFunc := func() (string, bool) { d.uploadUrlMu.RLock() defer d.uploadUrlMu.RUnlock() - if d.uploadUrl != "" && time.Since(d.uploadUrlUpdateTime) < UPLOAD_URL_EXPIRE_TIME { - uploadUrl := d.uploadUrl - return uploadUrl + if entry, ok := d.uploadUrlCache[uploadId]; ok { + return entry.url, true } - return "" + return "", false } // 检查地址缓存 - if uploadUrl := getCachedUrlFunc(); uploadUrl != "" { + if uploadUrl, ok := getCachedUrlFunc(); ok { return uploadUrl } uploadUrlGetFunc := func() (string, error) { // 双重检查缓存 - if uploadUrl := getCachedUrlFunc(); uploadUrl != "" { + if uploadUrl, ok := getCachedUrlFunc(); ok { return uploadUrl, nil } @@ -426,13 +425,15 @@ func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string { } d.uploadUrlMu.Lock() - defer d.uploadUrlMu.Unlock() - d.uploadUrl = uploadUrl - d.uploadUrlUpdateTime = time.Now() + d.uploadUrlCache[uploadId] = uploadURLCacheEntry{ + url: uploadUrl, + updateTime: time.Now(), + } + d.uploadUrlMu.Unlock() return uploadUrl, nil } - uploadUrl, err, _ := d.uploadUrlG.Do("", uploadUrlGetFunc) + uploadUrl, err, _ := d.uploadUrlG.Do(uploadId, uploadUrlGetFunc) if err != nil { fallback := d.UploadAPI log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, fallback) @@ -441,6 +442,17 @@ func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string { return uploadUrl } +func (d *BaiduNetdisk) clearUploadUrlCache(uploadId string) { + if uploadId == "" { + return + } + d.uploadUrlMu.Lock() + if _, ok := d.uploadUrlCache[uploadId]; ok { + delete(d.uploadUrlCache, uploadId) + } + d.uploadUrlMu.Unlock() +} + // requestForUploadUrl 请求获取上传地址。 // 实测此接口不需要认证,传method和upload_version就行,不过还是按文档规范调用。 // https://pan.baidu.com/union/doc/Mlvw5hfnr