Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions drivers/baidu_netdisk/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
stdpath "path"
Expand Down Expand Up @@ -35,11 +36,15 @@ type BaiduNetdisk struct {
uploadThread int
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)

upClient *resty.Client // 上传文件使用的http客户端
uploadUrlG singleflight.Group[string]
uploadUrlMu sync.RWMutex
uploadUrl string // 上传域名
uploadUrlUpdateTime time.Time // 上传域名上次更新时间
upClient *resty.Client // 上传文件使用的http客户端
uploadUrlG singleflight.Group[string]
uploadUrlMu sync.RWMutex
uploadUrlCache map[string]uploadURLCacheEntry
}

type uploadURLCacheEntry struct {
url string
updateTime time.Time
}

var ErrUploadIDExpired = errors.New("uploadid expired")
Expand All @@ -58,6 +63,7 @@ func (d *BaiduNetdisk) Init(ctx context.Context) error {
SetRetryCount(UPLOAD_RETRY_COUNT).
SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME).
SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME)
d.uploadUrlCache = make(map[string]uploadURLCacheEntry)
d.uploadThread, _ = strconv.Atoi(d.UploadThread)
if d.uploadThread < 1 {
d.uploadThread, d.UploadThread = 1, "1"
Expand Down Expand Up @@ -298,12 +304,22 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
return fileToObj(precreateResp.File), nil
}
}
ensureUploadURL := func() {
if precreateResp.UploadURL != "" {
return
}
precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid)
}
ensureUploadURL()

// step.2 上传分片
uploadLoop:
for attempt := 0; attempt < 2; attempt++ {
// 获取上传域名
uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid)
if precreateResp.UploadURL == "" {
ensureUploadURL()
}
uploadUrl := precreateResp.UploadURL
// 并发上传
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
retry.Attempts(1),
Expand Down Expand Up @@ -363,6 +379,7 @@ uploadLoop:
}
if errors.Is(err, ErrUploadIDExpired) {
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
d.clearUploadUrlCache(precreateResp.Uploadid)
// 重新 precreate(所有分片都要重传)
newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime)
if err2 != nil {
Expand All @@ -372,6 +389,8 @@ uploadLoop:
return fileToObj(newPre.File), nil
}
precreateResp = newPre
precreateResp.UploadURL = ""
ensureUploadURL()
// 覆盖掉旧的进度
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
continue uploadLoop
Expand All @@ -390,6 +409,7 @@ uploadLoop:
newFile.Mtime = mtime
// 上传成功清理进度
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
d.clearUploadUrlCache(precreateResp.Uploadid)
return fileToObj(newFile), nil
}

Expand Down Expand Up @@ -438,6 +458,9 @@ func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params
return err
}
log.Debugln(res.RawResponse.Status + res.String())
if res.StatusCode() != http.StatusOK {
return errs.NewErr(errs.StreamIncomplete, "baidu upload failed, status=%d, body=%s", res.StatusCode(), res.String())
}
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
respStr := res.String()
Expand Down
2 changes: 2 additions & 0 deletions drivers/baidu_netdisk/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ type PrecreateResp struct {

// return_type=2
File File `json:"info"`

UploadURL string `json:"-"` // 保存断点续传对应的上传域名
}

type UploadServerResp struct {
Expand Down
38 changes: 25 additions & 13 deletions drivers/baidu_netdisk/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,29 +394,28 @@ func (d *BaiduNetdisk) quota(ctx context.Context) (model.DiskUsage, error) {
return driver.DiskUsageFromUsedAndTotal(resp.Used, resp.Total), nil
}

// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会被缓存1h
// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会在 uploadid 生命周期内复用
// 如果获取失败,则返回 Upload API设置项。
func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
if !d.UseDynamicUploadAPI {
if !d.UseDynamicUploadAPI || uploadId == "" {
return d.UploadAPI
}
getCachedUrlFunc := func() string {
getCachedUrlFunc := func() (string, bool) {
d.uploadUrlMu.RLock()
defer d.uploadUrlMu.RUnlock()
if d.uploadUrl != "" && time.Since(d.uploadUrlUpdateTime) < UPLOAD_URL_EXPIRE_TIME {
uploadUrl := d.uploadUrl
return uploadUrl
if entry, ok := d.uploadUrlCache[uploadId]; ok {
return entry.url, true
}
return ""
return "", false
}
// 检查地址缓存
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
if uploadUrl, ok := getCachedUrlFunc(); ok {
return uploadUrl
}

uploadUrlGetFunc := func() (string, error) {
// 双重检查缓存
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
if uploadUrl, ok := getCachedUrlFunc(); ok {
return uploadUrl, nil
}

Expand All @@ -426,13 +425,15 @@ func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
}

d.uploadUrlMu.Lock()
defer d.uploadUrlMu.Unlock()
d.uploadUrl = uploadUrl
d.uploadUrlUpdateTime = time.Now()
d.uploadUrlCache[uploadId] = uploadURLCacheEntry{
url: uploadUrl,
updateTime: time.Now(),
}
d.uploadUrlMu.Unlock()
return uploadUrl, nil
}

uploadUrl, err, _ := d.uploadUrlG.Do("", uploadUrlGetFunc)
uploadUrl, err, _ := d.uploadUrlG.Do(uploadId, uploadUrlGetFunc)
if err != nil {
fallback := d.UploadAPI
log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, fallback)
Expand All @@ -441,6 +442,17 @@ func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
return uploadUrl
}

func (d *BaiduNetdisk) clearUploadUrlCache(uploadId string) {
if uploadId == "" {
return
}
d.uploadUrlMu.Lock()
if _, ok := d.uploadUrlCache[uploadId]; ok {
delete(d.uploadUrlCache, uploadId)
}
d.uploadUrlMu.Unlock()
}

// requestForUploadUrl 请求获取上传地址。
// 实测此接口不需要认证,传method和upload_version就行,不过还是按文档规范调用。
// https://pan.baidu.com/union/doc/Mlvw5hfnr
Expand Down