altdesktop / python-dbus-next

🚌 The next great DBus library for Python with asyncio support
https://python-dbus-next.readthedocs.io/en/latest/
MIT License
187 stars 59 forks source link

Unmarshaller optimizations #62

Closed rjarry closed 3 years ago

rjarry commented 3 years ago

Following our discussion on discord, here is a series of patches for optimizing the Unmarshaller class.

To motivate these changes, here are some line profiling results. The test setup is a bit complex and I did not have time to include it with the patch. However, it involves using line_profiler and subscribing to systemd units PropertiesChanged signals. Then, causing a lot of these signal messages to be broadcast (here, by creating a lot of virtual network devices in multiple network namespaces).

Before the changes

Timer unit: 1e-06 s

Total time: 5.26096 s
File: dbus_next/_private/unmarshaller.py
Function: read at line 42

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    42                                               def read(self, n):
    43                                                   # store previously read data in a buffer so we can resume on socket
    44                                                   # interruptions
    45    689597     511719.0      0.7      9.7          data = bytearray()
    46    689597     492045.0      0.7      9.4          if self.offset < len(self.buf):
    47                                                       data = self.buf[self.offset:self.offset + n]
    48                                                       self.offset += len(data)
    49                                                       n -= len(data)
    50    689597     323783.0      0.5      6.2          if n:
    51    688573     935647.0      1.4     17.8              read = self.stream.read(n)
    52    688573     366575.0      0.5      7.0              if read == b'':
    53                                                           raise EOFError()
    54    688573     331134.0      0.5      6.3              elif read is None:
    55       393       1210.0      3.1      0.0                  raise MarshallerStreamEndError()
    56    688180     459984.0      0.7      8.7              data.extend(read)
    57    688180     484322.0      0.7      9.2              self.buf.extend(read)
    58    688180     423061.0      0.6      8.0              if len(read) != n:
    59                                                           raise MarshallerStreamEndError()
    60    689204     459242.0      0.7      8.7          self.offset += n
    61    689204     472233.0      0.7      9.0          return bytes(data)

Total time: 4.70826 s
File: dbus_next/_private/unmarshaller.py
Function: read_string at line 109

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   109                                               def read_string(self, _=None):
   110     89534    1898808.0     21.2     40.3          str_length = self.read_uint32()
   111     89534    1394758.0     15.6     29.6          data = self.read(str_length)
   112     89534    1345369.0     15.0     28.6          self.read(1)
   113     89534      69327.0      0.8      1.5          return data.decode()

Total time: 15.5997 s
File: dbus_next/_private/unmarshaller.py
Function: _unmarshall at line 178

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   178                                               def _unmarshall(self):
   179      4951       8865.0      1.8      0.1          self.offset = 0
   180      4951     175875.0     35.5      1.1          self.endian = self.read_byte()
   181      4558       6388.0      1.4      0.0          if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN:
   182                                                       raise InvalidMessageError('Expecting endianness as the first byte')
   183      4558     126513.0     27.8      0.8          message_type = MessageType(self.read_byte())
   184      4558      94044.0     20.6      0.6          flags = MessageFlag(self.read_byte())
   185                                           
   186      4558      71461.0     15.7      0.5          protocol_version = self.read_byte()
   187                                           
   188      4558       5209.0      1.1      0.0          if protocol_version != PROTOCOL_VERSION:
   189                                                       raise InvalidMessageError(f'got unknown protocol version: {protocol_version}')
   190                                           
   191      4558     115648.0     25.4      0.7          body_len = self.read_uint32()
   192      4558      96788.0     21.2      0.6          serial = self.read_uint32()
   193                                           
   194      4558      28094.0      6.2      0.2          header_fields = {HeaderField.UNIX_FDS.name: []}
   195     27348    4046110.0    147.9     25.9          for field_struct in self.read_argument(SignatureTree('a(yv)').types[0]):
   196     22790      97859.0      4.3      0.6              field = HeaderField(field_struct[0])
   197     22790      29555.0      1.3      0.2              if field == HeaderField.UNIX_FDS:
   198                                                           header_fields[field.name].append(field_struct[1].value)
   199                                                       else:
   200     22790      77197.0      3.4      0.5                  header_fields[field.name] = field_struct[1].value
   201                                           
   202      4558      89324.0     19.6      0.6          self.align(8)
   203                                           
   204      4558      16499.0      3.6      0.1          path = header_fields.get(HeaderField.PATH.name)
   205      4558      14149.0      3.1      0.1          interface = header_fields.get(HeaderField.INTERFACE.name)
   206      4558      12471.0      2.7      0.1          member = header_fields.get(HeaderField.MEMBER.name)
   207      4558      10913.0      2.4      0.1          error_name = header_fields.get(HeaderField.ERROR_NAME.name)
   208      4558      10603.0      2.3      0.1          reply_serial = header_fields.get(HeaderField.REPLY_SERIAL.name)
   209      4558      11731.0      2.6      0.1          destination = header_fields.get(HeaderField.DESTINATION.name)
   210      4558      10689.0      2.3      0.1          sender = header_fields.get(HeaderField.SENDER.name)
   211      4558      13096.0      2.9      0.1          signature = header_fields.get(HeaderField.SIGNATURE.name, '')
   212      4558     186550.0     40.9      1.2          signature_tree = SignatureTree(signature)
   213      4558      15095.0      3.3      0.1          unix_fds = header_fields.get(HeaderField.UNIX_FDS.name)
   214                                           
   215      4558       4978.0      1.1      0.0          body = []
   216                                           
   217      4558       4925.0      1.1      0.0          if body_len:
   218     18232      26261.0      1.4      0.2              for type_ in signature_tree.types:
   219     13674    9830667.0    718.9     63.0                  body.append(self.read_argument(type_))
   220                                           
   221      4558       6732.0      1.5      0.0          self.message = Message(destination=destination,
   222      4558       4690.0      1.0      0.0                                 path=path,
   223      4558       8697.0      1.9      0.1                                 interface=interface,
   224      4558       4439.0      1.0      0.0                                 member=member,
   225      4558       4655.0      1.0      0.0                                 message_type=message_type,
   226      4558       4408.0      1.0      0.0                                 flags=flags,
   227      4558       4475.0      1.0      0.0                                 error_name=error_name,
   228      4558       4414.0      1.0      0.0                                 reply_serial=reply_serial,
   229      4558       4329.0      0.9      0.0                                 sender=sender,
   230      4558       4375.0      1.0      0.0                                 unix_fds=unix_fds,
   231      4558       4302.0      0.9      0.0                                 signature=signature_tree,
   232      4558       4407.0      1.0      0.0                                 body=body,
   233      4558     302236.0     66.3      1.9                                 serial=serial)

With these optimizations

Timer unit: 1e-06 s

Total time: 2.52228 s
File: dbus_next/_private/unmarshaller.py
Function: read at line 43

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    43                                               def read(self, n, prefetch=False):
    44                                                   """
    45                                                   Read from underlying socket into buffer and advance offset accordingly.
    46                                           
    47                                                   :arg n:
    48                                                       Number of bytes to read. If not enough bytes are available in the
    49                                                       buffer, read more from it.
    50                                                   :arg prefetch:
    51                                                       Do not update current offset after reading.
    52                                           
    53                                                   :returns:
    54                                                       Previous offset (before reading). To get the actual read bytes,
    55                                                       use the returned value and self.buf.
    56                                                   """
    57                                                   # store previously read data in a buffer so we can resume on socket
    58                                                   # interruptions
    59    545730     574351.0      1.1     22.8          missing_bytes = n - (len(self.buf) - self.offset)
    60    545730     344652.0      0.6     13.7          if missing_bytes > 0:
    61     18882     125987.0      6.7      5.0              data = self.stream.read(missing_bytes)
    62     18882      13948.0      0.7      0.6              if data == b'':
    63                                                           raise EOFError()
    64     18882      13070.0      0.7      0.5              elif data is None:
    65       450       1255.0      2.8      0.0                  raise MarshallerStreamEndError()
    66     18432      31132.0      1.7      1.2              self.buf.extend(data)
    67     18432      17326.0      0.9      0.7              if len(data) != missing_bytes:
    68                                                           raise MarshallerStreamEndError()
    69    545280     361875.0      0.7     14.3          prev = self.offset
    70    545280     309896.0      0.6     12.3          if not prefetch:
    71    531456     413507.0      0.8     16.4              self.offset += n
    72    545280     315282.0      0.6     12.5          return prev

Total time: 3.05621 s
File: dbus_next/_private/unmarshaller.py
Function: read_string at line 120

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   120                                               def read_string(self, _=None):
   121     90624    1564935.0     17.3     51.2          str_length = self.read_uint32()
   122     90624    1010484.0     11.2     33.1          o = self.read(str_length + 1)  # read terminating '\0' byte as well
   123                                                   # avoid buffer copies when slicing
   124     90624     248357.0      2.7      8.1          str_mem_slice = memoryview(self.buf)[o:o + str_length]
   125     90624     232433.0      2.6      7.6          return decode(str_mem_slice)

Total time: 13.1745 s
File: dbus_next/_private/unmarshaller.py
Function: _unmarshall at line 194

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   194                                               def _unmarshall(self):
   195      5058      13285.0      2.6      0.1          self.offset = 0
   196      5058     203096.0     40.2      1.5          self.read(16, prefetch=True)
   197      4608      67682.0     14.7      0.5          self.endian = self.read_byte()
   198      4608       7785.0      1.7      0.1          if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN:
   199                                                       raise InvalidMessageError('Expecting endianness as the first byte')
   200      4608     107473.0     23.3      0.8          message_type = MessageType(self.read_byte())
   201      4608      79156.0     17.2      0.6          flags = MessageFlag(self.read_byte())
   202                                           
   203      4608      64977.0     14.1      0.5          protocol_version = self.read_byte()
   204                                           
   205      4608      13540.0      2.9      0.1          if protocol_version != PROTOCOL_VERSION:
   206                                                       raise InvalidMessageError(f'got unknown protocol version: {protocol_version}')
   207                                           
   208      4608      83779.0     18.2      0.6          body_len = self.read_uint32()
   209      4608      86027.0     18.7      0.7          serial = self.read_uint32()
   210                                           
   211      4608      72960.0     15.8      0.6          array_len = self.read_uint32()
   212      4608      99894.0     21.7      0.8          self.read(array_len, prefetch=True)
   213                                                   # backtrack offset since array length needs to be read again
   214      4608       9999.0      2.2      0.1          self.offset -= 4
   215                                           
   216      4608      32743.0      7.1      0.2          header_fields = {HeaderField.UNIX_FDS.name: []}
   217     27648    3243462.0    117.3     24.6          for field_struct in self.read_argument(SignatureTree('a(yv)').types[0]):
   218     23040     128509.0      5.6      1.0              field = HeaderField(field_struct[0])
   219     23040      40008.0      1.7      0.3              if field == HeaderField.UNIX_FDS:
   220                                                           header_fields[field.name].append(field_struct[1].value)
   221                                                       else:
   222     23040      89569.0      3.9      0.7                  header_fields[field.name] = field_struct[1].value
   223                                           
   224      4608     133292.0     28.9      1.0          self.align(8)
   225                                           
   226      4608      25092.0      5.4      0.2          path = header_fields.get(HeaderField.PATH.name)
   227      4608      15829.0      3.4      0.1          interface = header_fields.get(HeaderField.INTERFACE.name)
   228      4608      14231.0      3.1      0.1          member = header_fields.get(HeaderField.MEMBER.name)
   229      4608      15498.0      3.4      0.1          error_name = header_fields.get(HeaderField.ERROR_NAME.name)
   230      4608      19051.0      4.1      0.1          reply_serial = header_fields.get(HeaderField.REPLY_SERIAL.name)
   231      4608      13442.0      2.9      0.1          destination = header_fields.get(HeaderField.DESTINATION.name)
   232      4608      13425.0      2.9      0.1          sender = header_fields.get(HeaderField.SENDER.name)
   233      4608      13082.0      2.8      0.1          signature = header_fields.get(HeaderField.SIGNATURE.name, '')
   234      4608     222017.0     48.2      1.7          signature_tree = SignatureTree(signature)
   235      4608      24041.0      5.2      0.2          unix_fds = header_fields.get(HeaderField.UNIX_FDS.name)
   236                                           
   237      4608       7020.0      1.5      0.1          body = []
   238                                           
   239      4608       6606.0      1.4      0.1          if body_len:
   240      4608     130171.0     28.2      1.0              self.read(body_len, prefetch=True)
   241     18432      32888.0      1.8      0.2              for type_ in signature_tree.types:
   242     13824    7637776.0    552.5     58.0                  body.append(self.read_argument(type_))
   243                                           
   244      4608       8695.0      1.9      0.1          self.message = Message(destination=destination,
   245      4608       6298.0      1.4      0.0                                 path=path,
   246      4608      10138.0      2.2      0.1                                 interface=interface,
   247      4608       5327.0      1.2      0.0                                 member=member,
   248      4608       5578.0      1.2      0.0                                 message_type=message_type,
   249      4608       5409.0      1.2      0.0                                 flags=flags,
   250      4608       6982.0      1.5      0.1                                 error_name=error_name,
   251      4608       5170.0      1.1      0.0                                 reply_serial=reply_serial,
   252      4608       5535.0      1.2      0.0                                 sender=sender,
   253      4608       6056.0      1.3      0.0                                 unix_fds=unix_fds,
   254      4608       6145.0      1.3      0.0                                 signature=signature_tree,
   255      4608       7562.0      1.6      0.1                                 body=body,
   256      4608     328163.0     71.2      2.5                                 serial=serial)

The number of parsed messages is roughly the same. However, the profiling results shows that we spend less time in all functions.

The most significative improvement is the actual number of self.stream.read calls. Using memoryview and unpack_from helps also a bit.

acrisci commented 3 years ago

:+1:

rjarry commented 3 years ago

I think there may be an issue if the message header is not a multiple of 8 bytes long. There will be an additional read performed when reading the body since there is an alignment after reading the header. I forgot to take this into account. This will not break anything but this additional read could be avoided if the alignment bytes were prefetched from the start. I'll submit another PR with a fix.