Skip to content

Commit

Permalink
Add support for progress_callback in Object#download_file (#2902)
Browse files Browse the repository at this point in the history
  • Loading branch information
alextwoods authored Aug 22, 2023
1 parent ce9343c commit 8e6aac3
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 26 deletions.
2 changes: 1 addition & 1 deletion build_tools/services.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ServiceEnumerator
MINIMUM_CORE_VERSION = "3.177.0"

# Minimum `aws-sdk-core` version for new S3 gem builds
MINIMUM_CORE_VERSION_S3 = "3.179.0"
MINIMUM_CORE_VERSION_S3 = "3.181.0"

EVENTSTREAM_PLUGIN = "Aws::Plugins::EventStreamConfiguration"

Expand Down
2 changes: 2 additions & 0 deletions gems/aws-sdk-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Unreleased Changes
------------------

* Feature - Add support for `on_chunk_received` callback.

3.180.3 (2023-08-09)
------------------

Expand Down
31 changes: 31 additions & 0 deletions gems/aws-sdk-core/lib/seahorse/client/plugins/request_callback.rb
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ class RequestCallback < Plugin
bytes in the body.
DOCS

option(:on_chunk_received,
default: nil,
doc_type: 'Proc',
docstring: <<-DOCS)
When a Proc object is provided, it will be used as callback when each chunk
of the response body is received. It provides three arguments: the chunk,
the number of bytes received, and the total number of
bytes in the response (or nil if the server did not send a `content-length`).
DOCS

# @api private
class OptionHandler < Client::Handler
def call(context)
Expand All @@ -68,8 +78,29 @@ def call(context)
end
on_chunk_sent = context.config.on_chunk_sent if on_chunk_sent.nil?
context[:on_chunk_sent] = on_chunk_sent if on_chunk_sent

if context.params.is_a?(Hash) && context.params[:on_chunk_received]
on_chunk_received = context.params.delete(:on_chunk_received)
end
on_chunk_received = context.config.on_chunk_received if on_chunk_received.nil?

add_response_events(on_chunk_received, context) if on_chunk_received

@handler.call(context)
end

def add_response_events(on_chunk_received, context)
shared_data = {bytes_received: 0}

context.http_response.on_headers do |_status, headers|
shared_data[:content_length] = headers['content-length']&.to_i
end

context.http_response.on_data do |chunk|
shared_data[:bytes_received] += chunk.bytesize if chunk && chunk.respond_to?(:bytesize)
on_chunk_received.call(chunk, shared_data[:bytes_received], shared_data[:content_length])
end
end
end

# @api private
Expand Down
2 changes: 2 additions & 0 deletions gems/aws-sdk-s3/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Unreleased Changes
------------------

* Feature - Add support for `progress_callback` in `Object#download_file` and improve multi-threaded performance #(2901).

1.132.1 (2023-08-09)
------------------

Expand Down
42 changes: 42 additions & 0 deletions gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ def public_url(options = {})
# obj.upload_stream do |write_stream|
# IO.copy_stream(STDIN, write_stream)
# end
# @param [Hash] options
# Additional options for {Client#create_multipart_upload},
# {Client#complete_multipart_upload},
# and {Client#upload_part} can be provided.
#
# @option options [Integer] :thread_count (10) The number of parallel
# multipart uploads
Expand All @@ -375,6 +379,9 @@ def public_url(options = {})
# @return [Boolean] Returns `true` when the object is uploaded
# without any errors.
#
# @see Client#create_multipart_upload
# @see Client#complete_multipart_upload
# @see Client#upload_part
def upload_stream(options = {}, &block)
uploading_options = options.dup
uploader = MultipartStreamUploader.new(
Expand Down Expand Up @@ -427,6 +434,13 @@ def upload_stream(options = {}, &block)
# using an open Tempfile, rewind it before uploading or else the object
# will be empty.
#
# @param [Hash] options
# Additional options for {Client#put_object}
# when file sizes below the multipart threshold. For files larger than
# the multipart threshold, options for {Client#create_multipart_upload},
# {Client#complete_multipart_upload},
# and {Client#upload_part} can be provided.
#
# @option options [Integer] :multipart_threshold (104857600) Files larger
# than or equal to `:multipart_threshold` are uploaded using the S3
# multipart APIs.
Expand All @@ -448,6 +462,11 @@ def upload_stream(options = {}, &block)
#
# @return [Boolean] Returns `true` when the object is uploaded
# without any errors.
#
# @see Client#put_object
# @see Client#create_multipart_upload
# @see Client#complete_multipart_upload
# @see Client#upload_part
def upload_file(source, options = {})
uploading_options = options.dup
uploader = FileUploader.new(
Expand Down Expand Up @@ -475,8 +494,21 @@ def upload_file(source, options = {})
# # and the parts are downloaded in parallel
# obj.download_file('/path/to/very_large_file')
#
# You can provide a callback to monitor progress of the download:
#
# # bytes and part_sizes are each an array with 1 entry per part
# # part_sizes may not be known until the first bytes are retrieved
# progress = Proc.new do |bytes, part_sizes, file_size|
# puts bytes.map.with_index { |b, i| "Part #{i+1}: #{b} / #{part_sizes[i]}"}.join(' ') + "Total: #{100.0 * bytes.sum / file_size}%" }
# end
# obj.download_file('/path/to/file', progress_callback: progress)
#
# @param [String] destination Where to download the file to.
#
# @param [Hash] options
# Additional options for {Client#get_object} and #{Client#head_object}
# may be provided.
#
# @option options [String] mode `auto`, `single_request`, `get_range`
# `single_request` mode forces only 1 GET request is made in download,
# `get_range` mode allows `chunk_size` parameter to configured in
Expand Down Expand Up @@ -505,8 +537,18 @@ def upload_file(source, options = {})
# response. For multipart downloads, this will be called for each
# part that is downloaded and validated.
#
# @option options [Proc] :progress_callback
# A Proc that will be called when each chunk of the download is received.
# It will be invoked with [bytes_read], [part_sizes], file_size.
# When the object is downloaded as parts (rather than by ranges), the
# part_sizes will not be known ahead of time and will be nil in the
# callback until the first bytes in the part are received.
#
# @return [Boolean] Returns `true` when the file is downloaded without
# any errors.
#
# @see Client#get_object
# @see Client#head_object
def download_file(destination, options = {})
downloader = FileDownloader.new(client: client)
Aws::Plugins::UserAgent.feature('resource') do
Expand Down
136 changes: 113 additions & 23 deletions gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def download(destination, options = {})
end
@on_checksum_validated = options[:on_checksum_validated]

@progress_callback = options[:progress_callback]

validate!

Aws::Plugins::UserAgent.feature('s3-transfer') do
Expand All @@ -49,7 +51,7 @@ def download(destination, options = {})
when 'get_range'
if @chunk_size
resp = @client.head_object(@params)
multithreaded_get_by_ranges(construct_chunks(resp.content_length))
multithreaded_get_by_ranges(resp.content_length)
else
msg = 'In :get_range mode, :chunk_size must be provided'
raise ArgumentError, msg
Expand Down Expand Up @@ -82,7 +84,7 @@ def multipart_download
if resp.content_length <= MIN_CHUNK_SIZE
single_request
else
multithreaded_get_by_ranges(construct_chunks(resp.content_length))
multithreaded_get_by_ranges(resp.content_length)
end
else
# partNumber is an option
Expand All @@ -99,9 +101,9 @@ def compute_mode(file_size, count)
chunk_size = compute_chunk(file_size)
part_size = (file_size.to_f / count.to_f).ceil
if chunk_size < part_size
multithreaded_get_by_ranges(construct_chunks(file_size))
multithreaded_get_by_ranges(file_size)
else
multithreaded_get_by_parts(count)
multithreaded_get_by_parts(count, file_size)
end
end

Expand Down Expand Up @@ -133,30 +135,65 @@ def batches(chunks, mode)
chunks.each_slice(@thread_count).to_a
end

def multithreaded_get_by_ranges(chunks)
thread_batches(chunks, 'range')
def multithreaded_get_by_ranges(file_size)
offset = 0
default_chunk_size = compute_chunk(file_size)
chunks = []
part_number = 1 # parts start at 1
while offset < file_size
progress = offset + default_chunk_size
progress = file_size if progress > file_size
range = "bytes=#{offset}-#{progress - 1}"
chunks << Part.new(
part_number: part_number,
size: (progress-offset),
params: @params.merge(range: range)
)
part_number += 1
offset = progress
end
download_in_threads(PartList.new(chunks), file_size)
end

def multithreaded_get_by_parts(parts)
thread_batches(parts, 'part_number')
def multithreaded_get_by_parts(n_parts, total_size)
parts = (1..n_parts).map do |part|
Part.new(part_number: part, params: @params.merge(part_number: part))
end
download_in_threads(PartList.new(parts), total_size)
end

def thread_batches(chunks, param)
batches(chunks, param).each do |batch|
threads = []
batch.each do |chunk|
threads << Thread.new do
resp = @client.get_object(
@params.merge(param.to_sym => chunk)
)
write(resp)
if @on_checksum_validated && resp.checksum_validated
@on_checksum_validated.call(resp.checksum_validated, resp)
def download_in_threads(pending, total_size)
threads = []
if @progress_callback
progress = MultipartProgress.new(pending, total_size, @progress_callback)
end
@thread_count.times do
thread = Thread.new do
begin
while part = pending.shift
if progress
part.params[:on_chunk_received] =
proc do |_chunk, bytes, total|
progress.call(part.part_number, bytes, total)
end
end
resp = @client.get_object(part.params)
write(resp)
if @on_checksum_validated && resp.checksum_validated
@on_checksum_validated.call(resp.checksum_validated, resp)
end
end
nil
rescue => error
# keep other threads from downloading other parts
pending.clear!
raise error
end
end
threads.each(&:join)
thread.abort_on_exception = true
threads << thread
end
threads.map(&:value).compact
end

def write(resp)
Expand All @@ -166,9 +203,9 @@ def write(resp)
end

def single_request
resp = @client.get_object(
@params.merge(response_target: @path)
)
params = @params.merge(response_target: @path)
params[:on_chunk_received] = single_part_progress if @progress_callback
resp = @client.get_object(params)

return resp unless @on_checksum_validated

Expand All @@ -178,6 +215,59 @@ def single_request

resp
end

def single_part_progress
proc do |_chunk, bytes_read, total_size|
@progress_callback.call([bytes_read], [total_size], total_size)
end
end

class Part < Struct.new(:part_number, :size, :params)
include Aws::Structure
end

# @api private
class PartList
include Enumerable
def initialize(parts = [])
@parts = parts
@mutex = Mutex.new
end

def shift
@mutex.synchronize { @parts.shift }
end

def size
@mutex.synchronize { @parts.size }
end

def clear!
@mutex.synchronize { @parts.clear }
end

def each(&block)
@mutex.synchronize { @parts.each(&block) }
end
end

# @api private
class MultipartProgress
def initialize(parts, total_size, progress_callback)
@bytes_received = Array.new(parts.size, 0)
@part_sizes = parts.map(&:size)
@total_size = total_size
@progress_callback = progress_callback
end

def call(part_number, bytes_received, total)
# part numbers start at 1
@bytes_received[part_number - 1] = bytes_received
# part size may not be known until we get the first response
@part_sizes[part_number - 1] ||= total
@progress_callback.call(@bytes_received, @part_sizes, @total_size)
end
end
end
end
end
2 changes: 1 addition & 1 deletion gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_upload_part.rb
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def wait_until(options = {}, &block)
# @option options [required, String] :copy_source
# Specifies the source object for the copy operation. You specify the
# value in one of two formats, depending on whether you want to access
# the source object through an [access point][1]:
# the source object through an [access point][1]\:
#
# * For objects not accessed through an access point, specify the name
# of the source bucket and key of the source object, separated by a
Expand Down
Loading

0 comments on commit 8e6aac3

Please sign in to comment.