Source code for petastorm.transform

#  Copyright (c) 2017-2018 Uber Technologies, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from petastorm.unischema import UnischemaField, Unischema

[docs]class TransformSpec(object): def __init__(self, func=None, edit_fields=None, removed_fields=None): """TransformSpec defines a user transformation that is applied to a loaded row on a worker thread/process. The object defines the function (callable) that perform the transform as well as the schema transform: pre-transform-schema to post-transform-schema. ``func`` argument is a callable which takes a row as its parameter and returns a modified row. ``edit_fields`` and ``removed_fields`` define mutating operations performed on the original schema that produce a post-transform schema. ``func`` return value must comply to this post-transform schema. :param func: Optional. A callable. The function is called on the worker thread. It takes a dictionary that complies to the input schema and must return a dictionary that complies to a post-transform schema. User may In case the user wants to only remove certain fields, the user may omit this argument and specify only `remove_fields` argument. :param edit_fields: Optional. A list of 4-tuples with the following fields: ``(name, numpy_dtype, shape, is_nullable)`` :param removed_fields: Optional[list]. A list of field names that will be removed from the original schema. """ self.func = func self.edit_fields = edit_fields or [] self.removed_fields = removed_fields or []
[docs]def transform_schema(schema, transform_spec): """Creates a post-transform given a pre-transform schema and a transform_spec with mutation instructions. :param schema: A pre-transform schema :param transform_spec: a TransformSpec object with mutation instructions. :return: A post-transform schema """ removed_fields = set(transform_spec.removed_fields) unknown_field_names = removed_fields - set(schema.fields.keys()) if unknown_field_names: warnings.warn('remove_fields specified some field names that are not part of the schema. ' 'These field names will be ignored "{}". '.format(', '.join(unknown_field_names))) exclude_fields = {f[0] for f in transform_spec.edit_fields} | removed_fields fields = [v for k, v in schema.fields.items() if k not in exclude_fields] for field_to_edit in transform_spec.edit_fields: edited_unischema_field = UnischemaField(name=field_to_edit[0], numpy_dtype=field_to_edit[1], shape=field_to_edit[2], codec=None, nullable=field_to_edit[3]) fields.append(edited_unischema_field) return Unischema(schema._name + '_transformed', fields)