Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use yaml's !!binary string for strings as necessary #22298

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions bindings/pydrake/common/test/yaml_typed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ class StringStruct:
value: str = "nominal_string"


@dc.dataclass
class PathStruct:
value: Path = "/path/to/nowhere"


@dc.dataclass
class AllScalarsStruct:
some_bool: bool = False
some_float: float = nan
some_int: int = 11
some_str: str = "nominal_string"
some_path: Path = "/path/to/nowhere"


@dc.dataclass
Expand Down Expand Up @@ -230,18 +236,58 @@ def test_read_float_missing(self, *, options):
yaml_load_typed(schema=FloatStruct, data="{}",
**options)

@run_with_multiple_values(_all_typed_read_options())
def test_read_path(self, *, options):
cases = [
("no_directory.txt", None),
("/absolute/path/file.txt", None),
("\"\"", "."),
("!!str", "."),
(".", "."),
("/non_lexical//path", "/non_lexical/path"),
("/quoted\"/path", None),
("\"1234\"", "1234"),
("'1234'", "1234"),
("!!str 1234", "1234"),
("\"1234.5\"", "1234.5"),
("'1234.5'", "1234.5"),
("!!str 1234.5", "1234.5"),
("\"true\"", "true"),
("'true'", "true"),
("!!str true", "true"),
]
for value, maybe_expected in cases:
data = f"value: {value}"
x = yaml_load_typed(schema=PathStruct, data=data, **options)
expected = value if maybe_expected is None else maybe_expected
self.assertEqual(x.value, Path(expected))

@run_with_multiple_values(_all_typed_read_options())
def test_read_path_missing(self, *, options):
if options["allow_schema_with_no_yaml"]:
default_value = PathStruct()
x = yaml_load_typed(schema=PathStruct, data="{}",
**options)
self.assertEqual(x.value, default_value.value, msg=repr(x.value))
else:
with self.assertRaisesRegex(RuntimeError, ".*missing.*"):
yaml_load_typed(schema=PathStruct, data="{}",
**options)

@run_with_multiple_values(_all_typed_read_options())
def test_read_all_scalars(self, *, options):
data = dedent("""
some_bool: true
some_float: 101.0
some_int: 102
some_path: /alternative/path
some_str: foo
""")
x = yaml_load_typed(schema=AllScalarsStruct, data=data, **options)
self.assertEqual(x.some_bool, True)
self.assertEqual(x.some_float, 101.0)
self.assertEqual(x.some_int, 102)
self.assertEqual(x.some_path, Path("/alternative/path"))
self.assertEqual(x.some_str, "foo")

@run_with_multiple_values(_all_typed_read_options())
Expand Down Expand Up @@ -621,6 +667,34 @@ def test_load_string(self):
result = yaml_load_typed(schema=StringStruct, data=data)
self.assertEqual(result.value, "some_value")

data = dedent("""
value:
!!binary Tm9uUHJpbnRhYmxlAw==
""")
result = yaml_load_typed(schema=StringStruct, data=data)
self.assertEqual(result.value, "NonPrintable\x03")

data = dedent("""
value: !!binary |
Tm9uUHJpbnRhYmxlAw==
""")
result = yaml_load_typed(schema=StringStruct, data=data)
self.assertEqual(result.value, "NonPrintable\x03")

data = dedent("""
value:
!!str Tm9uUHJpbnRhYmxlAw==
""")
result = yaml_load_typed(schema=StringStruct, data=data)
self.assertEqual(result.value, "Tm9uUHJpbnRhYmxlAw==")

data = dedent("""
value:
"\\tall\\nprintable\\r char"
""")
result = yaml_load_typed(schema=StringStruct, data=data)
self.assertEqual(result.value, "\tall\nprintable\r char")

def test_load_string_child_name(self):
data = dedent("""
some_child_name:
Expand Down Expand Up @@ -731,12 +805,27 @@ def test_write_string(self):
cases = [
("a", "a"),
("1", "'1'"),
("NonPrintable\x03", "!!binary |\n Tm9uUHJpbnRhYmxlAw=="),
]
for value, expected_str in cases:
actual_doc = yaml_dump_typed(StringStruct(value=value))
expected_doc = f"value: {expected_str}\n"
self.assertEqual(actual_doc, expected_doc)

def test_write_path(self):
cases = [
(Path(""), "."),
(Path("/absolute/path"), "/absolute/path"),
(Path("relative/path"), "relative/path"),
(Path("1234"), "'1234'"),
(Path("1234.5"), "'1234.5'"),
(Path("true"), "'true'"),
]
for value, expected_str in cases:
actual_doc = yaml_dump_typed(PathStruct(value=value))
expected_doc = f"value: {expected_str}\n"
self.assertEqual(actual_doc, expected_doc)

def test_write_list_plain(self):
# When the vector items are simple YAML scalars, we should use "flow"
# style, where they all appear on a single line.
Expand Down
44 changes: 39 additions & 5 deletions bindings/pydrake/common/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
import types
import typing
from pathlib import Path

import numpy as np
import yaml
Expand Down Expand Up @@ -140,6 +141,19 @@ def _represent_dict(self, data):
else:
return super().represent_dict(data)

def _represent_str(self, s):
"""Handle non-printable strings."""
is_binary = False
for c in s:
if (not c.isprintable()
and not (c == '\n' or c == '\r' or c == '\t')):
is_binary = True
break
if is_binary:
return self.represent_binary(s.encode("utf-8"))
else:
return self.represent_str(s)

def _represent_undefined(self, data):
if getattr(type(data), "__module__", "").startswith("pydrake"):
raise yaml.representer.RepresenterError(
Expand All @@ -150,6 +164,7 @@ def _represent_undefined(self, data):

_SchemaDumper.add_representer(None, _SchemaDumper._represent_undefined)
_SchemaDumper.add_representer(dict, _SchemaDumper._represent_dict)
_SchemaDumper.add_representer(str, _SchemaDumper._represent_str)
_SchemaDumper.add_representer(
_SchemaDumper.ExplicitScalar,
_SchemaDumper._represent_explicit_scalar)
Expand All @@ -165,6 +180,9 @@ class _DrakeFlowSchemaDumper(_SchemaDumper):

- For mappings: If there are no children, then formats this map onto a
single line; otherwise, format over multiple lines.

- For strings: if the string contains non-printable characters, it is
emitted as a binary string.
"""

def serialize_node(self, node, parent, index):
Expand Down Expand Up @@ -292,7 +310,11 @@ def _merge_yaml_dict_item_into_target(*, options, name, yaml_value,
raise RuntimeError(
f"Expected a {value_schema} value for '{name}' but instead got"
f" non-scalar yaml data of type {type(yaml_value)}")
new_value = value_schema(yaml_value)
if value_schema == str and isinstance(yaml_value, bytes):
# This had the !!binary tag.
new_value = yaml_value.decode('utf-8')
else:
new_value = value_schema(yaml_value)
setter(new_value)
return

Expand All @@ -315,6 +337,12 @@ def _merge_yaml_dict_item_into_target(*, options, name, yaml_value,
value_schema=nested_optional_type)
return

# Handle pathlib.Path.
if value_schema == Path:
new_value = Path(yaml_value)
setter(new_value)
return

# Handle NumPy types.
if value_schema == np.ndarray:
new_value = np.array(yaml_value, dtype=float)
Expand Down Expand Up @@ -588,10 +616,12 @@ def _yaml_dump_get_attribute(*, obj, name):
def _yaml_dump_typed_item(*, obj, schema):
"""Given an object ``obj`` and its type ``schema``, returns the plain YAML
object that should be serialized. Objects that are already primitive types
(str, float, etc.) are returned unchanged. Bare collection types (List and
Mapping) are processed recursively. Structs (dataclasses) are processed
using their schema. The result is "plain" in the sense that's it's always
just a tree of primitives, lists, and dicts -- no user-defined types.
(str, float, etc.) are returned mostly unchanged; strings with unprintable
characters (excluding whitespace) will be translated to byte strings. Bare
collection types (List and Mapping) are processed recursively. Structs
(dataclasses) are processed using their schema. The result is "plain" in
the sense that's it's always just a tree of primitives, lists, and dicts --
no user-defined types.
"""
assert schema is not None

Expand Down Expand Up @@ -665,6 +695,10 @@ def _yaml_dump_typed_item(*, obj, schema):
result["_tag"] = "!" + class_name
return result

# Handle pathlib.Path types.
if schema == Path:
return str(obj)

# Handle NumPy types.
if schema == np.ndarray:
# TODO(jwnimmer-tri) We should use the numpy.typing module here to
Expand Down
4 changes: 4 additions & 0 deletions common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,8 @@ drake_cc_googletest(
name = "file_source_test",
deps = [
":file_source",
"//common/test_utilities:expect_throws_message",
"//common/yaml:yaml_io",
],
)

Expand Down Expand Up @@ -1099,6 +1101,8 @@ drake_cc_googletest(
":find_resource",
":memory_file",
":temp_directory",
"//common/test_utilities:expect_throws_message",
"//common/yaml:yaml_io",
],
)

Expand Down
10 changes: 10 additions & 0 deletions common/memory_file.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ class MemoryFile final {
any number less than or equal to zero. */
std::string to_string(int contents_limit = 100) const;

/** Serialization stub.

%MemoryFile cannot actually be serialized yet. Attempting to do will throw.
This stub merely permits FileSource to be serialized (when it contains a
`std::filesystem::path`). */
template <typename Archive>
void Serialize(Archive* a) {
throw std::runtime_error("Serialization for MemoryFile not yet supported.");
}

private:
reset_after_move<std::string> contents_;

Expand Down
42 changes: 42 additions & 0 deletions common/test/file_source_test.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
#include "drake/common/file_source.h"

#include <filesystem>

#include <gtest/gtest.h>

#include "drake/common/memory_file.h"
#include "drake/common/name_value.h"
#include "drake/common/test_utilities/expect_throws_message.h"
#include "drake/common/yaml/yaml_io.h"

namespace drake {
namespace {

namespace fs = std::filesystem;

/* We want to lock down the idea that the default value is an empty path. */
GTEST_TEST(FileSourceTest, DefaultPath) {
FileSource dut;
EXPECT_TRUE(std::holds_alternative<std::filesystem::path>(dut));
}

GTEST_TEST(FileSourceTest, ToString) {
EXPECT_EQ(to_string(FileSource("a/b/c")), "\"a/b/c\"");
EXPECT_EQ(fmt::to_string(FileSource("a/b/c")), "\"a/b/c\"");
Expand All @@ -13,5 +28,32 @@ GTEST_TEST(FileSourceTest, ToString) {
EXPECT_EQ(to_string(FileSource(file)), file.to_string());
EXPECT_EQ(fmt::to_string(FileSource(file)), file.to_string());
}

/* Quick and dirty struct that has a FileSource and can be serialized. */
struct HasFileSource {
template <typename Archive>
void Serialize(Archive* archive) {
archive->Visit(DRAKE_NVP(source));
}
FileSource source;
};

/* The path value gets (de)serialized. */
GTEST_TEST(FileSourceTest, SerializePath) {
const HasFileSource dut{.source = fs::path("/some/path")};
const std::string y = yaml::SaveYamlString(dut);
const auto decoded = yaml::LoadYamlString<HasFileSource>(y);
ASSERT_TRUE(std::holds_alternative<fs::path>(decoded.source));
EXPECT_EQ(std::get<fs::path>(dut.source), std::get<fs::path>(decoded.source));
}

/* The MemoryFile value simply throws (see MemoryFile implementation). */
GTEST_TEST(FileSourceTest, SerializeMemoryFile) {
const HasFileSource dut{.source = MemoryFile("stuff", ".ext", "hint")};
DRAKE_EXPECT_THROWS_MESSAGE(
yaml::SaveYamlString(dut),
"Serialization for MemoryFile not yet supported.");
}

} // namespace
} // namespace drake
10 changes: 10 additions & 0 deletions common/test/memory_file_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include "drake/common/find_resource.h"
#include "drake/common/temp_directory.h"
#include "drake/common/test_utilities/expect_throws_message.h"
#include "drake/common/yaml/yaml_io.h"

namespace drake {
namespace {
Expand Down Expand Up @@ -109,5 +111,13 @@ GTEST_TEST(MemoryFileTest, ToString) {
EXPECT_THAT(fmt::to_string(file), testing::HasSubstr("\"0123456789\""));
}

/* Serialization compiles but throws. */
GTEST_TEST(MemoryFileTest, SerializationThrows) {
const MemoryFile dut("stuff", ".ext", "hint");
DRAKE_EXPECT_THROWS_MESSAGE(
yaml::SaveYamlString(dut),
"Serialization for MemoryFile not yet supported.");
}

} // namespace
} // namespace drake
Loading