Source code for petastorm.spark_utils
# Copyright (c) 2017-2018 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A set of Spark specific helper functions for the petastorm dataset"""
from six.moves.urllib.parse import urlparse
from petastorm import utils
from petastorm.etl.dataset_metadata import get_schema_from_dataset_url
from petastorm.fs_utils import FilesystemResolver
[docs]def dataset_as_rdd(dataset_url, spark_session, schema_fields=None, hdfs_driver='libhdfs3'):
"""
Retrieve a spark rdd for a given petastorm dataset
:param dataset_url: A string for the dataset url (e.g. hdfs:///path/to/dataset)
:param spark_session: A spark session
:param schema_fields: list of unischema fields to subset, or None to read all fields.
:param hdfs_driver: A string denoting the hdfs driver to use (if using a dataset on hdfs). Current choices are
libhdfs (java through JNI) or libhdfs3 (C++)
:return: A rdd of dictionary records from the dataset
"""
schema = get_schema_from_dataset_url(dataset_url, hdfs_driver=hdfs_driver)
dataset_url_parsed = urlparse(dataset_url)
resolver = FilesystemResolver(dataset_url_parsed, spark_session.sparkContext._jsc.hadoopConfiguration(),
hdfs_driver=hdfs_driver)
dataset_df = spark_session.read.parquet(resolver.get_dataset_path())
if schema_fields is not None:
# If wanting a subset of fields, create the schema view and run a select on those fields
schema = schema.create_schema_view(schema_fields)
field_names = [field.name for field in schema_fields]
dataset_df = dataset_df.select(*field_names)
dataset_rows = dataset_df.rdd \
.map(lambda row: utils.decode_row(row.asDict(), schema)) \
.map(lambda record: schema.make_namedtuple(**record))
return dataset_rows