Source code for petastorm.predicates
# Copyright (c) 2017-2018 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Predicates for petastorm
"""
import abc
import collections.abc
import hashlib
import numpy as np
import six
import sys
[docs]@six.add_metaclass(abc.ABCMeta)
class PredicateBase(object):
""" Base class for row predicates """
[docs] @abc.abstractmethod
def get_fields(self):
pass
[docs] @abc.abstractmethod
def do_include(self, values):
pass
def _string_to_bucket(string, bucket_num):
hash_str = hashlib.md5(string.encode('utf-8')).hexdigest()
return int(hash_str, 16) % bucket_num
[docs]class in_set(PredicateBase):
""" Test if predicate_field value is in inclusion_values set """
def __init__(self, inclusion_values, predicate_field):
self._inclusion_values = set(inclusion_values)
self._predicate_field = predicate_field
[docs] def get_fields(self):
return {self._predicate_field}
[docs] def do_include(self, values):
return values[self._predicate_field] in self._inclusion_values
[docs]class in_intersection(PredicateBase):
""" Test if predicate_field list contain at least one value from inclusion_values set """
def __init__(self, inclusion_values, _predicate_field):
self._inclusion_values = list(inclusion_values)
self._predicate_field = _predicate_field
[docs] def get_fields(self):
return {self._predicate_field}
[docs] def do_include(self, values):
if not isinstance(values[self._predicate_field], collections.abc.Iterable):
raise ValueError('Predicate field should have iterable type')
return any(np.in1d(values[self._predicate_field], self._inclusion_values))
[docs]class in_lambda(PredicateBase):
""" Wrap up custom function to be used as a predicate
example: in_lambda(['labels_object_roles'], lambda labels_object_roles : len(labels_object_roles) > 3)
"""
def __init__(self, predicate_fields, predicate_func, state_arg=None):
"""
:param predicate_fields: list of fields to be used in predicate
:param predicate_func: predicate function
example: lambda labels_object_roles : len(labels_object_roles) > 3
:param state_arg: additional object to keep function state. it will be passed to
predicate_func after fields arguments ONLY if it is not None
"""
if not isinstance(predicate_fields, list):
raise ValueError('Predicate fields should be a list')
self._predicate_fields = predicate_fields
self._predicate_func = predicate_func
self._state_arg = state_arg
[docs] def get_fields(self):
return set(self._predicate_fields)
[docs] def do_include(self, values):
args = [values[field] for field in self._predicate_fields]
if self._state_arg is not None:
args.append(self._state_arg)
return self._predicate_func(*args)
[docs]class in_negate(PredicateBase):
""" A predicate used to negate another predicate. """
def __init__(self, predicate):
if not isinstance(predicate, PredicateBase):
raise ValueError('Predicate is nor derived from PredicateBase')
self._predicate = predicate
[docs] def get_fields(self):
return self._predicate.get_fields()
[docs] def do_include(self, values):
return not self._predicate.do_include(values)
[docs]class in_reduce(PredicateBase):
""" A predicate used to aggregate other predicates using any reduce logical operation."""
def __init__(self, predicate_list, reduce_func):
""" predicate_list: list of predicates
reduce_func: function to aggregate result of all predicates in the list
e.g. all() will implements logical 'And', any() implements logical 'Or'
"""
check_list = [isinstance(p, PredicateBase) for p in predicate_list]
if not all(check_list):
raise ValueError('Predicate is nor derived from PredicateBase')
self._predicate_list = predicate_list
self._reduce_func = reduce_func
[docs] def get_fields(self):
fields = set()
for p in self._predicate_list:
fields |= p.get_fields()
return fields
[docs] def do_include(self, values):
include_list = [p.do_include(values) for p in self._predicate_list]
return self._reduce_func(include_list)
[docs]class in_pseudorandom_split(PredicateBase):
""" Split dataset according to a split list based on volume_guid.
The split is pseudorandom (can not supply the seed yet), i.e. the split outcome is always the same.
Split is performed by hashing volume_guid uniformly to 0:1 range and returning part of full dataset
which was hashed in given sub-range
Example:
'split_list = [0.5, 0.2, 0.3]' - dataset will be split on three subsets in proportion
subset 1: 0.5 of log data
subset 2: 0.2 of log data
subset 3: 0.3 of log data
Note, split is not exact, so avoid small fraction (e.g. 0.001) to avoid empty sets
"""
def __init__(self, fraction_list, subset_index, predicate_field):
""" split_list: a list of log fractions (real numbers in range [0:1])
subset_index: define which subset will be used by the Reader
"""
if subset_index >= len(fraction_list):
raise ValueError('subset_index is out of range')
self._predicate_field = predicate_field
# build CDF
subsets_high_borders = [sum(fraction_list[:i + 1]) for i in range(len(fraction_list))]
if subset_index:
fraction_low = subsets_high_borders[subset_index - 1]
else:
fraction_low = 0
fraction_high = subsets_high_borders[subset_index]
self._bucket_low = fraction_low * (sys.maxsize - 1)
self._bucket_high = fraction_high * (sys.maxsize - 1)
[docs] def get_fields(self):
return {self._predicate_field}
[docs] def do_include(self, values):
if self._predicate_field not in values.keys():
raise ValueError('Tested values does not have split key: %s' % self._predicate_field)
bucket_idx = _string_to_bucket(str(values[self._predicate_field]), sys.maxsize)
return self._bucket_low <= bucket_idx < self._bucket_high