Load tabular data from cloud storage
Question
I have a set of .parquet
files in the cloud, and want to read them into memory on my remote workers quickly. How can I do this with Metaflow?
Solution
- You can load data from S3 directly to memory very quickly, at tens of gigabits per second or more, using Metaflow’s optimized S3 client, metaflow.S3.
- Once in memory, Parquet data can be decoded efficiently using Apache Arrow.
- The in-memory tables produced by Arrow are interoperable with various modern data tools, so you can use the data in various ways without making additional copies, which speeds up processing and avoids unnecessary memory overhead.
1Cloud to table
Before writing a Metaflow flow, let's see how to use the Metaflow S3 client with Apache Arrow.
The main steps to pay attention to are that we use the metaflow.S3.get_many
function to parallelize the retrieval of partitions of the .parquet
file, loading the bytes into memory on the worker instance, and decoding the bytes so they are useful in a pyarrow.Table
object.
from metaflow import S3
import pyarrow.parquet as pq
import pyarrow
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
# Instantiate Metaflow S3 client context
s3 = S3()
# Set the URL of an S3 bucket containing .parquet files
url = "s3://outerbounds-datasets/ubiquant/investment_ids"
To check metadata about what exists in the S3 url of interest without actually downloading the files, you can use metaflow.s3.list_recursive
.
files = list(s3.list_recursive([url]))
total_size = sum(f.size for f in files) / 1024**3
print("Loading%2.1dGB of data partitioned across %d files." % (total_size, len(files)))
# Download the files in parallel
loaded = s3.get_many([f.url for f in files])
Notice the loaded files are in temporary storage in ./metaflow.s3.foobar
.
print(len(loaded))
print(loaded[0])
print(loaded[0].path)
local_tmp_file_paths = [f.path for f in loaded]
In another set of parallel processes, read the PyArrow tables from bytes and then concatenate them.
The benefits of this workflow scale with the number of processors, available RAM, and I/O throughput of the machine you are loading a table on. Bigger instances can be cheaper in many cases, since they can reduce processing times at a super-linear rate. More on this later in the post.
with ThreadPoolExecutor(max_workers = multiprocessing.cpu_count()) as exe:
tables = exe.map(lambda f: pq.read_table(f, use_threads=False), local_tmp_file_paths)
table = pyarrow.concat_tables(tables)
print("Table has %d rows and%2.1dGB bytes in memory." % (table.shape[0], table.nbytes / 1024**3))
# close s3 connection
s3.close()
2Performance benefits scale with instance size
Using the basic pattern described above, you can now write Metaflow flows that scale this fast data speedup on cloud instances.
In this workflow, we organize the same operations presented in section 1 in a Metaflow flow.
Notice that the data_processing
step is annotated with @batch(..., use_tmpfs=True, ...)
.
The tmpfs
feature extends the resources you request,
because it allows you to use memory on the Batch instance to instantiate a temporary file system;
this makes the cloud-to-table workflow significantly faster and does not require using the local file system to temporarily store the .parquet
bytes.
To reiterate, the benefits of this workflow scale with the number of processors, available RAM, and I/O throughput of the machine you are loading a table on; so you will want to use an instance that can fit your entire Arrow table in memory to get maximal benefits. To get a sense of how fast this workflow can get, check out the Fast Data: Loading Tables From S3 At Lightning Speed post.
from metaflow import Parameter, FlowSpec, step, S3, batch, conda
from time import time
class FastDataProcessing(FlowSpec):
url = Parameter(
"data",
default="s3://outerbounds-datasets/ubiquant/investment_ids",
help="S3 prefix to Parquet files")
@step
def start(self):
self.next(self.data_processing)
@conda(
libraries={
"pandas": "2.0.1",
"pyarrow": "11.0.0"
},
python="3.10.10"
)
@batch(memory=32000, cpu=8, use_tmpfs=True, tmpfs_size=16000)
@step
def data_processing(self):
import pyarrow.parquet as pq
import pyarrow
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
with S3() as s3:
# Check metadata about is in the S3 url of interest.
files = list(s3.list_recursive([self.url]))
total_size = sum(f.size for f in files) / 1024**3
msg = "Loading%2.1dGB of data across %d files."
print(msg % (total_size, len(files)))
# Download N parquet files in parallel.
loaded = s3.get_many([f.url for f in files])
local_tmp_file_paths = [f.path for f in loaded]
# Read N PyArrow tables from bytes and concatenate.
n_threads = multiprocessing.cpu_count()
with ThreadPoolExecutor(max_workers = n_threads) as exe:
tables = exe.map(
lambda f: pq.read_table(f, use_threads=False),
local_tmp_file_paths
)
table = pyarrow.concat_tables(tables)
msg = "Table has %d rows and%2.1dGB bytes in memory."
print(msg % (table.shape[0], table.nbytes / 1024**3))
self.next(self.end)
@step
def end(self):
pass
if __name__ == "__main__":
FastDataProcessing()
python fast_data_processing.py --environment=conda run