D3303: cborutil: implement support for streaming encoding, bytestring decoding
indygreg (Gregory Szorc)
phabricator at mercurial-scm.org
Fri Apr 13 21:31:05 UTC 2018
indygreg updated this revision to Diff 8152.
indygreg edited the summary of this revision.
indygreg retitled this revision from "cborutil: implement support for indefinite length CBOR types" to "cborutil: implement support for streaming encoding, bytestring decoding".
REPOSITORY
rHG Mercurial
CHANGES SINCE LAST UPDATE
https://phab.mercurial-scm.org/D3303?vs=8104&id=8152
REVISION DETAIL
https://phab.mercurial-scm.org/D3303
AFFECTED FILES
contrib/import-checker.py
mercurial/utils/cborutil.py
tests/test-cbor.py
CHANGE DETAILS
diff --git a/tests/test-cbor.py b/tests/test-cbor.py
new file mode 100644
--- /dev/null
+++ b/tests/test-cbor.py
@@ -0,0 +1,182 @@
+from __future__ import absolute_import
+
+import io
+import unittest
+
+from mercurial.thirdparty import (
+ cbor,
+)
+from mercurial.utils import (
+ cborutil,
+)
+
+class BytestringTests(unittest.TestCase):
+ def testsimple(self):
+ self.assertEqual(
+ list(cborutil.streamencode(b'foobar')),
+ [b'\x46', b'foobar'])
+
+ self.assertEqual(
+ cbor.loads(b''.join(cborutil.streamencode(b'foobar'))),
+ b'foobar')
+
+ def testlong(self):
+ source = b'x' * 1048576
+
+ self.assertEqual(
+ cbor.loads(b''.join(cborutil.streamencode(source))),
+ source)
+
+ def testfromiter(self):
+ # This is the example from RFC 7049 Section 2.2.2.
+ source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
+
+ self.assertEqual(
+ list(cborutil.streamencodebytestringfromiter(source)),
+ [
+ b'\x5f',
+ b'\x44',
+ b'\xaa\xbb\xcc\xdd',
+ b'\x43',
+ b'\xee\xff\x99',
+ b'\xff',
+ ])
+
+ dest = b''.join(cborutil.streamencodebytestringfromiter(source))
+ self.assertEqual(cbor.loads(dest), b''.join(source))
+
+ def testfromiterlarge(self):
+ source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
+
+ dest = b''.join(cborutil.streamencodebytestringfromiter(source))
+
+ self.assertEqual(cbor.loads(dest), b''.join(source))
+
+ def testindefinite(self):
+ source = b'\x00\x01\x02\x03' + b'\xff' * 16384
+
+ it = cborutil.streamencodeindefinitebytestring(source, chunksize=2)
+
+ self.assertEqual(next(it), b'\x5f')
+ self.assertEqual(next(it), b'\x42')
+ self.assertEqual(next(it), b'\x00\x01')
+ self.assertEqual(next(it), b'\x42')
+ self.assertEqual(next(it), b'\x02\x03')
+ self.assertEqual(next(it), b'\x42')
+ self.assertEqual(next(it), b'\xff\xff')
+
+ dest = b''.join(cborutil.streamencodeindefinitebytestring(
+ source, chunksize=42))
+ self.assertEqual(cbor.loads(dest), b''.join(source))
+
+ def testreadtoiter(self):
+ source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
+
+ it = cborutil.readindefinitebytestringtoiter(source)
+ self.assertEqual(next(it), b'\xaa\xbb\xcc\xdd')
+ self.assertEqual(next(it), b'\xee\xff\x99')
+
+ with self.assertRaises(StopIteration):
+ next(it)
+
+class IntTests(unittest.TestCase):
+ def testsmall(self):
+ self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
+ self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
+ self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
+ self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
+ self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
+
+ def testnegativesmall(self):
+ self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
+ self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
+ self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
+ self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
+ self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
+
+ def testrange(self):
+ for i in range(-70000, 70000, 10):
+ self.assertEqual(
+ b''.join(cborutil.streamencode(i)),
+ cbor.dumps(i))
+
+class ArrayTests(unittest.TestCase):
+ def testempty(self):
+ self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
+ self.assertEqual(cbor.loads(b''.join(cborutil.streamencode([]))), [])
+
+ def testbasic(self):
+ source = [b'foo', b'bar', 1, -10]
+
+ self.assertEqual(list(cborutil.streamencode(source)), [
+ b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'])
+
+ def testemptyfromiter(self):
+ self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
+ b'\x9f\xff')
+
+ def testfromiter1(self):
+ source = [b'foo']
+
+ self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [
+ b'\x9f',
+ b'\x43', b'foo',
+ b'\xff',
+ ])
+
+
+ dest = b''.join(cborutil.streamencodearrayfromiter(source))
+ self.assertEqual(cbor.loads(dest), source)
+
+class MapTests(unittest.TestCase):
+ def testempty(self):
+ self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
+ self.assertEqual(cbor.loads(b''.join(cborutil.streamencode({}))), {})
+
+ def testemptyindefinite(self):
+ self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
+ b'\xbf', b'\xff'])
+
+ self.assertEqual(
+ cbor.loads(b''.join(cborutil.streamencodemapfromiter([]))),
+ {})
+
+ def testone(self):
+ source = {b'foo': b'bar'}
+ self.assertEqual(list(cborutil.streamencode(source)), [
+ b'\xa1', b'\x43', b'foo', b'\x43', b'bar'])
+
+ self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
+ source)
+
+ def testmultiple(self):
+ source = {
+ b'foo': b'bar',
+ b'baz': b'value1',
+ }
+
+ self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
+ source)
+
+ self.assertEqual(
+ cbor.loads(b''.join(
+ cborutil.streamencodemapfromiter(source.items()))),
+ source)
+
+ def testcomplex(self):
+ source = {
+ b'key': 1,
+ 2: -10,
+ }
+
+ self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
+ source)
+
+ self.assertEqual(
+ cbor.loads(b''.join(
+ cborutil.streamencodemapfromiter(source.items()))),
+ source)
+
+if __name__ == '__main__':
+ import silenttestrunner
+ silenttestrunner.main(__name__)
diff --git a/mercurial/utils/cborutil.py b/mercurial/utils/cborutil.py
new file mode 100644
--- /dev/null
+++ b/mercurial/utils/cborutil.py
@@ -0,0 +1,228 @@
+# cborutil.py - CBOR extensions
+#
+# Copyright 2018 Gregory Szorc <gregory.szorc at gmail.com>
+#
+# This software may be used and distributed according to the terms of the
+# GNU General Public License version 2 or any later version.
+
+from __future__ import absolute_import
+
+import struct
+
+from ..thirdparty.cbor.cbor2 import (
+ decoder as decodermod,
+)
+
+# Very short very of RFC 7049...
+#
+# Each item begins with a byte. The 3 high bits of that byte denote the
+# "major type." The lower 5 bits denote the "subtype." Each major type
+# has its own encoding mechanism.
+#
+# Most types have lengths. However, bytestring, string, array, and map
+# can be indefinite length. These are denotes by a subtype with value 31.
+# Sub-components of those types then come afterwards and are terminated
+# by a "break" byte.
+
+MAJOR_TYPE_UINT = 0
+MAJOR_TYPE_NEGINT = 1
+MAJOR_TYPE_BYTESTRING = 2
+MAJOR_TYPE_STRING = 3
+MAJOR_TYPE_ARRAY = 4
+MAJOR_TYPE_MAP = 5
+MAJOR_TYPE_SEMANTIC = 6
+MAJOR_TYPE_SPECIAL = 7
+
+SUBTYPE_MASK = 0b00011111
+
+SUBTYPE_INDEFINITE = 31
+
+# Indefinite types begin with their major type ORd with information value 31.
+BEGIN_INDEFINITE_BYTESTRING = struct.pack(
+ r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
+BEGIN_INDEFINITE_ARRAY = struct.pack(
+ r'>B', MAJOR_TYPE_ARRAY << 5 | SUBTYPE_INDEFINITE)
+BEGIN_INDEFINITE_MAP = struct.pack(
+ r'>B', MAJOR_TYPE_MAP << 5 | SUBTYPE_INDEFINITE)
+
+ENCODED_LENGTH_1 = struct.Struct(r'>B')
+ENCODED_LENGTH_2 = struct.Struct(r'>BB')
+ENCODED_LENGTH_3 = struct.Struct(r'>BH')
+ENCODED_LENGTH_4 = struct.Struct(r'>BL')
+ENCODED_LENGTH_5 = struct.Struct(r'>BQ')
+
+# The break ends an indefinite length item.
+BREAK = b'\xff'
+BREAK_INT = 255
+
+def encodelength(majortype, length):
+ """Obtain a value encoding the major type and its length."""
+ if length < 24:
+ return ENCODED_LENGTH_1.pack(majortype << 5 | length)
+ elif length < 256:
+ return ENCODED_LENGTH_2.pack(majortype << 5 | 24, length)
+ elif length < 65536:
+ return ENCODED_LENGTH_3.pack(majortype << 5 | 25, length)
+ elif length < 4294967296:
+ return ENCODED_LENGTH_4.pack(majortype << 5 | 26, length)
+ else:
+ return ENCODED_LENGTH_5.pack(majortype << 5 | 27, length)
+
+def streamencodebytestring(v):
+ yield encodelength(MAJOR_TYPE_BYTESTRING, len(v))
+ yield v
+
+def streamencodebytestringfromiter(it):
+ """Convert an iterator of chunks to an indefinite bytestring.
+
+ Given an input that is iterable and each element in the iterator is
+ representable as bytes, emit an indefinite length bytestring.
+ """
+ yield BEGIN_INDEFINITE_BYTESTRING
+
+ for chunk in it:
+ yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
+ yield chunk
+
+ yield BREAK
+
+def streamencodeindefinitebytestring(source, chunksize=65536):
+ """Given a large source buffer, emit as an indefinite length bytestring.
+
+ This is a generator of chunks constituting the encoded CBOR data.
+ """
+ yield BEGIN_INDEFINITE_BYTESTRING
+
+ i = 0
+ l = len(source)
+
+ while True:
+ chunk = source[i:i + chunksize]
+ i += len(chunk)
+
+ yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
+ yield chunk
+
+ if i >= l:
+ break
+
+ yield BREAK
+
+def streamencodeint(v):
+ if v >= 18446744073709551616 or v < -18446744073709551616:
+ raise ValueError('big integers not supported')
+
+ if v >= 0:
+ yield encodelength(MAJOR_TYPE_UINT, v)
+ else:
+ yield encodelength(MAJOR_TYPE_NEGINT, abs(v) - 1)
+
+def streamencodearray(l):
+ """Encode a known size iterable to an array."""
+
+ yield encodelength(MAJOR_TYPE_ARRAY, len(l))
+
+ for i in l:
+ for chunk in streamencode(i):
+ yield chunk
+
+def streamencodearrayfromiter(it):
+ """Encode an iterator of items to an indefinite length array."""
+
+ yield BEGIN_INDEFINITE_ARRAY
+
+ for i in it:
+ for chunk in streamencode(i):
+ yield chunk
+
+ yield BREAK
+
+def streamencodemap(d):
+ """Encode dictionary to a generator.
+
+ Does not supporting indefinite length dictionaries.
+ """
+ yield encodelength(MAJOR_TYPE_MAP, len(d))
+
+ for key, value in sorted(d.iteritems()):
+ for chunk in streamencode(key):
+ yield chunk
+ for chunk in streamencode(value):
+ yield chunk
+
+def streamencodemapfromiter(it):
+ """Given an iterable of (key, value), encode to an indefinite length map."""
+ yield BEGIN_INDEFINITE_MAP
+
+ for key, value in it:
+ for chunk in streamencode(key):
+ yield chunk
+ for chunk in streamencode(value):
+ yield chunk
+
+ yield BREAK
+
+STREAM_ENCODERS = {
+ bytes: streamencodebytestring,
+ int: streamencodeint,
+ list: streamencodearray,
+ tuple: streamencodearray,
+ dict: streamencodemap,
+}
+
+def streamencode(v):
+ """Encode a value and emit a generator."""
+ fn = STREAM_ENCODERS.get(v.__class__)
+
+ if not fn:
+ raise ValueError('do not know how to encode %s' % type(v))
+
+ return fn(v)
+
+def readindefinitebytestringtoiter(fh, expectheader=True):
+ """Read an indefinite bytestring to a generator.
+
+ Receives an object with a ``read(X)`` method to read N bytes.
+
+ If ``expectheader`` is True, it is expected that the first byte read
+ will represent an indefinite length bytestring. Otherwise, we
+ expect the first byte to be part of the first bytestring chunk.
+ """
+ read = fh.read
+ decodeuint = decodermod.decode_uint
+ byteasinteger = decodermod.byte_as_integer
+
+ if expectheader:
+ initial = decodermod.byte_as_integer(read(1))
+
+ majortype = initial >> 5
+ subtype = initial & SUBTYPE_MASK
+
+ if majortype != MAJOR_TYPE_BYTESTRING:
+ raise decodermod.CBORDecodeError(
+ 'expected major type %d; got %d' % (MAJOR_TYPE_BYTESTRING,
+ majortype))
+
+ if subtype != SUBTYPE_INDEFINITE:
+ raise decodermod.CBORDecodeError(
+ 'expected indefinite subtype; got %d' % subtype)
+
+ # The indefinite bytestring is composed of chunks of normal bytestrings.
+ # Read chunks until we hit a BREAK byte.
+
+ while True:
+ # We need to sniff for the BREAK byte.
+ initial = byteasinteger(read(1))
+
+ if initial == BREAK_INT:
+ break
+
+ length = decodeuint(fh, initial & SUBTYPE_MASK)
+ chunk = read(length)
+
+ if len(chunk) != length:
+ raise decodermod.CBORDecodeError(
+ 'failed to read bytestring chunk: got %d bytes; expected %d' % (
+ len(chunk), length))
+
+ yield chunk
diff --git a/contrib/import-checker.py b/contrib/import-checker.py
--- a/contrib/import-checker.py
+++ b/contrib/import-checker.py
@@ -36,6 +36,8 @@
'mercurial.pure.parsers',
# third-party imports should be directly imported
'mercurial.thirdparty',
+ 'mercurial.thirdparty.cbor',
+ 'mercurial.thirdparty.cbor.cbor2',
'mercurial.thirdparty.zope',
'mercurial.thirdparty.zope.interface',
)
To: indygreg, #hg-reviewers
Cc: yuja, mercurial-devel
More information about the Mercurial-devel
mailing list