diff --git a/tensorflow_transform/tf_metadata/schema_utils.py b/tensorflow_transform/tf_metadata/schema_utils.py index 1d1fb981..3d49b519 100644 --- a/tensorflow_transform/tf_metadata/schema_utils.py +++ b/tensorflow_transform/tf_metadata/schema_utils.py @@ -262,7 +262,16 @@ def _standardize_default_value( assert isinstance(default_value, list), spec.default_value # Convert bytes to string if spec.dtype == tf.string: - default_value = [value.decode('utf-8') for value in default_value] + + # Handle bytes string by trying to decode them (for legacy backwards + # compatibility) and if failed, keep the default value as bytes. + def try_decode(value: bytes) -> Union[str, bytes]: + try: + return value.decode('utf-8') + except UnicodeError: + return value + + default_value = [try_decode(value) for value in default_value] # Unwrap a list with a single element. if len(default_value) == 1: default_value = default_value[0] diff --git a/tensorflow_transform/tf_metadata/schema_utils_test_cases.py b/tensorflow_transform/tf_metadata/schema_utils_test_cases.py index abab79ed..902873ce 100644 --- a/tensorflow_transform/tf_metadata/schema_utils_test_cases.py +++ b/tensorflow_transform/tf_metadata/schema_utils_test_cases.py @@ -753,6 +753,35 @@ 'seq_string_feature': schema_pb2.StringDomain(value=['a', 'b']) } }, + { + 'testcase_name': 'fixed_len_bytes_encoding', + 'ascii_proto': """ + feature { + name: "x" + type: BYTES + value_count { + min: 1 + max: 1 + } + } + tensor_representation_group { + key: "" + value { + tensor_representation { + key: "x" + value { + dense_tensor { + column_name: "x" + shape { dim { size: 1 } } + default_value { bytes_value: "\\xd0" } + } + } + } + } + } + """, + 'feature_spec': {'x': tf.io.FixedLenFeature([1], tf.string, b'\xd0')}, + }, ] INVALID_SCHEMA_PROTOS = [