Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 182 additions & 30 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"context"
"fmt"
"io"
"math"
"math/rand"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -193,6 +196,9 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps,
httpClient: bi.httpClient,
retryMax: bi.cfg.RetryMax,
retryWaitMin: bi.cfg.RetryWaitMin,
retryWaitMax: bi.cfg.RetryWaitMax,
}
task.Run()
bi.downloadTasks.Enqueue(task)
Expand Down Expand Up @@ -252,6 +258,9 @@ type cloudFetchDownloadTask struct {
resultChan chan cloudFetchDownloadTaskResult
speedThresholdMbps float64
httpClient *http.Client
retryMax int
retryWaitMin time.Duration
retryWaitMax time.Duration
}

func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, int64, error) {
Expand Down Expand Up @@ -295,20 +304,32 @@ func (cft *cloudFetchDownloadTask) Run() {
cft.link.RowCount,
)
downloadStart := time.Now()
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient)
rawBody, err := fetchBatchBytes(
cft.ctx,
cft.link,
cft.minTimeToExpiry,
cft.speedThresholdMbps,
cft.httpClient,
cft.retryMax,
cft.retryWaitMin,
cft.retryWaitMax,
)
if err != nil {
cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err}
return
}

// Read all data into memory before closing
buf, err := io.ReadAll(getReader(data, cft.useLz4Compression))
data.Close() //nolint:errcheck,gosec // G104: close after reading data
downloadMs := time.Since(downloadStart).Milliseconds()
if err != nil {
cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err}
return
buf := rawBody
if cft.useLz4Compression {
// Decompression sits outside the retry loop: malformed LZ4 is data
// corruption, not a transient network condition.
buf, err = io.ReadAll(lz4.NewReader(bytes.NewReader(rawBody)))
if err != nil {
cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err}
return
}
}
downloadMs := time.Since(downloadStart).Milliseconds()

logger.Debug().Msgf(
"CloudFetch: downloaded data for link at offset %d row count %d",
Expand Down Expand Up @@ -339,43 +360,174 @@ func logCloudFetchSpeed(fullURL string, contentLength int64, duration time.Durat
}
}

// fetchBatchBytes downloads a single CloudFetch result link and returns the
// raw response body, still compressed if the server used LZ4. Connection-time
// failures, retryable HTTP statuses, and mid-stream body read failures are
// retried up to retryMax times with exponential backoff and equal jitter.
// Decompression and IPC parsing stay with the caller because those failures are
// not transient network conditions.
//
// Link expiry is rechecked after each backoff: a long retry chain may outlive
// a presigned URL, and continuing past expiry is guaranteed to fail.
func fetchBatchBytes(
ctx context.Context,
link *cli_service.TSparkArrowResultLink,
minTimeToExpiry time.Duration,
speedThresholdMbps float64,
httpClient *http.Client,
) (io.ReadCloser, error) {
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
return nil, errors.New(dbsqlerr.ErrLinkExpired)
retryMax int,
retryWaitMin time.Duration,
retryWaitMax time.Duration,
) ([]byte, error) {
if retryMax < 0 {
retryMax = 0
}

// TODO: Retry on HTTP errors
req, err := http.NewRequestWithContext(ctx, "GET", link.FileLink, nil)
if err != nil {
return nil, err
}
var (
lastErr error
lastStatus int
lastRetryAfter string
)

for attempt := 0; attempt <= retryMax; attempt++ {
if attempt > 0 {
wait := cloudFetchBackoff(attempt, retryWaitMin, retryWaitMax, lastRetryAfter)
logger.Debug().Msgf(
"CloudFetch: retrying download of link at offset %d (attempt %d/%d) in %v; lastStatus=%d lastErr=%v",
link.StartRowOffset, attempt, retryMax, wait, lastStatus, lastErr,
)
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(wait):
}
}

if link.HttpHeaders != nil {
for key, value := range link.HttpHeaders {
req.Header.Set(key, value)
// Check link expiry *after* backoff: a long retry chain may outlive a
// presigned URL, and there's no point spending another HTTP attempt
// (or another retry) on a link we know will be rejected.
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
return nil, errors.New(dbsqlerr.ErrLinkExpired)
}
}

startTime := time.Now()
res, err := httpClient.Do(req)
if err != nil {
return nil, err
req, err := http.NewRequestWithContext(ctx, "GET", link.FileLink, nil)
if err != nil {
return nil, err
}
if link.HttpHeaders != nil {
for key, value := range link.HttpHeaders {
req.Header.Set(key, value)
}
}

startTime := time.Now()
res, err := httpClient.Do(req)
if err != nil {
// Caller cancellation is terminal; otherwise treat transport errors
// (TCP RST, TLS timeout, etc.) as transient.
if ctx.Err() != nil {
return nil, ctx.Err()
}
lastErr = err
lastStatus = 0
lastRetryAfter = ""
continue
}

if res.StatusCode == http.StatusOK {
// Read the full body inside the retry loop so truncated 200 OK
// responses are retried just like header-time failures.
buf, readErr := io.ReadAll(res.Body)
res.Body.Close() //nolint:errcheck,gosec // G104: close after drain
if readErr != nil {
if ctx.Err() != nil {
return nil, ctx.Err()
}
lastErr = readErr
lastStatus = 0
lastRetryAfter = ""
continue
}
logCloudFetchSpeed(link.FileLink, int64(len(buf)), time.Since(startTime), speedThresholdMbps)
return buf, nil
}

// Drain and close so the underlying connection can be reused.
_, _ = io.Copy(io.Discard, res.Body)
res.Body.Close() //nolint:errcheck,gosec // G104: closing after drain

lastStatus = res.StatusCode
lastErr = nil
lastRetryAfter = ""
if res.Header != nil {
lastRetryAfter = res.Header.Get("Retry-After")
}

if !isCloudFetchRetryableStatus(res.StatusCode) {
msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode)
return nil, dbsqlerrint.NewDriverError(ctx, msg, nil)
}
}
if res.StatusCode != http.StatusOK {
msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode)
return nil, dbsqlerrint.NewDriverError(ctx, msg, err)

if lastStatus != 0 {
msg := fmt.Sprintf("%s: %s %d (after %d retries)", errArrowRowsCloudFetchDownloadFailure, "HTTP error", lastStatus, retryMax)
return nil, dbsqlerrint.NewDriverError(ctx, msg, nil)
}
msg := fmt.Sprintf("%s: %v (after %d retries)", errArrowRowsCloudFetchDownloadFailure, lastErr, retryMax)
return nil, dbsqlerrint.NewDriverError(ctx, msg, lastErr)
}

// cloudFetchRetryableStatuses lists HTTP status codes from object storage that
// indicate transient conditions and warrant a retry. Mirrors AWS S3 guidance
// for SlowDown (503) / InternalError (500) plus the general 408/429/502/504.
var cloudFetchRetryableStatuses = map[int]struct{}{
http.StatusRequestTimeout: {}, // 408
http.StatusTooManyRequests: {}, // 429
http.StatusInternalServerError: {}, // 500
http.StatusBadGateway: {}, // 502
http.StatusServiceUnavailable: {}, // 503
http.StatusGatewayTimeout: {}, // 504
}

func isCloudFetchRetryableStatus(status int) bool {
_, ok := cloudFetchRetryableStatuses[status]
return ok
}

// Log download speed metrics
logCloudFetchSpeed(link.FileLink, res.ContentLength, time.Since(startTime), speedThresholdMbps)
// cloudFetchBackoff returns the wait before retry attempt N (1-based). The
// base delay is exponential — waitMin * 2^(attempt-1) capped at waitMax — with
// equal jitter applied: the actual sleep is uniformly distributed in
// [base/2, base]. Equal jitter (rather than no jitter) is used to spread
// synchronized retries across the up-to-MaxDownloadThreads concurrent
// downloads, which would otherwise hammer the storage endpoint in lockstep
// after a region-wide blip. If the server returned a parseable integer
// Retry-After header, that value (in seconds) is honored instead, capped at
// waitMax. HTTP-date Retry-After values are ignored — same as the Thrift
// client's backoff.
func cloudFetchBackoff(attempt int, waitMin, waitMax time.Duration, retryAfter string) time.Duration {
if retryAfter != "" {
if secs, err := strconv.ParseInt(retryAfter, 10, 64); err == nil && secs >= 0 {
d := time.Duration(secs) * time.Second
if d > waitMax {
return waitMax
}
return d
}
}

return res.Body, nil
expo := float64(waitMin) * math.Pow(2, float64(attempt-1))
if expo > float64(waitMax) || math.IsInf(expo, 0) {
expo = float64(waitMax)
}
base := time.Duration(expo)
if base <= 0 {
return 0
}
half := base / 2
if half <= 0 {
return base
}
return half + time.Duration(rand.Int63n(int64(half))) //nolint:gosec // G404: jitter only, non-cryptographic
}

func getReader(r io.Reader, useLz4Compression bool) io.Reader {
Expand Down
Loading
Loading