Skip to content

Commit

Permalink
Allow customization of TFDV options when running within TFT.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 288032664
  • Loading branch information
paulgc authored and tensorflow-extended-team committed Jan 3, 2020
1 parent 060a49c commit d5aecd1
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tfx/components/transform/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from tfx import types
from tfx.components.base import base_executor
from tfx.components.transform import labels
from tfx.components.transform import stats_options as transform_stats_options
from tfx.components.transform import messages
from tfx.components.util import value_utils
from tfx.types import artifact_utils
Expand Down Expand Up @@ -487,6 +488,7 @@ def _GenerateStats(
pcoll: beam.pvalue.PCollection,
stats_output_path: Text,
schema: schema_pb2.Schema,
stats_options: tfdv.StatsOptions,
# TODO(b/115684207): Remove this and all related code.
use_tfdv=True,
# TODO(b/115684207): Remove this and all related code.
Expand All @@ -498,6 +500,8 @@ def _GenerateStats(
pcoll: PCollection of examples.
stats_output_path: path where statistics is written to.
schema: schema.
stats_options: An instance of `tfdv.StatsOptions()` used when computing
statistics.
use_tfdv: whether use TFDV for computing statistics.
examples_are_serialized: Unused.
Expand All @@ -507,11 +511,11 @@ def _GenerateStats(
assert use_tfdv
del examples_are_serialized # Unused

stats_options.schema = schema
# pylint: disable=no-value-for-parameter
return (
pcoll
| 'GenerateStatistics' >> tfdv.GenerateStatistics(
tfdv.StatsOptions(schema=schema))
| 'GenerateStatistics' >> tfdv.GenerateStatistics(stats_options)
| 'WriteStats' >> Executor._WriteStats(stats_output_path))

# TODO(zhuo): Obviate this once TFXIO is used.
Expand Down Expand Up @@ -1102,6 +1106,8 @@ def _RunBeamImpl(self, inputs: Mapping[Text, Any],
| 'FromSerializedToArrowTables[{}]'.format(infix)
>> self._FromSerializedToArrowTables(schema_proto))

pre_transform_stats_options = (
transform_stats_options.get_pre_transform_stats_options())
([
dataset.standardized if stats_use_tfdv else dataset.serialized
for dataset in analyze_data_list
Expand All @@ -1111,6 +1117,7 @@ def _RunBeamImpl(self, inputs: Mapping[Text, Any],
self._GenerateStats(
pre_transform_feature_stats_path,
schema_proto,
stats_options=pre_transform_stats_options,
use_tfdv=stats_use_tfdv,
examples_are_serialized=True))

Expand Down Expand Up @@ -1163,6 +1170,8 @@ def _RunBeamImpl(self, inputs: Mapping[Text, Any],
transform_output_path,
tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH)

post_transform_stats_options = (
transform_stats_options.get_post_transform_stats_options())
([(dataset.transformed_and_standardized
if stats_use_tfdv else dataset.transformed_and_encoded)
for dataset in transform_data_list]
Expand All @@ -1171,6 +1180,7 @@ def _RunBeamImpl(self, inputs: Mapping[Text, Any],
self._GenerateStats(
post_transform_feature_stats_path,
transformed_schema_proto,
stats_options=post_transform_stats_options,
use_tfdv=stats_use_tfdv))

if per_set_stats_output_paths:
Expand All @@ -1186,6 +1196,7 @@ def _RunBeamImpl(self, inputs: Mapping[Text, Any],
data | 'GenerateStats[{}]'.format(infix) >> self._GenerateStats(
dataset.stats_output_path,
transformed_schema_proto,
stats_options=post_transform_stats_options,
use_tfdv=stats_use_tfdv)

if materialize_output_paths:
Expand Down
50 changes: 50 additions & 0 deletions tfx/components/transform/stats_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Lint as: python3
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stats Options for customizing TFDV in TFT."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow_data_validation as tfdv


# An instance of `tfdv.StatsOptions()` used when computing pre-transform
# statistics. If not specified, default options are used.
_PRE_TRANSFORM_STATS_OPTIONS = None

# An instance of `tfdv.StatsOptions()` used when computing post-transform
# statistics. If not specified, default options are used.
_POST_TRANSFORM_STATS_OPTIONS = None


def set_pre_transform_stats_options(stats_options: tfdv.StatsOptions):
global _PRE_TRANSFORM_STATS_OPTIONS
_PRE_TRANSFORM_STATS_OPTIONS = stats_options


def set_post_transform_stats_options(stats_options: tfdv.StatsOptions):
global _POST_TRANSFORM_STATS_OPTIONS
_POST_TRANSFORM_STATS_OPTIONS = stats_options


def get_pre_transform_stats_options() -> tfdv.StatsOptions:
return (tfdv.StatsOptions() if _PRE_TRANSFORM_STATS_OPTIONS is None
else _PRE_TRANSFORM_STATS_OPTIONS)


def get_post_transform_stats_options() -> tfdv.StatsOptions:
return (tfdv.StatsOptions() if _POST_TRANSFORM_STATS_OPTIONS is None
else _POST_TRANSFORM_STATS_OPTIONS)

0 comments on commit d5aecd1

Please sign in to comment.