Source code for petastorm.transform
# Copyright (c) 2017-2018 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from petastorm.unischema import UnischemaField, Unischema
[docs]def edit_field(name, numpy_dtype, shape, nullable=False):
"""
A helper method to create the 4-tuples (name, numpy_dtype, shape, is_nullable)
used in the `edit_fields` of `TransformSpec`.
"""
return name, numpy_dtype, shape, nullable
[docs]class TransformSpec(object):
def __init__(self, func=None, edit_fields=None, removed_fields=None, selected_fields=None):
"""TransformSpec defines a user transformation that is applied to a loaded row on a worker thread/process.
The object defines the function (callable) that perform the transform as well as the
schema transform: pre-transform-schema to post-transform-schema.
``func`` argument is a callable which takes a row as its parameter and returns a modified row.
``edit_fields`` and ``removed_fields`` define mutating operations performed on the original schema that
produce a post-transform schema. ``func`` return value must comply to this post-transform schema.
:param func: Optional. A callable. The function is called on the worker thread. It takes a dictionary that
complies to the input schema and must return a dictionary that complies to a post-transform schema. User may
In case the user wants to only remove certain fields, the user may omit this argument and specify only
`remove_fields` argument.
:param edit_fields: Optional. A list of 4-tuples with the following fields:
``(name, numpy_dtype, shape, is_nullable)``.
:param removed_fields: Optional[list]. A list of field names that will be removed from the original schema.
:param selected_fields: Optional[list]. A list of field names specify the fields to be selected.
If selected_fields specified, The reader schema will preserve the field order in selected_fields.
Note: For param `removed_fields` and `selected_fields`, user can only specify one of them.
"""
self.func = func
self.edit_fields = edit_fields or []
if removed_fields is not None and selected_fields is not None:
raise ValueError('User can only specify one of removed_fields and selected_fields in TransformSpec.')
self.removed_fields = removed_fields or []
self.selected_fields = selected_fields
[docs]def transform_schema(schema, transform_spec):
"""Creates a post-transform given a pre-transform schema and a transform_spec with mutation instructions.
:param schema: A pre-transform schema
:param transform_spec: a TransformSpec object with mutation instructions.
:return: A post-transform schema
"""
removed_fields = set(transform_spec.removed_fields)
unknown_field_names = removed_fields - set(schema.fields.keys())
if unknown_field_names:
warnings.warn('remove_fields specified some field names that are not part of the schema. '
'These field names will be ignored "{}". '.format(', '.join(unknown_field_names)))
exclude_fields = {f[0] for f in transform_spec.edit_fields} | removed_fields
fields = [v for k, v in schema.fields.items() if k not in exclude_fields]
for field_to_edit in transform_spec.edit_fields:
edited_unischema_field = UnischemaField(name=field_to_edit[0], numpy_dtype=field_to_edit[1],
shape=field_to_edit[2], codec=None, nullable=field_to_edit[3])
fields.append(edited_unischema_field)
if transform_spec.selected_fields is not None:
unknown_field_names = set(transform_spec.selected_fields) - set(f.name for f in fields)
if unknown_field_names:
warnings.warn('selected_fields specified some field names that are not part of the schema. '
'These field names will be ignored "{}". '.format(', '.join(unknown_field_names)))
fields = [f for f in fields if f.name in transform_spec.selected_fields]
fields = sorted(fields, key=lambda f: transform_spec.selected_fields.index(f.name))
return Unischema(schema._name + '_transformed', fields)