protorpc-standalone-0.9.1/0000755000076500000240000000000012300027071016460 5ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/PKG-INFO0000644000076500000240000000143412300027071017557 0ustar jeremydwstaff00000000000000Metadata-Version: 1.1 Name: protorpc-standalone Version: 0.9.1 Summary: Google Protocol RPC (modified to run outside Google App Engine) Home-page: https://github.com/jeremydw/protorpc-standalone Author: Google Inc. Author-email: rafek@google.com License: Apache 2.0 Description: UNKNOWN Keywords: google protocol rpc Platform: UNKNOWN Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Operating System :: MacOS :: MacOS X Classifier: Operating System :: Microsoft :: Windows Classifier: Operating System :: POSIX :: Linux Classifier: Programming Language :: Python :: 2.7 Classifier: Topic :: Software Development :: Libraries Classifier: Topic :: Software Development :: Libraries :: Python Modules Provides: protorpc (0.9.1) protorpc-standalone-0.9.1/protorpc/0000755000076500000240000000000012300027071020330 5ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc/__init__.py0000755000076500000240000000126112277637135022471 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2011 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Main module for ProtoRPC package.""" __author__ = 'rafek@google.com (Rafe Kaplan)' protorpc-standalone-0.9.1/protorpc/_google/0000755000076500000240000000000012300027071021743 5ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc/_google/__init__.py0000644000076500000240000000000012277637135024067 0ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc/_google/net/0000755000076500000240000000000012300027071022531 5ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc/_google/net/__init__.py0000644000076500000240000000000012277637135024655 0ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc/_google/net/proto/0000755000076500000240000000000012300027071023674 5ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc/_google/net/proto/__init__.py0000644000076500000240000000113112277637135026026 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2007 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # protorpc-standalone-0.9.1/protorpc/_google/net/proto/message_set.py0000644000076500000240000002405212277637135026575 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2007 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """This module contains the MessageSet class, which is a special kind of protocol message which can contain other protocol messages without knowing their types. See the class's doc string for more information.""" from google.net.proto import ProtocolBuffer import logging try: from google3.net.proto import _net_proto___parse__python except ImportError: _net_proto___parse__python = None TAG_BEGIN_ITEM_GROUP = 11 TAG_END_ITEM_GROUP = 12 TAG_TYPE_ID = 16 TAG_MESSAGE = 26 class Item: def __init__(self, message, message_class=None): self.message = message self.message_class = message_class def SetToDefaultInstance(self, message_class): self.message = message_class() self.message_class = message_class def Parse(self, message_class): if self.message_class is not None: return 1 try: message_obj = message_class() message_obj.MergePartialFromString(self.message) self.message = message_obj self.message_class = message_class return 1 except ProtocolBuffer.ProtocolBufferDecodeError: logging.warn("Parse error in message inside MessageSet. Tried " "to parse as: " + message_class.__name__) return 0 def MergeFrom(self, other): if self.message_class is not None: if other.Parse(self.message_class): self.message.MergeFrom(other.message) elif other.message_class is not None: if not self.Parse(other.message_class): self.message = other.message_class() self.message_class = other.message_class self.message.MergeFrom(other.message) else: self.message += other.message def Copy(self): if self.message_class is None: return Item(self.message) else: new_message = self.message_class() new_message.CopyFrom(self.message) return Item(new_message, self.message_class) def Equals(self, other): if self.message_class is not None: if not other.Parse(self.message_class): return 0 return self.message.Equals(other.message) elif other.message_class is not None: if not self.Parse(other.message_class): return 0 return self.message.Equals(other.message) else: return self.message == other.message def IsInitialized(self, debug_strs=None): if self.message_class is None: return 1 else: return self.message.IsInitialized(debug_strs) def ByteSize(self, pb, type_id): message_length = 0 if self.message_class is None: message_length = len(self.message) else: message_length = self.message.ByteSize() return pb.lengthString(message_length) + pb.lengthVarInt64(type_id) + 2 def ByteSizePartial(self, pb, type_id): message_length = 0 if self.message_class is None: message_length = len(self.message) else: message_length = self.message.ByteSizePartial() return pb.lengthString(message_length) + pb.lengthVarInt64(type_id) + 2 def OutputUnchecked(self, out, type_id): out.putVarInt32(TAG_TYPE_ID) out.putVarUint64(type_id) out.putVarInt32(TAG_MESSAGE) if self.message_class is None: out.putPrefixedString(self.message) else: out.putVarInt32(self.message.ByteSize()) self.message.OutputUnchecked(out) def OutputPartial(self, out, type_id): out.putVarInt32(TAG_TYPE_ID) out.putVarUint64(type_id) out.putVarInt32(TAG_MESSAGE) if self.message_class is None: out.putPrefixedString(self.message) else: out.putVarInt32(self.message.ByteSizePartial()) self.message.OutputPartial(out) def Decode(decoder): type_id = 0 message = None while 1: tag = decoder.getVarInt32() if tag == TAG_END_ITEM_GROUP: break if tag == TAG_TYPE_ID: type_id = decoder.getVarUint64() continue if tag == TAG_MESSAGE: message = decoder.getPrefixedString() continue if tag == 0: raise ProtocolBuffer.ProtocolBufferDecodeError decoder.skipData(tag) if type_id == 0 or message is None: raise ProtocolBuffer.ProtocolBufferDecodeError return (type_id, message) Decode = staticmethod(Decode) class MessageSet(ProtocolBuffer.ProtocolMessage): def __init__(self, contents=None): self.items = dict() if contents is not None: self.MergeFromString(contents) def get(self, message_class): if message_class.MESSAGE_TYPE_ID not in self.items: return message_class() item = self.items[message_class.MESSAGE_TYPE_ID] if item.Parse(message_class): return item.message else: return message_class() def mutable(self, message_class): if message_class.MESSAGE_TYPE_ID not in self.items: message = message_class() self.items[message_class.MESSAGE_TYPE_ID] = Item(message, message_class) return message item = self.items[message_class.MESSAGE_TYPE_ID] if not item.Parse(message_class): item.SetToDefaultInstance(message_class) return item.message def has(self, message_class): if message_class.MESSAGE_TYPE_ID not in self.items: return 0 item = self.items[message_class.MESSAGE_TYPE_ID] return item.Parse(message_class) def has_unparsed(self, message_class): return message_class.MESSAGE_TYPE_ID in self.items def GetTypeIds(self): return self.items.keys() def NumMessages(self): return len(self.items) def remove(self, message_class): if message_class.MESSAGE_TYPE_ID in self.items: del self.items[message_class.MESSAGE_TYPE_ID] def __getitem__(self, message_class): if message_class.MESSAGE_TYPE_ID not in self.items: raise KeyError(message_class) item = self.items[message_class.MESSAGE_TYPE_ID] if item.Parse(message_class): return item.message else: raise KeyError(message_class) def __setitem__(self, message_class, message): self.items[message_class.MESSAGE_TYPE_ID] = Item(message, message_class) def __contains__(self, message_class): return self.has(message_class) def __delitem__(self, message_class): self.remove(message_class) def __len__(self): return len(self.items) def MergeFrom(self, other): assert other is not self for (type_id, item) in other.items.items(): if type_id in self.items: self.items[type_id].MergeFrom(item) else: self.items[type_id] = item.Copy() def Equals(self, other): if other is self: return 1 if len(self.items) != len(other.items): return 0 for (type_id, item) in other.items.items(): if type_id not in self.items: return 0 if not self.items[type_id].Equals(item): return 0 return 1 def __eq__(self, other): return ((other is not None) and (other.__class__ == self.__class__) and self.Equals(other)) def __ne__(self, other): return not (self == other) def IsInitialized(self, debug_strs=None): initialized = 1 for item in self.items.values(): if not item.IsInitialized(debug_strs): initialized = 0 return initialized def ByteSize(self): n = 2 * len(self.items) for (type_id, item) in self.items.items(): n += item.ByteSize(self, type_id) return n def ByteSizePartial(self): n = 2 * len(self.items) for (type_id, item) in self.items.items(): n += item.ByteSizePartial(self, type_id) return n def Clear(self): self.items = dict() def OutputUnchecked(self, out): for (type_id, item) in self.items.items(): out.putVarInt32(TAG_BEGIN_ITEM_GROUP) item.OutputUnchecked(out, type_id) out.putVarInt32(TAG_END_ITEM_GROUP) def OutputPartial(self, out): for (type_id, item) in self.items.items(): out.putVarInt32(TAG_BEGIN_ITEM_GROUP) item.OutputPartial(out, type_id) out.putVarInt32(TAG_END_ITEM_GROUP) def TryMerge(self, decoder): while decoder.avail() > 0: tag = decoder.getVarInt32() if tag == TAG_BEGIN_ITEM_GROUP: (type_id, message) = Item.Decode(decoder) if type_id in self.items: self.items[type_id].MergeFrom(Item(message)) else: self.items[type_id] = Item(message) continue if (tag == 0): raise ProtocolBuffer.ProtocolBufferDecodeError decoder.skipData(tag) def _CToASCII(self, output_format): if _net_proto___parse__python is None: return ProtocolBuffer.ProtocolMessage._CToASCII(self, output_format) else: return _net_proto___parse__python.ToASCII( self, "MessageSetInternal", output_format) def ParseASCII(self, s): if _net_proto___parse__python is None: ProtocolBuffer.ProtocolMessage.ParseASCII(self, s) else: _net_proto___parse__python.ParseASCII(self, "MessageSetInternal", s) def ParseASCIIIgnoreUnknown(self, s): if _net_proto___parse__python is None: ProtocolBuffer.ProtocolMessage.ParseASCIIIgnoreUnknown(self, s) else: _net_proto___parse__python.ParseASCIIIgnoreUnknown( self, "MessageSetInternal", s) def __str__(self, prefix="", printElemNumber=0): text = "" for (type_id, item) in self.items.items(): if item.message_class is None: text += "%s[%d] <\n" % (prefix, type_id) text += "%s (%d bytes)\n" % (prefix, len(item.message)) text += "%s>\n" % prefix else: text += "%s[%s] <\n" % (prefix, item.message_class.__name__) text += item.message.__str__(prefix + " ", printElemNumber) text += "%s>\n" % prefix return text _PROTO_DESCRIPTOR_NAME = 'MessageSet' __all__ = ['MessageSet'] protorpc-standalone-0.9.1/protorpc/_google/net/proto/ProtocolBuffer.py0000644000076500000240000006426312277637135027241 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2007 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import array import httplib import re import struct __all__ = ['ProtocolMessage', 'Encoder', 'Decoder', 'ExtendableProtocolMessage', 'ProtocolBufferDecodeError', 'ProtocolBufferEncodeError', 'ProtocolBufferReturnError'] URL_RE = re.compile('^(https?)://([^/]+)(/.*)$') class ProtocolMessage: def __init__(self, contents=None): raise NotImplementedError def Clear(self): raise NotImplementedError def IsInitialized(self, debug_strs=None): raise NotImplementedError def Encode(self): try: return self._CEncode() except NotImplementedError: e = Encoder() self.Output(e) return e.buffer().tostring() def SerializeToString(self): return self.Encode() def SerializePartialToString(self): try: return self._CEncodePartial() except (NotImplementedError, AttributeError): e = Encoder() self.OutputPartial(e) return e.buffer().tostring() def _CEncode(self): raise NotImplementedError def _CEncodePartial(self): raise NotImplementedError def ParseFromString(self, s): self.Clear() self.MergeFromString(s) def ParsePartialFromString(self, s): self.Clear() self.MergePartialFromString(s) def MergeFromString(self, s): self.MergePartialFromString(s) dbg = [] if not self.IsInitialized(dbg): raise ProtocolBufferDecodeError, '\n\t'.join(dbg) def MergePartialFromString(self, s): try: self._CMergeFromString(s) except NotImplementedError: a = array.array('B') a.fromstring(s) d = Decoder(a, 0, len(a)) self.TryMerge(d) def _CMergeFromString(self, s): raise NotImplementedError def __getstate__(self): return self.Encode() def __setstate__(self, contents_): self.__init__(contents=contents_) def sendCommand(self, server, url, response, follow_redirects=1, secure=0, keyfile=None, certfile=None): data = self.Encode() if secure: if keyfile and certfile: conn = httplib.HTTPSConnection(server, key_file=keyfile, cert_file=certfile) else: conn = httplib.HTTPSConnection(server) else: conn = httplib.HTTPConnection(server) conn.putrequest("POST", url) conn.putheader("Content-Length", "%d" %len(data)) conn.endheaders() conn.send(data) resp = conn.getresponse() if follow_redirects > 0 and resp.status == 302: m = URL_RE.match(resp.getheader('Location')) if m: protocol, server, url = m.groups() return self.sendCommand(server, url, response, follow_redirects=follow_redirects - 1, secure=(protocol == 'https'), keyfile=keyfile, certfile=certfile) if resp.status != 200: raise ProtocolBufferReturnError(resp.status) if response is not None: response.ParseFromString(resp.read()) return response def sendSecureCommand(self, server, keyfile, certfile, url, response, follow_redirects=1): return self.sendCommand(server, url, response, follow_redirects=follow_redirects, secure=1, keyfile=keyfile, certfile=certfile) def __str__(self, prefix="", printElemNumber=0): raise NotImplementedError def ToASCII(self): return self._CToASCII(ProtocolMessage._SYMBOLIC_FULL_ASCII) def ToCompactASCII(self): return self._CToASCII(ProtocolMessage._NUMERIC_ASCII) def ToShortASCII(self): return self._CToASCII(ProtocolMessage._SYMBOLIC_SHORT_ASCII) _NUMERIC_ASCII = 0 _SYMBOLIC_SHORT_ASCII = 1 _SYMBOLIC_FULL_ASCII = 2 def _CToASCII(self, output_format): raise NotImplementedError def ParseASCII(self, ascii_string): raise NotImplementedError def ParseASCIIIgnoreUnknown(self, ascii_string): raise NotImplementedError def Equals(self, other): raise NotImplementedError def __eq__(self, other): if other.__class__ is self.__class__: return self.Equals(other) return NotImplemented def __ne__(self, other): if other.__class__ is self.__class__: return not self.Equals(other) return NotImplemented def Output(self, e): dbg = [] if not self.IsInitialized(dbg): raise ProtocolBufferEncodeError, '\n\t'.join(dbg) self.OutputUnchecked(e) return def OutputUnchecked(self, e): raise NotImplementedError def OutputPartial(self, e): raise NotImplementedError def Parse(self, d): self.Clear() self.Merge(d) return def Merge(self, d): self.TryMerge(d) dbg = [] if not self.IsInitialized(dbg): raise ProtocolBufferDecodeError, '\n\t'.join(dbg) return def TryMerge(self, d): raise NotImplementedError def CopyFrom(self, pb): if (pb == self): return self.Clear() self.MergeFrom(pb) def MergeFrom(self, pb): raise NotImplementedError def lengthVarInt32(self, n): return self.lengthVarInt64(n) def lengthVarInt64(self, n): if n < 0: return 10 result = 0 while 1: result += 1 n >>= 7 if n == 0: break return result def lengthString(self, n): return self.lengthVarInt32(n) + n def DebugFormat(self, value): return "%s" % value def DebugFormatInt32(self, value): if (value <= -2000000000 or value >= 2000000000): return self.DebugFormatFixed32(value) return "%d" % value def DebugFormatInt64(self, value): if (value <= -20000000000000 or value >= 20000000000000): return self.DebugFormatFixed64(value) return "%d" % value def DebugFormatString(self, value): def escape(c): o = ord(c) if o == 10: return r"\n" if o == 39: return r"\'" if o == 34: return r'\"' if o == 92: return r"\\" if o >= 127 or o < 32: return "\\%03o" % o return c return '"' + "".join([escape(c) for c in value]) + '"' def DebugFormatFloat(self, value): return "%ff" % value def DebugFormatFixed32(self, value): if (value < 0): value += (1L<<32) return "0x%x" % value def DebugFormatFixed64(self, value): if (value < 0): value += (1L<<64) return "0x%x" % value def DebugFormatBool(self, value): if value: return "true" else: return "false" TYPE_DOUBLE = 1 TYPE_FLOAT = 2 TYPE_INT64 = 3 TYPE_UINT64 = 4 TYPE_INT32 = 5 TYPE_FIXED64 = 6 TYPE_FIXED32 = 7 TYPE_BOOL = 8 TYPE_STRING = 9 TYPE_GROUP = 10 TYPE_FOREIGN = 11 _TYPE_TO_DEBUG_STRING = { TYPE_INT32: ProtocolMessage.DebugFormatInt32, TYPE_INT64: ProtocolMessage.DebugFormatInt64, TYPE_UINT64: ProtocolMessage.DebugFormatInt64, TYPE_FLOAT: ProtocolMessage.DebugFormatFloat, TYPE_STRING: ProtocolMessage.DebugFormatString, TYPE_FIXED32: ProtocolMessage.DebugFormatFixed32, TYPE_FIXED64: ProtocolMessage.DebugFormatFixed64, TYPE_BOOL: ProtocolMessage.DebugFormatBool } class Encoder: NUMERIC = 0 DOUBLE = 1 STRING = 2 STARTGROUP = 3 ENDGROUP = 4 FLOAT = 5 MAX_TYPE = 6 def __init__(self): self.buf = array.array('B') return def buffer(self): return self.buf def put8(self, v): if v < 0 or v >= (1<<8): raise ProtocolBufferEncodeError, "u8 too big" self.buf.append(v & 255) return def put16(self, v): if v < 0 or v >= (1<<16): raise ProtocolBufferEncodeError, "u16 too big" self.buf.append((v >> 0) & 255) self.buf.append((v >> 8) & 255) return def put32(self, v): if v < 0 or v >= (1L<<32): raise ProtocolBufferEncodeError, "u32 too big" self.buf.append((v >> 0) & 255) self.buf.append((v >> 8) & 255) self.buf.append((v >> 16) & 255) self.buf.append((v >> 24) & 255) return def put64(self, v): if v < 0 or v >= (1L<<64): raise ProtocolBufferEncodeError, "u64 too big" self.buf.append((v >> 0) & 255) self.buf.append((v >> 8) & 255) self.buf.append((v >> 16) & 255) self.buf.append((v >> 24) & 255) self.buf.append((v >> 32) & 255) self.buf.append((v >> 40) & 255) self.buf.append((v >> 48) & 255) self.buf.append((v >> 56) & 255) return def putVarInt32(self, v): buf_append = self.buf.append if v & 127 == v: buf_append(v) return if v >= 0x80000000 or v < -0x80000000: raise ProtocolBufferEncodeError, "int32 too big" if v < 0: v += 0x10000000000000000 while True: bits = v & 127 v >>= 7 if v: bits |= 128 buf_append(bits) if not v: break return def putVarInt64(self, v): buf_append = self.buf.append if v >= 0x8000000000000000 or v < -0x8000000000000000: raise ProtocolBufferEncodeError, "int64 too big" if v < 0: v += 0x10000000000000000 while True: bits = v & 127 v >>= 7 if v: bits |= 128 buf_append(bits) if not v: break return def putVarUint64(self, v): buf_append = self.buf.append if v < 0 or v >= 0x10000000000000000: raise ProtocolBufferEncodeError, "uint64 too big" while True: bits = v & 127 v >>= 7 if v: bits |= 128 buf_append(bits) if not v: break return def putFloat(self, v): a = array.array('B') a.fromstring(struct.pack(" self.limit: raise ProtocolBufferDecodeError, "truncated" self.idx += n return def skipData(self, tag): t = tag & 7 if t == Encoder.NUMERIC: self.getVarInt64() elif t == Encoder.DOUBLE: self.skip(8) elif t == Encoder.STRING: n = self.getVarInt32() self.skip(n) elif t == Encoder.STARTGROUP: while 1: t = self.getVarInt32() if (t & 7) == Encoder.ENDGROUP: break else: self.skipData(t) if (t - Encoder.ENDGROUP) != (tag - Encoder.STARTGROUP): raise ProtocolBufferDecodeError, "corrupted" elif t == Encoder.ENDGROUP: raise ProtocolBufferDecodeError, "corrupted" elif t == Encoder.FLOAT: self.skip(4) else: raise ProtocolBufferDecodeError, "corrupted" def get8(self): if self.idx >= self.limit: raise ProtocolBufferDecodeError, "truncated" c = self.buf[self.idx] self.idx += 1 return c def get16(self): if self.idx + 2 > self.limit: raise ProtocolBufferDecodeError, "truncated" c = self.buf[self.idx] d = self.buf[self.idx + 1] self.idx += 2 return (d << 8) | c def get32(self): if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError, "truncated" c = self.buf[self.idx] d = self.buf[self.idx + 1] e = self.buf[self.idx + 2] f = long(self.buf[self.idx + 3]) self.idx += 4 return (f << 24) | (e << 16) | (d << 8) | c def get64(self): if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError, "truncated" c = self.buf[self.idx] d = self.buf[self.idx + 1] e = self.buf[self.idx + 2] f = long(self.buf[self.idx + 3]) g = long(self.buf[self.idx + 4]) h = long(self.buf[self.idx + 5]) i = long(self.buf[self.idx + 6]) j = long(self.buf[self.idx + 7]) self.idx += 8 return ((j << 56) | (i << 48) | (h << 40) | (g << 32) | (f << 24) | (e << 16) | (d << 8) | c) def getVarInt32(self): b = self.get8() if not (b & 128): return b result = long(0) shift = 0 while 1: result |= (long(b & 127) << shift) shift += 7 if not (b & 128): if result >= 0x10000000000000000L: raise ProtocolBufferDecodeError, "corrupted" break if shift >= 64: raise ProtocolBufferDecodeError, "corrupted" b = self.get8() if result >= 0x8000000000000000L: result -= 0x10000000000000000L if result >= 0x80000000L or result < -0x80000000L: raise ProtocolBufferDecodeError, "corrupted" return result def getVarInt64(self): result = self.getVarUint64() if result >= (1L << 63): result -= (1L << 64) return result def getVarUint64(self): result = long(0) shift = 0 while 1: if shift >= 64: raise ProtocolBufferDecodeError, "corrupted" b = self.get8() result |= (long(b & 127) << shift) shift += 7 if not (b & 128): if result >= (1L << 64): raise ProtocolBufferDecodeError, "corrupted" return result return result def getFloat(self): if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError, "truncated" a = self.buf[self.idx:self.idx+4] self.idx += 4 return struct.unpack(" self.limit: raise ProtocolBufferDecodeError, "truncated" a = self.buf[self.idx:self.idx+8] self.idx += 8 return struct.unpack(" self.limit: raise ProtocolBufferDecodeError, "truncated" r = self.buf[self.idx : self.idx + length] self.idx += length return r.tostring() def getRawString(self): r = self.buf[self.idx:self.limit] self.idx = self.limit return r.tostring() _TYPE_TO_METHOD = { TYPE_DOUBLE: getDouble, TYPE_FLOAT: getFloat, TYPE_FIXED64: get64, TYPE_FIXED32: get32, TYPE_INT32: getVarInt32, TYPE_INT64: getVarInt64, TYPE_UINT64: getVarUint64, TYPE_BOOL: getBoolean, TYPE_STRING: getPrefixedString } class ExtensionIdentifier(object): __slots__ = ('full_name', 'number', 'field_type', 'wire_tag', 'is_repeated', 'default', 'containing_cls', 'composite_cls', 'message_name') def __init__(self, full_name, number, field_type, wire_tag, is_repeated, default): self.full_name = full_name self.number = number self.field_type = field_type self.wire_tag = wire_tag self.is_repeated = is_repeated self.default = default class ExtendableProtocolMessage(ProtocolMessage): def HasExtension(self, extension): self._VerifyExtensionIdentifier(extension) return extension in self._extension_fields def ClearExtension(self, extension): self._VerifyExtensionIdentifier(extension) if extension in self._extension_fields: del self._extension_fields[extension] def GetExtension(self, extension, index=None): self._VerifyExtensionIdentifier(extension) if extension in self._extension_fields: result = self._extension_fields[extension] else: if extension.is_repeated: result = [] elif extension.composite_cls: result = extension.composite_cls() else: result = extension.default if extension.is_repeated: result = result[index] return result def SetExtension(self, extension, *args): self._VerifyExtensionIdentifier(extension) if extension.composite_cls: raise TypeError( 'Cannot assign to extension "%s" because it is a composite type.' % extension.full_name) if extension.is_repeated: if (len(args) != 2): raise TypeError( 'SetExtension(extension, index, value) for repeated extension ' 'takes exactly 3 arguments: (%d given)' % len(args)) index = args[0] value = args[1] self._extension_fields[extension][index] = value else: if (len(args) != 1): raise TypeError( 'SetExtension(extension, value) for singular extension ' 'takes exactly 3 arguments: (%d given)' % len(args)) value = args[0] self._extension_fields[extension] = value def MutableExtension(self, extension, index=None): self._VerifyExtensionIdentifier(extension) if extension.composite_cls is None: raise TypeError( 'MutableExtension() cannot be applied to "%s", because it is not a ' 'composite type.' % extension.full_name) if extension.is_repeated: if index is None: raise TypeError( 'MutableExtension(extension, index) for repeated extension ' 'takes exactly 2 arguments: (1 given)') return self.GetExtension(extension, index) if extension in self._extension_fields: return self._extension_fields[extension] else: result = extension.composite_cls() self._extension_fields[extension] = result return result def ExtensionList(self, extension): self._VerifyExtensionIdentifier(extension) if not extension.is_repeated: raise TypeError( 'ExtensionList() cannot be applied to "%s", because it is not a ' 'repeated extension.' % extension.full_name) if extension in self._extension_fields: return self._extension_fields[extension] result = [] self._extension_fields[extension] = result return result def ExtensionSize(self, extension): self._VerifyExtensionIdentifier(extension) if not extension.is_repeated: raise TypeError( 'ExtensionSize() cannot be applied to "%s", because it is not a ' 'repeated extension.' % extension.full_name) if extension in self._extension_fields: return len(self._extension_fields[extension]) return 0 def AddExtension(self, extension, value=None): self._VerifyExtensionIdentifier(extension) if not extension.is_repeated: raise TypeError( 'AddExtension() cannot be applied to "%s", because it is not a ' 'repeated extension.' % extension.full_name) if extension in self._extension_fields: field = self._extension_fields[extension] else: field = [] self._extension_fields[extension] = field if extension.composite_cls: if value is not None: raise TypeError( 'value must not be set in AddExtension() for "%s", because it is ' 'a message type extension. Set values on the returned message ' 'instead.' % extension.full_name) msg = extension.composite_cls() field.append(msg) return msg field.append(value) def _VerifyExtensionIdentifier(self, extension): if extension.containing_cls != self.__class__: raise TypeError("Containing type of %s is %s, but not %s." % (extension.full_name, extension.containing_cls.__name__, self.__class__.__name__)) def _MergeExtensionFields(self, x): for ext, val in x._extension_fields.items(): if ext.is_repeated: for i in xrange(len(val)): if ext.composite_cls is None: self.AddExtension(ext, val[i]) else: self.AddExtension(ext).MergeFrom(val[i]) else: if ext.composite_cls is None: self.SetExtension(ext, val) else: self.MutableExtension(ext).MergeFrom(val) def _ListExtensions(self): result = [ext for ext in self._extension_fields.keys() if (not ext.is_repeated) or self.ExtensionSize(ext) > 0] result.sort(key = lambda item: item.number) return result def _ExtensionEquals(self, x): extensions = self._ListExtensions() if extensions != x._ListExtensions(): return False for ext in extensions: if ext.is_repeated: if self.ExtensionSize(ext) != x.ExtensionSize(ext): return False for e1, e2 in zip(self.ExtensionList(ext), x.ExtensionList(ext)): if e1 != e2: return False else: if self.GetExtension(ext) != x.GetExtension(ext): return False return True def _OutputExtensionFields(self, out, partial, extensions, start_index, end_field_number): def OutputSingleField(ext, value): out.putVarInt32(ext.wire_tag) if ext.field_type == TYPE_GROUP: if partial: value.OutputPartial(out) else: value.OutputUnchecked(out) out.putVarInt32(ext.wire_tag + 1) elif ext.field_type == TYPE_FOREIGN: if partial: out.putVarInt32(value.ByteSizePartial()) value.OutputPartial(out) else: out.putVarInt32(value.ByteSize()) value.OutputUnchecked(out) else: Encoder._TYPE_TO_METHOD[ext.field_type](out, value) size = len(extensions) for ext_index in xrange(start_index, size): ext = extensions[ext_index] if ext.number >= end_field_number: return ext_index if ext.is_repeated: for i in xrange(len(self._extension_fields[ext])): OutputSingleField(ext, self._extension_fields[ext][i]) else: OutputSingleField(ext, self._extension_fields[ext]) return size def _ParseOneExtensionField(self, wire_tag, d): number = wire_tag >> 3 if number in self._extensions_by_field_number: ext = self._extensions_by_field_number[number] if wire_tag != ext.wire_tag: return if ext.field_type == TYPE_FOREIGN: length = d.getVarInt32() tmp = Decoder(d.buffer(), d.pos(), d.pos() + length) if ext.is_repeated: self.AddExtension(ext).TryMerge(tmp) else: self.MutableExtension(ext).TryMerge(tmp) d.skip(length) elif ext.field_type == TYPE_GROUP: if ext.is_repeated: self.AddExtension(ext).TryMerge(d) else: self.MutableExtension(ext).TryMerge(d) else: value = Decoder._TYPE_TO_METHOD[ext.field_type](d) if ext.is_repeated: self.AddExtension(ext, value) else: self.SetExtension(ext, value) else: d.skipData(wire_tag) def _ExtensionByteSize(self, partial): size = 0 for extension, value in self._extension_fields.items(): ftype = extension.field_type tag_size = self.lengthVarInt64(extension.wire_tag) if ftype == TYPE_GROUP: tag_size *= 2 if extension.is_repeated: size += tag_size * len(value) for single_value in value: size += self._FieldByteSize(ftype, single_value, partial) else: size += tag_size + self._FieldByteSize(ftype, value, partial) return size def _FieldByteSize(self, ftype, value, partial): size = 0 if ftype == TYPE_STRING: size = self.lengthString(len(value)) elif ftype == TYPE_FOREIGN or ftype == TYPE_GROUP: if partial: size = self.lengthString(value.ByteSizePartial()) else: size = self.lengthString(value.ByteSize()) elif ftype == TYPE_INT64 or ftype == TYPE_UINT64 or ftype == TYPE_INT32: size = self.lengthVarInt64(value) else: if ftype in Encoder._TYPE_TO_BYTE_SIZE: size = Encoder._TYPE_TO_BYTE_SIZE[ftype] else: raise AssertionError( 'Extension type %d is not recognized.' % ftype) return size def _ExtensionDebugString(self, prefix, printElemNumber): res = '' extensions = self._ListExtensions() for extension in extensions: value = self._extension_fields[extension] if extension.is_repeated: cnt = 0 for e in value: elm="" if printElemNumber: elm = "(%d)" % cnt if extension.composite_cls is not None: res += prefix + "[%s%s] {\n" % (extension.full_name, elm) res += e.__str__(prefix + " ", printElemNumber) res += prefix + "}\n" else: if extension.composite_cls is not None: res += prefix + "[%s] {\n" % extension.full_name res += value.__str__( prefix + " ", printElemNumber) res += prefix + "}\n" else: if extension.field_type in _TYPE_TO_DEBUG_STRING: text_value = _TYPE_TO_DEBUG_STRING[ extension.field_type](self, value) else: text_value = self.DebugFormat(value) res += prefix + "[%s]: %s\n" % (extension.full_name, text_value) return res @staticmethod def _RegisterExtension(cls, extension, composite_cls=None): extension.containing_cls = cls extension.composite_cls = composite_cls if composite_cls is not None: extension.message_name = composite_cls._PROTO_DESCRIPTOR_NAME actual_handle = cls._extensions_by_field_number.setdefault( extension.number, extension) if actual_handle is not extension: raise AssertionError( 'Extensions "%s" and "%s" both try to extend message type "%s" with' 'field number %d.' % (extension.full_name, actual_handle.full_name, cls.__name__, extension.number)) class ProtocolBufferDecodeError(Exception): pass class ProtocolBufferEncodeError(Exception): pass class ProtocolBufferReturnError(Exception): pass protorpc-standalone-0.9.1/protorpc/_google/net/proto/RawMessage.py0000644000076500000240000000440612277637135026335 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2007 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This is the Python counterpart to the RawMessage class defined in rawmessage.h. To use this, put the following line in your .proto file: python from google.net.proto.RawMessage import RawMessage """ __pychecker__ = 'no-callinit no-argsused' from google.net.proto import ProtocolBuffer class RawMessage(ProtocolBuffer.ProtocolMessage): """ This is a special subclass of ProtocolMessage that doesn't interpret its data in any way. Instead, it just stores it in a string. See rawmessage.h for more details. """ def __init__(self, initial=None): self.__contents = '' if initial is not None: self.MergeFromString(initial) def contents(self): return self.__contents def set_contents(self, contents): self.__contents = contents def Clear(self): self.__contents = '' def IsInitialized(self, debug_strs=None): return 1 def __str__(self, prefix="", printElemNumber=0): return prefix + self.DebugFormatString(self.__contents) def OutputUnchecked(self, e): e.putRawString(self.__contents) def OutputPartial(self, e): return self.OutputUnchecked(e) def TryMerge(self, d): self.__contents = d.getRawString() def MergeFrom(self, pb): assert pb is not self if pb.__class__ != self.__class__: return 0 self.__contents = pb.__contents return 1 def Equals(self, pb): return self.__contents == pb.__contents def __eq__(self, other): return (other is not None) and (other.__class__ == self.__class__) and self.Equals(other) def __ne__(self, other): return not (self == other) def ByteSize(self): return len(self.__contents) def ByteSizePartial(self): return self.ByteSize() protorpc-standalone-0.9.1/protorpc/definition.py0000755000076500000240000002230012277637135023057 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Stub library.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import new import sys import urllib2 from . import descriptor from . import message_types from . import messages from . import protobuf from . import remote from . import util __all__ = [ 'define_enum', 'define_field', 'define_file', 'define_message', 'define_service', 'import_file', 'import_file_set', ] # Map variant back to message field classes. def _build_variant_map(): """Map variants to fields. Returns: Dictionary mapping field variant to its associated field type. """ result = {} for name in dir(messages): value = getattr(messages, name) if isinstance(value, type) and issubclass(value, messages.Field): for variant in getattr(value, 'VARIANTS', []): result[variant] = value return result _VARIANT_MAP = _build_variant_map() _MESSAGE_TYPE_MAP = { message_types.DateTimeMessage.definition_name(): message_types.DateTimeField, } def _get_or_define_module(full_name, modules): """Helper method for defining new modules. Args: full_name: Fully qualified name of module to create or return. modules: Dictionary of all modules. Defaults to sys.modules. Returns: Named module if found in 'modules', else creates new module and inserts in 'modules'. Will also construct parent modules if necessary. """ module = modules.get(full_name) if not module: module = new.module(full_name) modules[full_name] = module split_name = full_name.rsplit('.', 1) if len(split_name) > 1: parent_module_name, sub_module_name = split_name parent_module = _get_or_define_module(parent_module_name, modules) setattr(parent_module, sub_module_name, module) return module def define_enum(enum_descriptor, module_name): """Define Enum class from descriptor. Args: enum_descriptor: EnumDescriptor to build Enum class from. module_name: Module name to give new descriptor class. Returns: New messages.Enum sub-class as described by enum_descriptor. """ enum_values = enum_descriptor.values or [] class_dict = dict((value.name, value.number) for value in enum_values) class_dict['__module__'] = module_name return type(str(enum_descriptor.name), (messages.Enum,), class_dict) def define_field(field_descriptor): """Define Field instance from descriptor. Args: field_descriptor: FieldDescriptor class to build field instance from. Returns: New field instance as described by enum_descriptor. """ field_class = _VARIANT_MAP[field_descriptor.variant] params = {'number': field_descriptor.number, 'variant': field_descriptor.variant, } if field_descriptor.label == descriptor.FieldDescriptor.Label.REQUIRED: params['required'] = True elif field_descriptor.label == descriptor.FieldDescriptor.Label.REPEATED: params['repeated'] = True message_type_field = _MESSAGE_TYPE_MAP.get(field_descriptor.type_name) if message_type_field: return message_type_field(**params) elif field_class in (messages.EnumField, messages.MessageField): return field_class(field_descriptor.type_name, **params) else: if field_descriptor.default_value: value = field_descriptor.default_value try: value = descriptor._DEFAULT_FROM_STRING_MAP[field_class](value) except (TypeError, ValueError, KeyError): pass # Let the value pass to the constructor. params['default'] = value return field_class(**params) def define_message(message_descriptor, module_name): """Define Message class from descriptor. Args: message_descriptor: MessageDescriptor to describe message class from. module_name: Module name to give to new descriptor class. Returns: New messages.Message sub-class as described by message_descriptor. """ class_dict = {'__module__': module_name} for enum in message_descriptor.enum_types or []: enum_instance = define_enum(enum, module_name) class_dict[enum.name] = enum_instance # TODO(rafek): support nested messages when supported by descriptor. for field in message_descriptor.fields or []: field_instance = define_field(field) class_dict[field.name] = field_instance class_name = message_descriptor.name.encode('utf-8') return type(class_name, (messages.Message,), class_dict) def define_service(service_descriptor, module): """Define a new service proxy. Args: service_descriptor: ServiceDescriptor class that describes the service. module: Module to add service to. Request and response types are found relative to this module. Returns: Service class proxy capable of communicating with a remote server. """ class_dict = {'__module__': module.__name__} class_name = service_descriptor.name.encode('utf-8') for method_descriptor in service_descriptor.methods or []: request_definition = messages.find_definition( method_descriptor.request_type, module) response_definition = messages.find_definition( method_descriptor.response_type, module) method_name = method_descriptor.name.encode('utf-8') def remote_method(self, request): """Actual service method.""" raise NotImplementedError('Method is not implemented') remote_method.__name__ = method_name remote_method_decorator = remote.method(request_definition, response_definition) class_dict[method_name] = remote_method_decorator(remote_method) service_class = type(class_name, (remote.Service,), class_dict) return service_class def define_file(file_descriptor, module=None): """Define module from FileDescriptor. Args: file_descriptor: FileDescriptor instance to describe module from. module: Module to add contained objects to. Module name overrides value in file_descriptor.package. Definitions are added to existing module if provided. Returns: If no module provided, will create a new module with its name set to the file descriptor's package. If a module is provided, returns the same module. """ if module is None: module = new.module(file_descriptor.package) for enum_descriptor in file_descriptor.enum_types or []: enum_class = define_enum(enum_descriptor, module.__name__) setattr(module, enum_descriptor.name, enum_class) for message_descriptor in file_descriptor.message_types or []: message_class = define_message(message_descriptor, module.__name__) setattr(module, message_descriptor.name, message_class) for service_descriptor in file_descriptor.service_types or []: service_class = define_service(service_descriptor, module) setattr(module, service_descriptor.name, service_class) return module @util.positional(1) def import_file(file_descriptor, modules=None): """Import FileDescriptor in to module space. This is like define_file except that a new module and any required parent modules are created and added to the modules parameter or sys.modules if not provided. Args: file_descriptor: FileDescriptor instance to describe module from. modules: Dictionary of modules to update. Modules and their parents that do not exist will be created. If an existing module is found that matches file_descriptor.package, that module is updated with the FileDescriptor contents. Returns: Module found in modules, else a new module. """ if not file_descriptor.package: raise ValueError('File descriptor must have package name') if modules is None: modules = sys.modules module = _get_or_define_module(file_descriptor.package.encode('utf-8'), modules) return define_file(file_descriptor, module) @util.positional(1) def import_file_set(file_set, modules=None, _open=open): """Import FileSet in to module space. Args: file_set: If string, open file and read serialized FileSet. Otherwise, a FileSet instance to import definitions from. modules: Dictionary of modules to update. Modules and their parents that do not exist will be created. If an existing module is found that matches file_descriptor.package, that module is updated with the FileDescriptor contents. _open: Used for dependency injection during tests. """ if isinstance(file_set, basestring): encoded_file = _open(file_set, 'rb') try: encoded_file_set = encoded_file.read() finally: encoded_file.close() file_set = protobuf.decode_message(descriptor.FileSet, encoded_file_set) for file_descriptor in file_set.files: # Do not reload built in protorpc classes. if not file_descriptor.package.startswith('protorpc.'): import_file(file_descriptor, modules=modules) protorpc-standalone-0.9.1/protorpc/definition_test.py0000755000076500000240000005563012277637135024132 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.stub.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import new import StringIO import sys import unittest from protorpc import definition from protorpc import descriptor from protorpc import message_types from protorpc import messages from protorpc import protobuf from protorpc import remote from protorpc import test_util import mox class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = definition class DefineEnumTest(test_util.TestCase): """Test for define_enum.""" def testDefineEnum_Empty(self): """Test defining an empty enum.""" enum_descriptor = descriptor.EnumDescriptor() enum_descriptor.name = 'Empty' enum_class = definition.define_enum(enum_descriptor, 'whatever') self.assertEquals('Empty', enum_class.__name__) self.assertEquals('whatever', enum_class.__module__) self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class)) def testDefineEnum(self): """Test defining an enum.""" red = descriptor.EnumValueDescriptor() green = descriptor.EnumValueDescriptor() blue = descriptor.EnumValueDescriptor() red.name = 'RED' red.number = 1 green.name = 'GREEN' green.number = 2 blue.name = 'BLUE' blue.number = 3 enum_descriptor = descriptor.EnumDescriptor() enum_descriptor.name = 'Colors' enum_descriptor.values = [red, green, blue] enum_class = definition.define_enum(enum_descriptor, 'whatever') self.assertEquals('Colors', enum_class.__name__) self.assertEquals('whatever', enum_class.__module__) self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class)) class DefineFieldTest(test_util.TestCase): """Test for define_field.""" def testDefineField_Optional(self): """Test defining an optional field instance from a method descriptor.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT32 field_descriptor.label = descriptor.FieldDescriptor.Label.OPTIONAL field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, messages.IntegerField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.INT32, field.variant) self.assertFalse(field.required) self.assertFalse(field.repeated) def testDefineField_Required(self): """Test defining a required field instance from a method descriptor.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING field_descriptor.label = descriptor.FieldDescriptor.Label.REQUIRED field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, messages.StringField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant) self.assertTrue(field.required) self.assertFalse(field.repeated) def testDefineField_Repeated(self): """Test defining a repeated field instance from a method descriptor.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.DOUBLE field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, messages.FloatField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.DOUBLE, field.variant) self.assertFalse(field.required) self.assertTrue(field.repeated) def testDefineField_Message(self): """Test defining a message field.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE field_descriptor.type_name = 'something.yet.to.be.Defined' field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, messages.MessageField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant) self.assertFalse(field.required) self.assertTrue(field.repeated) self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError, 'Could not find definition for ' 'something.yet.to.be.Defined', getattr, field, 'type') def testDefineField_DateTime(self): """Test defining a date time field.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_timestamp' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE field_descriptor.type_name = 'protorpc.message_types.DateTimeMessage' field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, message_types.DateTimeField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant) self.assertFalse(field.required) self.assertTrue(field.repeated) def testDefineField_Enum(self): """Test defining an enum field.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.ENUM field_descriptor.type_name = 'something.yet.to.be.Defined' field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, messages.EnumField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.ENUM, field.variant) self.assertFalse(field.required) self.assertTrue(field.repeated) self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError, 'Could not find definition for ' 'something.yet.to.be.Defined', getattr, field, 'type') def testDefineField_Default_Bool(self): """Test defining a default value for a bool.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.BOOL field_descriptor.default_value = u'true' field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, messages.BooleanField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.BOOL, field.variant) self.assertFalse(field.required) self.assertFalse(field.repeated) self.assertEqual(field.default, True) field_descriptor.default_value = u'false' field = definition.define_field(field_descriptor) self.assertEqual(field.default, False) def testDefineField_Default_Float(self): """Test defining a default value for a float.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.FLOAT field_descriptor.default_value = u'34.567' field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, messages.FloatField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.FLOAT, field.variant) self.assertFalse(field.required) self.assertFalse(field.repeated) self.assertEqual(field.default, 34.567) def testDefineField_Default_Int(self): """Test defining a default value for an int.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64 field_descriptor.default_value = u'34' field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, messages.IntegerField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.INT64, field.variant) self.assertFalse(field.required) self.assertFalse(field.repeated) self.assertEqual(field.default, 34) def testDefineField_Default_Str(self): """Test defining a default value for a str.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING field_descriptor.default_value = u'Test' field = definition.define_field(field_descriptor) # Name will not be set from the original descriptor. self.assertFalse(hasattr(field, 'name')) self.assertTrue(isinstance(field, messages.StringField)) self.assertEquals(1, field.number) self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant) self.assertFalse(field.required) self.assertFalse(field.repeated) self.assertEqual(field.default, u'Test') def testDefineField_Default_Invalid(self): """Test defining a default value that is not valid.""" field_descriptor = descriptor.FieldDescriptor() field_descriptor.name = 'a_field' field_descriptor.number = 1 field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64 field_descriptor.default_value = u'Test' # Verify that the string is passed to the Constructor. mock = mox.Mox() mock.StubOutWithMock(messages.IntegerField, '__init__') messages.IntegerField.__init__( default=u'Test', number=1, variant=messages.Variant.INT64 ).AndRaise(messages.InvalidDefaultError) mock.ReplayAll() with self.assertRaises(messages.InvalidDefaultError): _ = definition.define_field(field_descriptor) mock.VerifyAll() mock.ResetAll() mock.UnsetStubs() class DefineMessageTest(test_util.TestCase): """Test for define_message.""" def testDefineMessageEmpty(self): """Test definition a message with no fields or enums.""" class AMessage(messages.Message): pass message_descriptor = descriptor.describe_message(AMessage) message_class = definition.define_message(message_descriptor, '__main__') self.assertEquals('AMessage', message_class.__name__) self.assertEquals('__main__', message_class.__module__) self.assertEquals(message_descriptor, descriptor.describe_message(message_class)) def testDefineMessageEnumOnly(self): """Test definition a message with only enums.""" class AMessage(messages.Message): class NestedEnum(messages.Enum): pass message_descriptor = descriptor.describe_message(AMessage) message_class = definition.define_message(message_descriptor, '__main__') self.assertEquals('AMessage', message_class.__name__) self.assertEquals('__main__', message_class.__module__) self.assertEquals(message_descriptor, descriptor.describe_message(message_class)) def testDefineMessageFieldsOnly(self): """Test definition a message with only fields.""" class AMessage(messages.Message): field1 = messages.IntegerField(1) field2 = messages.StringField(2) message_descriptor = descriptor.describe_message(AMessage) message_class = definition.define_message(message_descriptor, '__main__') self.assertEquals('AMessage', message_class.__name__) self.assertEquals('__main__', message_class.__module__) self.assertEquals(message_descriptor, descriptor.describe_message(message_class)) def testDefineMessage(self): """Test defining Message class from descriptor.""" class AMessage(messages.Message): class NestedEnum(messages.Enum): pass field1 = messages.IntegerField(1) field2 = messages.StringField(2) message_descriptor = descriptor.describe_message(AMessage) message_class = definition.define_message(message_descriptor, '__main__') self.assertEquals('AMessage', message_class.__name__) self.assertEquals('__main__', message_class.__module__) self.assertEquals(message_descriptor, descriptor.describe_message(message_class)) class DefineServiceTest(test_util.TestCase): """Test service proxy definition.""" def setUp(self): """Set up mock and request classes.""" self.module = new.module('stocks') class GetQuoteRequest(messages.Message): __module__ = 'stocks' symbols = messages.StringField(1, repeated=True) class GetQuoteResponse(messages.Message): __module__ = 'stocks' prices = messages.IntegerField(1, repeated=True) self.module.GetQuoteRequest = GetQuoteRequest self.module.GetQuoteResponse = GetQuoteResponse def testDefineService(self): """Test service definition from descriptor.""" method_descriptor = descriptor.MethodDescriptor() method_descriptor.name = 'get_quote' method_descriptor.request_type = 'GetQuoteRequest' method_descriptor.response_type = 'GetQuoteResponse' service_descriptor = descriptor.ServiceDescriptor() service_descriptor.name = 'Stocks' service_descriptor.methods = [method_descriptor] StockService = definition.define_service(service_descriptor, self.module) self.assertTrue(issubclass(StockService, remote.Service)) self.assertTrue(issubclass(StockService.Stub, remote.StubBase)) request = self.module.GetQuoteRequest() service = StockService() self.assertRaises(NotImplementedError, service.get_quote, request) self.assertEquals(self.module.GetQuoteRequest, service.get_quote.remote.request_type) self.assertEquals(self.module.GetQuoteResponse, service.get_quote.remote.response_type) class ModuleTest(test_util.TestCase): """Test for module creation and importation functions.""" def MakeFileDescriptor(self, package): """Helper method to construct FileDescriptors. Creates FileDescriptor with a MessageDescriptor and an EnumDescriptor. Args: package: Package name to give new file descriptors. Returns: New FileDescriptor instance. """ enum_descriptor = descriptor.EnumDescriptor() enum_descriptor.name = u'MyEnum' message_descriptor = descriptor.MessageDescriptor() message_descriptor.name = u'MyMessage' service_descriptor = descriptor.ServiceDescriptor() service_descriptor.name = u'MyService' file_descriptor = descriptor.FileDescriptor() file_descriptor.package = package file_descriptor.enum_types = [enum_descriptor] file_descriptor.message_types = [message_descriptor] file_descriptor.service_types = [service_descriptor] return file_descriptor def testDefineModule(self): """Test define_module function.""" file_descriptor = self.MakeFileDescriptor('my.package') module = definition.define_file(file_descriptor) self.assertEquals('my.package', module.__name__) self.assertEquals('my.package', module.MyEnum.__module__) self.assertEquals('my.package', module.MyMessage.__module__) self.assertEquals('my.package', module.MyService.__module__) self.assertEquals(file_descriptor, descriptor.describe_file(module)) def testDefineModule_ReuseModule(self): """Test updating module with additional definitions.""" file_descriptor = self.MakeFileDescriptor('my.package') module = new.module('override') self.assertEquals(module, definition.define_file(file_descriptor, module)) self.assertEquals('override', module.MyEnum.__module__) self.assertEquals('override', module.MyMessage.__module__) self.assertEquals('override', module.MyService.__module__) # One thing is different between original descriptor and new. file_descriptor.package = 'override' self.assertEquals(file_descriptor, descriptor.describe_file(module)) def testImportFile(self): """Test importing FileDescriptor in to module space.""" modules = {} file_descriptor = self.MakeFileDescriptor('standalone') definition.import_file(file_descriptor, modules=modules) self.assertEquals(file_descriptor, descriptor.describe_file(modules['standalone'])) def testImportFile_InToExisting(self): """Test importing FileDescriptor in to existing module.""" module = new.module('standalone') modules = {'standalone': module} file_descriptor = self.MakeFileDescriptor('standalone') definition.import_file(file_descriptor, modules=modules) self.assertEquals(module, modules['standalone']) self.assertEquals(file_descriptor, descriptor.describe_file(modules['standalone'])) def testImportFile_InToGlobalModules(self): """Test importing FileDescriptor in to global modules.""" original_modules = sys.modules try: sys.modules = dict(sys.modules) if 'standalone' in sys.modules: del sys.modules['standalone'] file_descriptor = self.MakeFileDescriptor('standalone') definition.import_file(file_descriptor) self.assertEquals(file_descriptor, descriptor.describe_file(sys.modules['standalone'])) finally: sys.modules = original_modules def testImportFile_Nested(self): """Test importing FileDescriptor in to existing nested module.""" modules = {} file_descriptor = self.MakeFileDescriptor('root.nested') definition.import_file(file_descriptor, modules=modules) self.assertEquals(modules['root'].nested, modules['root.nested']) self.assertEquals(file_descriptor, descriptor.describe_file(modules['root.nested'])) def testImportFile_NoPackage(self): """Test importing FileDescriptor with no package.""" file_descriptor = self.MakeFileDescriptor('does not matter') file_descriptor.reset('package') self.assertRaisesWithRegexpMatch(ValueError, 'File descriptor must have package name', definition.import_file, file_descriptor) def testImportFileSet(self): """Test importing a whole file set.""" file_set = descriptor.FileSet() file_set.files = [self.MakeFileDescriptor(u'standalone'), self.MakeFileDescriptor(u'root.nested'), self.MakeFileDescriptor(u'root.nested.nested'), ] root = new.module('root') nested = new.module('root.nested') root.nested = nested modules = { 'root': root, 'root.nested': nested, } definition.import_file_set(file_set, modules=modules) self.assertEquals(root, modules['root']) self.assertEquals(nested, modules['root.nested']) self.assertEquals(nested.nested, modules['root.nested.nested']) self.assertEquals(file_set, descriptor.describe_file_set( [modules['standalone'], modules['root.nested'], modules['root.nested.nested'], ])) def testImportFileSetFromFile(self): """Test importing a whole file set from a file.""" file_set = descriptor.FileSet() file_set.files = [self.MakeFileDescriptor(u'standalone'), self.MakeFileDescriptor(u'root.nested'), self.MakeFileDescriptor(u'root.nested.nested'), ] stream = StringIO.StringIO(protobuf.encode_message(file_set)) self.mox = mox.Mox() opener = self.mox.CreateMockAnything() opener('my-file.dat', 'rb').AndReturn(stream) self.mox.ReplayAll() modules = {} definition.import_file_set('my-file.dat', modules=modules, _open=opener) self.assertEquals(file_set, descriptor.describe_file_set( [modules['standalone'], modules['root.nested'], modules['root.nested.nested'], ])) def testImportBuiltInProtorpcClasses(self): """Test that built in Protorpc classes are skipped.""" file_set = descriptor.FileSet() file_set.files = [self.MakeFileDescriptor(u'standalone'), self.MakeFileDescriptor(u'root.nested'), self.MakeFileDescriptor(u'root.nested.nested'), descriptor.describe_file(descriptor), ] root = new.module('root') nested = new.module('root.nested') root.nested = nested modules = { 'root': root, 'root.nested': nested, 'protorpc.descriptor': descriptor, } definition.import_file_set(file_set, modules=modules) self.assertEquals(root, modules['root']) self.assertEquals(nested, modules['root.nested']) self.assertEquals(nested.nested, modules['root.nested.nested']) self.assertEquals(descriptor, modules['protorpc.descriptor']) self.assertEquals(file_set, descriptor.describe_file_set( [modules['standalone'], modules['root.nested'], modules['root.nested.nested'], modules['protorpc.descriptor'], ])) if __name__ == '__main__': unittest.main() protorpc-standalone-0.9.1/protorpc/descriptor.py0000755000076500000240000005240112277637135023112 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Services descriptor definitions. Contains message definitions and functions for converting service classes into transmittable message format. Describing an Enum instance, Enum class, Field class or Message class will generate an appropriate descriptor object that describes that class. This message can itself be used to transmit information to clients wishing to know the description of an enum value, enum, field or message without needing to download the source code. This format is also compatible with other, non-Python languages. The descriptors are modeled to be binary compatible with: http://code.google.com/p/protobuf/source/browse/trunk/src/google/protobuf/descriptor.proto NOTE: The names of types and fields are not always the same between these descriptors and the ones defined in descriptor.proto. This was done in order to make source code files that use these descriptors easier to read. For example, it is not necessary to prefix TYPE to all the values in FieldDescriptor.Variant as is done in descriptor.proto FieldDescriptorProto.Type. Example: class Pixel(messages.Message): x = messages.IntegerField(1, required=True) y = messages.IntegerField(2, required=True) color = messages.BytesField(3) # Describe Pixel class using message descriptor. fields = [] field = FieldDescriptor() field.name = 'x' field.number = 1 field.label = FieldDescriptor.Label.REQUIRED field.variant = FieldDescriptor.Variant.INT64 fields.append(field) field = FieldDescriptor() field.name = 'y' field.number = 2 field.label = FieldDescriptor.Label.REQUIRED field.variant = FieldDescriptor.Variant.INT64 fields.append(field) field = FieldDescriptor() field.name = 'color' field.number = 3 field.label = FieldDescriptor.Label.OPTIONAL field.variant = FieldDescriptor.Variant.BYTES fields.append(field) message = MessageDescriptor() message.name = 'Pixel' message.fields = fields # Describing is the equivalent of building the above message. message == describe_message(Pixel) Public Classes: EnumValueDescriptor: Describes Enum values. EnumDescriptor: Describes Enum classes. FieldDescriptor: Describes field instances. FileDescriptor: Describes a single 'file' unit. FileSet: Describes a collection of file descriptors. MessageDescriptor: Describes Message classes. MethodDescriptor: Describes a method of a service. ServiceDescriptor: Describes a services. Public Functions: describe_enum_value: Describe an individual enum-value. describe_enum: Describe an Enum class. describe_field: Describe a Field definition. describe_file: Describe a 'file' unit from a Python module or object. describe_file_set: Describe a file set from a list of modules or objects. describe_message: Describe a Message definition. describe_method: Describe a Method definition. describe_service: Describe a Service definition. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import codecs import types from . import messages from . import util __all__ = ['EnumDescriptor', 'EnumValueDescriptor', 'FieldDescriptor', 'MessageDescriptor', 'MethodDescriptor', 'FileDescriptor', 'FileSet', 'ServiceDescriptor', 'DescriptorLibrary', 'describe_enum', 'describe_enum_value', 'describe_field', 'describe_message', 'describe_method', 'describe_file', 'describe_file_set', 'describe_service', 'describe', 'import_descriptor_loader', ] # NOTE: MessageField is missing because message fields cannot have # a default value at this time. # TODO(rafek): Support default message values. # # Map to functions that convert default values of fields of a given type # to a string. The function must return a value that is compatible with # FieldDescriptor.default_value and therefore a unicode string. _DEFAULT_TO_STRING_MAP = { messages.IntegerField: unicode, messages.FloatField: unicode, messages.BooleanField: lambda value: value and u'true' or u'false', messages.BytesField: lambda value: codecs.escape_encode(value)[0], messages.StringField: lambda value: value, messages.EnumField: lambda value: unicode(value.number), } _DEFAULT_FROM_STRING_MAP = { messages.IntegerField: int, messages.FloatField: float, messages.BooleanField: lambda value: value == u'true', messages.BytesField: lambda value: codecs.escape_decode(value)[0], messages.StringField: lambda value: value, messages.EnumField: int, } class EnumValueDescriptor(messages.Message): """Enum value descriptor. Fields: name: Name of enumeration value. number: Number of enumeration value. """ # TODO(rafek): Why are these listed as optional in descriptor.proto. # Harmonize? name = messages.StringField(1, required=True) number = messages.IntegerField(2, required=True, variant=messages.Variant.INT32) class EnumDescriptor(messages.Message): """Enum class descriptor. Fields: name: Name of Enum without any qualification. values: Values defined by Enum class. """ name = messages.StringField(1) values = messages.MessageField(EnumValueDescriptor, 2, repeated=True) class FieldDescriptor(messages.Message): """Field definition descriptor. Enums: Variant: Wire format hint sub-types for field. Label: Values for optional, required and repeated fields. Fields: name: Name of field. number: Number of field. variant: Variant of field. type_name: Type name for message and enum fields. default_value: String representation of default value. """ Variant = messages.Variant class Label(messages.Enum): """Field label.""" OPTIONAL = 1 REQUIRED = 2 REPEATED = 3 name = messages.StringField(1, required=True) number = messages.IntegerField(3, required=True, variant=messages.Variant.INT32) label = messages.EnumField(Label, 4, default=Label.OPTIONAL) variant = messages.EnumField(Variant, 5) type_name = messages.StringField(6) # For numeric types, contains the original text representation of the value. # For booleans, "true" or "false". # For strings, contains the default text contents (not escaped in any way). # For bytes, contains the C escaped value. All bytes < 128 are that are # traditionally considered unprintable are also escaped. default_value = messages.StringField(7) class MessageDescriptor(messages.Message): """Message definition descriptor. Fields: name: Name of Message without any qualification. fields: Fields defined for message. message_types: Nested Message classes defined on message. enum_types: Nested Enum classes defined on message. """ name = messages.StringField(1) fields = messages.MessageField(FieldDescriptor, 2, repeated=True) message_types = messages.MessageField( 'protorpc.descriptor.MessageDescriptor', 3, repeated=True) enum_types = messages.MessageField(EnumDescriptor, 4, repeated=True) class MethodDescriptor(messages.Message): """Service method definition descriptor. Fields: name: Name of service method. request_type: Fully qualified or relative name of request message type. response_type: Fully qualified or relative name of response message type. """ name = messages.StringField(1) request_type = messages.StringField(2) response_type = messages.StringField(3) class ServiceDescriptor(messages.Message): """Service definition descriptor. Fields: name: Name of Service without any qualification. methods: Remote methods of Service. """ name = messages.StringField(1) methods = messages.MessageField(MethodDescriptor, 2, repeated=True) class FileDescriptor(messages.Message): """Description of file containing protobuf definitions. Fields: package: Fully qualified name of package that definitions belong to. message_types: Message definitions contained in file. enum_types: Enum definitions contained in file. service_types: Service definitions contained in file. """ package = messages.StringField(2) # TODO(rafek): Add dependency field message_types = messages.MessageField(MessageDescriptor, 4, repeated=True) enum_types = messages.MessageField(EnumDescriptor, 5, repeated=True) service_types = messages.MessageField(ServiceDescriptor, 6, repeated=True) class FileSet(messages.Message): """A collection of FileDescriptors. Fields: files: Files in file-set. """ files = messages.MessageField(FileDescriptor, 1, repeated=True) def describe_enum_value(enum_value): """Build descriptor for Enum instance. Args: enum_value: Enum value to provide descriptor for. Returns: Initialized EnumValueDescriptor instance describing the Enum instance. """ enum_value_descriptor = EnumValueDescriptor() enum_value_descriptor.name = unicode(enum_value.name) enum_value_descriptor.number = enum_value.number return enum_value_descriptor def describe_enum(enum_definition): """Build descriptor for Enum class. Args: enum_definition: Enum class to provide descriptor for. Returns: Initialized EnumDescriptor instance describing the Enum class. """ enum_descriptor = EnumDescriptor() enum_descriptor.name = enum_definition.definition_name().split('.')[-1] values = [] for number in enum_definition.numbers(): value = enum_definition.lookup_by_number(number) values.append(describe_enum_value(value)) if values: enum_descriptor.values = values return enum_descriptor def describe_field(field_definition): """Build descriptor for Field instance. Args: field_definition: Field instance to provide descriptor for. Returns: Initialized FieldDescriptor instance describing the Field instance. """ field_descriptor = FieldDescriptor() field_descriptor.name = field_definition.name field_descriptor.number = field_definition.number field_descriptor.variant = field_definition.variant if isinstance(field_definition, messages.EnumField): field_descriptor.type_name = field_definition.type.definition_name() if isinstance(field_definition, messages.MessageField): field_descriptor.type_name = field_definition.message_type.definition_name() if field_definition.default is not None: field_descriptor.default_value = _DEFAULT_TO_STRING_MAP[ type(field_definition)](field_definition.default) # Set label. if field_definition.repeated: field_descriptor.label = FieldDescriptor.Label.REPEATED elif field_definition.required: field_descriptor.label = FieldDescriptor.Label.REQUIRED else: field_descriptor.label = FieldDescriptor.Label.OPTIONAL return field_descriptor def describe_message(message_definition): """Build descriptor for Message class. Args: message_definition: Message class to provide descriptor for. Returns: Initialized MessageDescriptor instance describing the Message class. """ message_descriptor = MessageDescriptor() message_descriptor.name = message_definition.definition_name().split('.')[-1] fields = sorted(message_definition.all_fields(), key=lambda v: v.number) if fields: message_descriptor.fields = [describe_field(field) for field in fields] try: nested_messages = message_definition.__messages__ except AttributeError: pass else: message_descriptors = [] for name in nested_messages: value = getattr(message_definition, name) message_descriptors.append(describe_message(value)) message_descriptor.message_types = message_descriptors try: nested_enums = message_definition.__enums__ except AttributeError: pass else: enum_descriptors = [] for name in nested_enums: value = getattr(message_definition, name) enum_descriptors.append(describe_enum(value)) message_descriptor.enum_types = enum_descriptors return message_descriptor def describe_method(method): """Build descriptor for service method. Args: method: Remote service method to describe. Returns: Initialized MethodDescriptor instance describing the service method. """ method_info = method.remote descriptor = MethodDescriptor() descriptor.name = method_info.method.func_name descriptor.request_type = method_info.request_type.definition_name() descriptor.response_type = method_info.response_type.definition_name() return descriptor def describe_service(service_class): """Build descriptor for service. Args: service_class: Service class to describe. Returns: Initialized ServiceDescriptor instance describing the service. """ descriptor = ServiceDescriptor() descriptor.name = service_class.__name__ methods = [] remote_methods = service_class.all_remote_methods() for name in sorted(remote_methods.iterkeys()): if name == 'get_descriptor': continue method = remote_methods[name] methods.append(describe_method(method)) if methods: descriptor.methods = methods return descriptor def describe_file(module): """Build a file from a specified Python module. Args: module: Python module to describe. Returns: Initialized FileDescriptor instance describing the module. """ # May not import remote at top of file because remote depends on this # file # TODO(rafek): Straighten out this dependency. Possibly move these functions # from descriptor to their own module. from . import remote descriptor = FileDescriptor() descriptor.package = util.get_package_for_module(module) if not descriptor.package: descriptor.package = None message_descriptors = [] enum_descriptors = [] service_descriptors = [] # Need to iterate over all top level attributes of the module looking for # message, enum and service definitions. Each definition must be itself # described. for name in sorted(dir(module)): value = getattr(module, name) if isinstance(value, type): if issubclass(value, messages.Message): message_descriptors.append(describe_message(value)) elif issubclass(value, messages.Enum): enum_descriptors.append(describe_enum(value)) elif issubclass(value, remote.Service): service_descriptors.append(describe_service(value)) if message_descriptors: descriptor.message_types = message_descriptors if enum_descriptors: descriptor.enum_types = enum_descriptors if service_descriptors: descriptor.service_types = service_descriptors return descriptor def describe_file_set(modules): """Build a file set from a specified Python modules. Args: modules: Iterable of Python module to describe. Returns: Initialized FileSet instance describing the modules. """ descriptor = FileSet() file_descriptors = [] for module in modules: file_descriptors.append(describe_file(module)) if file_descriptors: descriptor.files = file_descriptors return descriptor def describe(value): """Describe any value as a descriptor. Helper function for describing any object with an appropriate descriptor object. Args: value: Value to describe as a descriptor. Returns: Descriptor message class if object is describable as a descriptor, else None. """ from . import remote if isinstance(value, types.ModuleType): return describe_file(value) elif callable(value) and hasattr(value, 'remote'): return describe_method(value) elif isinstance(value, messages.Field): return describe_field(value) elif isinstance(value, messages.Enum): return describe_enum_value(value) elif isinstance(value, type): if issubclass(value, messages.Message): return describe_message(value) elif issubclass(value, messages.Enum): return describe_enum(value) elif issubclass(value, remote.Service): return describe_service(value) return None @util.positional(1) def import_descriptor_loader(definition_name, importer=__import__): """Find objects by importing modules as needed. A definition loader is a function that resolves a definition name to a descriptor. The import finder resolves definitions to their names by importing modules when necessary. Args: definition_name: Name of definition to find. importer: Import function used for importing new modules. Returns: Appropriate descriptor for any describable type located by name. Raises: DefinitionNotFoundError when a name does not refer to either a definition or a module. """ # Attempt to import descriptor as a module. if definition_name.startswith('.'): definition_name = definition_name[1:] if not definition_name.startswith('.'): leaf = definition_name.split('.')[-1] if definition_name: try: module = importer(definition_name, '', '', [leaf]) except ImportError: pass else: return describe(module) try: # Attempt to use messages.find_definition to find item. return describe(messages.find_definition(definition_name, importer=__import__)) except messages.DefinitionNotFoundError, err: # There are things that find_definition will not find, but if the parent # is loaded, its children can be searched for a match. split_name = definition_name.rsplit('.', 1) if len(split_name) > 1: parent, child = split_name try: parent_definition = import_descriptor_loader(parent, importer=importer) except messages.DefinitionNotFoundError: # Fall through to original error. pass else: # Check the parent definition for a matching descriptor. if isinstance(parent_definition, FileDescriptor): search_list = parent_definition.service_types or [] elif isinstance(parent_definition, ServiceDescriptor): search_list = parent_definition.methods or [] elif isinstance(parent_definition, EnumDescriptor): search_list = parent_definition.values or [] elif isinstance(parent_definition, MessageDescriptor): search_list = parent_definition.fields or [] else: search_list = [] for definition in search_list: if definition.name == child: return definition # Still didn't find. Reraise original exception. raise err class DescriptorLibrary(object): """A descriptor library is an object that contains known definitions. A descriptor library contains a cache of descriptor objects mapped by definition name. It contains all types of descriptors except for file sets. When a definition name is requested that the library does not know about it can be provided with a descriptor loader which attempt to resolve the missing descriptor. """ @util.positional(1) def __init__(self, descriptors=None, descriptor_loader=import_descriptor_loader): """Constructor. Args: descriptors: A dictionary or dictionary-like object that can be used to store and cache descriptors by definition name. definition_loader: A function used for resolving missing descriptors. The function takes a definition name as its parameter and returns an appropriate descriptor. It may raise DefinitionNotFoundError. """ self.__descriptor_loader = descriptor_loader self.__descriptors = descriptors or {} def lookup_descriptor(self, definition_name): """Lookup descriptor by name. Get descriptor from library by name. If descriptor is not found will attempt to find via descriptor loader if provided. Args: definition_name: Definition name to find. Returns: Descriptor that describes definition name. Raises: DefinitionNotFoundError if not descriptor exists for definition name. """ try: return self.__descriptors[definition_name] except KeyError: pass if self.__descriptor_loader: definition = self.__descriptor_loader(definition_name) self.__descriptors[definition_name] = definition return definition else: raise messages.DefinitionNotFoundError( 'Could not find definition for %s' % definition_name) def lookup_package(self, definition_name): """Determines the package name for any definition. Determine the package that any definition name belongs to. May check parent for package name and will resolve missing descriptors if provided descriptor loader. Args: definition_name: Definition name to find package for. """ while True: descriptor = self.lookup_descriptor(definition_name) if isinstance(descriptor, FileDescriptor): return descriptor.package else: index = definition_name.rfind('.') if index < 0: return None definition_name = definition_name[:index] protorpc-standalone-0.9.1/protorpc/descriptor_test.py0000755000076500000240000004632112277637135024155 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.descriptor.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import new import unittest from protorpc import descriptor from protorpc import message_types from protorpc import messages from protorpc import registry from protorpc import remote from protorpc import test_util RUSSIA = u'\u0420\u043e\u0441\u0441\u0438\u044f' class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = descriptor class DescribeEnumValueTest(test_util.TestCase): def testDescribe(self): class MyEnum(messages.Enum): MY_NAME = 10 expected = descriptor.EnumValueDescriptor() expected.name = 'MY_NAME' expected.number = 10 described = descriptor.describe_enum_value(MyEnum.MY_NAME) described.check_initialized() self.assertEquals(expected, described) class DescribeEnumTest(test_util.TestCase): def testEmptyEnum(self): class EmptyEnum(messages.Enum): pass expected = descriptor.EnumDescriptor() expected.name = 'EmptyEnum' described = descriptor.describe_enum(EmptyEnum) described.check_initialized() self.assertEquals(expected, described) def testNestedEnum(self): class MyScope(messages.Message): class NestedEnum(messages.Enum): pass expected = descriptor.EnumDescriptor() expected.name = 'NestedEnum' described = descriptor.describe_enum(MyScope.NestedEnum) described.check_initialized() self.assertEquals(expected, described) def testEnumWithItems(self): class EnumWithItems(messages.Enum): A = 3 B = 1 C = 2 expected = descriptor.EnumDescriptor() expected.name = 'EnumWithItems' a = descriptor.EnumValueDescriptor() a.name = 'A' a.number = 3 b = descriptor.EnumValueDescriptor() b.name = 'B' b.number = 1 c = descriptor.EnumValueDescriptor() c.name = 'C' c.number = 2 expected.values = [b, c, a] described = descriptor.describe_enum(EnumWithItems) described.check_initialized() self.assertEquals(expected, described) class DescribeFieldTest(test_util.TestCase): def testLabel(self): for repeated, required, expected_label in ( (True, False, descriptor.FieldDescriptor.Label.REPEATED), (False, True, descriptor.FieldDescriptor.Label.REQUIRED), (False, False, descriptor.FieldDescriptor.Label.OPTIONAL)): field = messages.IntegerField(10, required=required, repeated=repeated) field.name = 'a_field' expected = descriptor.FieldDescriptor() expected.name = 'a_field' expected.number = 10 expected.label = expected_label expected.variant = descriptor.FieldDescriptor.Variant.INT64 described = descriptor.describe_field(field) described.check_initialized() self.assertEquals(expected, described) def testDefault(self): for field_class, default, expected_default in ( (messages.IntegerField, 200, '200'), (messages.FloatField, 1.5, '1.5'), (messages.FloatField, 1e6, '1000000.0'), (messages.BooleanField, True, 'true'), (messages.BooleanField, False, 'false'), (messages.BytesField, 'ab\xF1', 'ab\\xf1'), (messages.StringField, RUSSIA, RUSSIA), ): field = field_class(10, default=default) field.name = u'a_field' expected = descriptor.FieldDescriptor() expected.name = u'a_field' expected.number = 10 expected.label = descriptor.FieldDescriptor.Label.OPTIONAL expected.variant = field_class.DEFAULT_VARIANT expected.default_value = expected_default described = descriptor.describe_field(field) described.check_initialized() self.assertEquals(expected, described) def testDefault_EnumField(self): class MyEnum(messages.Enum): VAL = 1 module_name = test_util.get_module_name(MyEnum) field = messages.EnumField(MyEnum, 10, default=MyEnum.VAL) field.name = 'a_field' expected = descriptor.FieldDescriptor() expected.name = 'a_field' expected.number = 10 expected.label = descriptor.FieldDescriptor.Label.OPTIONAL expected.variant = messages.EnumField.DEFAULT_VARIANT expected.type_name = '%s.MyEnum' % module_name expected.default_value = '1' described = descriptor.describe_field(field) self.assertEquals(expected, described) def testMessageField(self): field = messages.MessageField(descriptor.FieldDescriptor, 10) field.name = 'a_field' expected = descriptor.FieldDescriptor() expected.name = 'a_field' expected.number = 10 expected.label = descriptor.FieldDescriptor.Label.OPTIONAL expected.variant = messages.MessageField.DEFAULT_VARIANT expected.type_name = ('protorpc.descriptor.FieldDescriptor') described = descriptor.describe_field(field) described.check_initialized() self.assertEquals(expected, described) def testDateTimeField(self): field = message_types.DateTimeField(20) field.name = 'a_timestamp' expected = descriptor.FieldDescriptor() expected.name = 'a_timestamp' expected.number = 20 expected.label = descriptor.FieldDescriptor.Label.OPTIONAL expected.variant = messages.MessageField.DEFAULT_VARIANT expected.type_name = ('protorpc.message_types.DateTimeMessage') described = descriptor.describe_field(field) described.check_initialized() self.assertEquals(expected, described) class DescribeMessageTest(test_util.TestCase): def testEmptyDefinition(self): class MyMessage(messages.Message): pass expected = descriptor.MessageDescriptor() expected.name = 'MyMessage' described = descriptor.describe_message(MyMessage) described.check_initialized() self.assertEquals(expected, described) def testDefinitionWithFields(self): class MessageWithFields(messages.Message): field1 = messages.IntegerField(10) field2 = messages.StringField(30) field3 = messages.IntegerField(20) expected = descriptor.MessageDescriptor() expected.name = 'MessageWithFields' expected.fields = [ descriptor.describe_field(MessageWithFields.field_by_name('field1')), descriptor.describe_field(MessageWithFields.field_by_name('field3')), descriptor.describe_field(MessageWithFields.field_by_name('field2')), ] described = descriptor.describe_message(MessageWithFields) described.check_initialized() self.assertEquals(expected, described) def testNestedEnum(self): class MessageWithEnum(messages.Message): class Mood(messages.Enum): GOOD = 1 BAD = 2 UGLY = 3 class Music(messages.Enum): CLASSIC = 1 JAZZ = 2 BLUES = 3 expected = descriptor.MessageDescriptor() expected.name = 'MessageWithEnum' expected.enum_types = [descriptor.describe_enum(MessageWithEnum.Mood), descriptor.describe_enum(MessageWithEnum.Music)] described = descriptor.describe_message(MessageWithEnum) described.check_initialized() self.assertEquals(expected, described) def testNestedMessage(self): class MessageWithMessage(messages.Message): class Nesty(messages.Message): pass expected = descriptor.MessageDescriptor() expected.name = 'MessageWithMessage' expected.message_types = [ descriptor.describe_message(MessageWithMessage.Nesty)] described = descriptor.describe_message(MessageWithMessage) described.check_initialized() self.assertEquals(expected, described) class DescribeMethodTest(test_util.TestCase): """Test describing remote methods.""" def testDescribe(self): class Request(messages.Message): pass class Response(messages.Message): pass @remote.method(Request, Response) def remote_method(request): pass module_name = test_util.get_module_name(DescribeMethodTest) expected = descriptor.MethodDescriptor() expected.name = 'remote_method' expected.request_type = '%s.Request' % module_name expected.response_type = '%s.Response' % module_name described = descriptor.describe_method(remote_method) described.check_initialized() self.assertEquals(expected, described) class DescribeServiceTest(test_util.TestCase): """Test describing service classes.""" def testDescribe(self): class Request1(messages.Message): pass class Response1(messages.Message): pass class Request2(messages.Message): pass class Response2(messages.Message): pass class MyService(remote.Service): @remote.method(Request1, Response1) def method1(self, request): pass @remote.method(Request2, Response2) def method2(self, request): pass expected = descriptor.ServiceDescriptor() expected.name = 'MyService' expected.methods = [] expected.methods.append(descriptor.describe_method(MyService.method1)) expected.methods.append(descriptor.describe_method(MyService.method2)) described = descriptor.describe_service(MyService) described.check_initialized() self.assertEquals(expected, described) class DescribeFileTest(test_util.TestCase): """Test describing modules.""" def LoadModule(self, module_name, source): result = {'__name__': module_name, 'messages': messages, 'remote': remote, } exec source in result module = new.module(module_name) for name, value in result.iteritems(): setattr(module, name, value) return module def testEmptyModule(self): """Test describing an empty file.""" module = new.module('my.package.name') expected = descriptor.FileDescriptor() expected.package = 'my.package.name' described = descriptor.describe_file(module) described.check_initialized() self.assertEquals(expected, described) def testNoPackageName(self): """Test describing a module with no module name.""" module = new.module('') expected = descriptor.FileDescriptor() described = descriptor.describe_file(module) described.check_initialized() self.assertEquals(expected, described) def testPackageName(self): """Test using the 'package' module attribute.""" module = new.module('my.module.name') module.package = 'my.package.name' expected = descriptor.FileDescriptor() expected.package = 'my.package.name' described = descriptor.describe_file(module) described.check_initialized() self.assertEquals(expected, described) def testMain(self): """Test using the 'package' module attribute.""" module = new.module('__main__') module.__file__ = '/blim/blam/bloom/my_package.py' expected = descriptor.FileDescriptor() expected.package = 'my_package' described = descriptor.describe_file(module) described.check_initialized() self.assertEquals(expected, described) def testMessages(self): """Test that messages are described.""" module = self.LoadModule('my.package', 'class Message1(messages.Message): pass\n' 'class Message2(messages.Message): pass\n') message1 = descriptor.MessageDescriptor() message1.name = 'Message1' message2 = descriptor.MessageDescriptor() message2.name = 'Message2' expected = descriptor.FileDescriptor() expected.package = 'my.package' expected.message_types = [message1, message2] described = descriptor.describe_file(module) described.check_initialized() self.assertEquals(expected, described) def testEnums(self): """Test that enums are described.""" module = self.LoadModule('my.package', 'class Enum1(messages.Enum): pass\n' 'class Enum2(messages.Enum): pass\n') enum1 = descriptor.EnumDescriptor() enum1.name = 'Enum1' enum2 = descriptor.EnumDescriptor() enum2.name = 'Enum2' expected = descriptor.FileDescriptor() expected.package = 'my.package' expected.enum_types = [enum1, enum2] described = descriptor.describe_file(module) described.check_initialized() self.assertEquals(expected, described) def testServices(self): """Test that services are described.""" module = self.LoadModule('my.package', 'class Service1(remote.Service): pass\n' 'class Service2(remote.Service): pass\n') service1 = descriptor.ServiceDescriptor() service1.name = 'Service1' service2 = descriptor.ServiceDescriptor() service2.name = 'Service2' expected = descriptor.FileDescriptor() expected.package = 'my.package' expected.service_types = [service1, service2] described = descriptor.describe_file(module) described.check_initialized() self.assertEquals(expected, described) class DescribeFileSetTest(test_util.TestCase): """Test describing multiple modules.""" def testNoModules(self): """Test what happens when no modules provided.""" described = descriptor.describe_file_set([]) described.check_initialized() # The described FileSet.files will be None. self.assertEquals(descriptor.FileSet(), described) def testWithModules(self): """Test what happens when no modules provided.""" modules = [new.module('package1'), new.module('package1')] file1 = descriptor.FileDescriptor() file1.package = 'package1' file2 = descriptor.FileDescriptor() file2.package = 'package2' expected = descriptor.FileSet() expected.files = [file1, file1] described = descriptor.describe_file_set(modules) described.check_initialized() self.assertEquals(expected, described) class DescribeTest(test_util.TestCase): def testModule(self): self.assertEquals(descriptor.describe_file(test_util), descriptor.describe(test_util)) def testMethod(self): class Param(messages.Message): pass class Service(remote.Service): @remote.method(Param, Param) def fn(self): return Param() self.assertEquals(descriptor.describe_method(Service.fn), descriptor.describe(Service.fn)) def testField(self): self.assertEquals( descriptor.describe_field(test_util.NestedMessage.a_value), descriptor.describe(test_util.NestedMessage.a_value)) def testEnumValue(self): self.assertEquals( descriptor.describe_enum_value( test_util.OptionalMessage.SimpleEnum.VAL1), descriptor.describe(test_util.OptionalMessage.SimpleEnum.VAL1)) def testMessage(self): self.assertEquals(descriptor.describe_message(test_util.NestedMessage), descriptor.describe(test_util.NestedMessage)) def testEnum(self): self.assertEquals( descriptor.describe_enum(test_util.OptionalMessage.SimpleEnum), descriptor.describe(test_util.OptionalMessage.SimpleEnum)) def testService(self): class Service(remote.Service): pass self.assertEquals(descriptor.describe_service(Service), descriptor.describe(Service)) def testService(self): class Service(remote.Service): pass self.assertEquals(descriptor.describe_service(Service), descriptor.describe(Service)) def testUndescribable(self): class NonService(object): def fn(self): pass for value in (NonService, NonService.fn, 1, 'string', 1.2, None): self.assertEquals(None, descriptor.describe(value)) class ModuleFinderTest(test_util.TestCase): def testFindModule(self): self.assertEquals(descriptor.describe_file(registry), descriptor.import_descriptor_loader('protorpc.registry')) def testFindMessage(self): self.assertEquals( descriptor.describe_message(descriptor.FileSet), descriptor.import_descriptor_loader('protorpc.descriptor.FileSet')) def testFindField(self): self.assertEquals( descriptor.describe_field(descriptor.FileSet.files), descriptor.import_descriptor_loader('protorpc.descriptor.FileSet.files')) def testFindEnumValue(self): self.assertEquals( descriptor.describe_enum_value(test_util.OptionalMessage.SimpleEnum.VAL1), descriptor.import_descriptor_loader( 'protorpc.test_util.OptionalMessage.SimpleEnum.VAL1')) def testFindMethod(self): self.assertEquals( descriptor.describe_method(registry.RegistryService.services), descriptor.import_descriptor_loader( 'protorpc.registry.RegistryService.services')) def testFindService(self): self.assertEquals( descriptor.describe_service(registry.RegistryService), descriptor.import_descriptor_loader('protorpc.registry.RegistryService')) def testFindWithAbsoluteName(self): self.assertEquals( descriptor.describe_service(registry.RegistryService), descriptor.import_descriptor_loader('.protorpc.registry.RegistryService')) def testFindWrongThings(self): for name in ('a', 'protorpc.registry.RegistryService.__init__', '', ): self.assertRaisesWithRegexpMatch( messages.DefinitionNotFoundError, 'Could not find definition for %s' % name, descriptor.import_descriptor_loader, name) class DescriptorLibraryTest(test_util.TestCase): def setUp(self): self.packageless = descriptor.MessageDescriptor() self.packageless.name = 'Packageless' self.library = descriptor.DescriptorLibrary( descriptors={ 'not.real.Packageless': self.packageless, 'Packageless': self.packageless, }) def testLookupPackage(self): self.assertEquals('csv', self.library.lookup_package('csv')) self.assertEquals('protorpc', self.library.lookup_package('protorpc')) self.assertEquals('protorpc.registry', self.library.lookup_package('protorpc.registry')) self.assertEquals('protorpc.registry', self.library.lookup_package('.protorpc.registry')) self.assertEquals( 'protorpc.registry', self.library.lookup_package('protorpc.registry.RegistryService')) self.assertEquals( 'protorpc.registry', self.library.lookup_package( 'protorpc.registry.RegistryService.services')) def testLookupNonPackages(self): for name in ('', 'a', 'protorpc.descriptor.DescriptorLibrary'): self.assertRaisesWithRegexpMatch( messages.DefinitionNotFoundError, 'Could not find definition for %s' % name, self.library.lookup_package, name) def testNoPackage(self): self.assertRaisesWithRegexpMatch( messages.DefinitionNotFoundError, 'Could not find definition for not.real', self.library.lookup_package, 'not.real.Packageless') self.assertEquals(None, self.library.lookup_package('Packageless')) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/end2end_test.py0000755000076500000240000001222312277637135023310 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2011 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """End to end tests for ProtoRPC.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import unittest from protorpc import protojson from protorpc import remote from protorpc import test_util from protorpc import util from protorpc import webapp_test_util package = 'test_package' class EndToEndTest(webapp_test_util.EndToEndTestBase): def testSimpleRequest(self): self.assertEquals(test_util.OptionalMessage(string_value='+blar'), self.stub.optional_message(string_value='blar')) def testSimpleRequestComplexContentType(self): response = self.DoRawRequest( 'optional_message', content='{"string_value": "blar"}', content_type='application/json; charset=utf-8') headers = response.headers self.assertEquals(200, response.code) self.assertEquals('{"string_value": "+blar"}', response.read()) self.assertEquals('application/json', headers['content-type']) def testInitParameter(self): self.assertEquals(test_util.OptionalMessage(string_value='uninitialized'), self.stub.init_parameter()) self.assertEquals(test_util.OptionalMessage(string_value='initialized'), self.other_stub.init_parameter()) def testMissingContentType(self): code, content, headers = self.RawRequestError( 'optional_message', content='{"string_value": "blar"}', content_type='') self.assertEquals(400, code) self.assertEquals(util.pad_string('Bad Request'), content) self.assertEquals('text/plain; charset=utf-8', headers['content-type']) def testWrongPath(self): self.assertRaisesWithRegexpMatch(remote.ServerError, 'HTTP Error 404: Not Found', self.bad_path_stub.optional_message) def testUnsupportedContentType(self): code, content, headers = self.RawRequestError( 'optional_message', content='{"string_value": "blar"}', content_type='image/png') self.assertEquals(415, code) self.assertEquals(util.pad_string('Unsupported Media Type'), content) self.assertEquals(headers['content-type'], 'text/plain; charset=utf-8') def testUnsupportedHttpMethod(self): code, content, headers = self.RawRequestError('optional_message') self.assertEquals(405, code) self.assertEquals( util.pad_string('/my/service.optional_message is a ProtoRPC method.\n\n' 'Service protorpc.webapp_test_util.TestService\n\n' 'More about ProtoRPC: ' 'http://code.google.com/p/google-protorpc\n'), content) self.assertEquals(headers['content-type'], 'text/plain; charset=utf-8') def testMethodNotFound(self): self.assertRaisesWithRegexpMatch(remote.MethodNotFoundError, 'Unrecognized RPC method: does_not_exist', self.mismatched_stub.does_not_exist) def testBadMessageError(self): code, content, headers = self.RawRequestError('nested_message', content='{}') self.assertEquals(400, code) expected_content = protojson.encode_message(remote.RpcStatus( state=remote.RpcState.REQUEST_ERROR, error_message=('Error parsing ProtoRPC request ' '(Unable to parse request content: ' 'Message NestedMessage is missing ' 'required field a_value)'))) self.assertEquals(util.pad_string(expected_content), content) self.assertEquals(headers['content-type'], 'application/json') def testApplicationError(self): try: self.stub.raise_application_error() except remote.ApplicationError, err: self.assertEquals('This is an application error', err.message) self.assertEquals('ERROR_NAME', err.error_name) else: self.fail('Expected application error') def testRpcError(self): try: self.stub.raise_rpc_error() except remote.ServerError, err: self.assertEquals('Internal Server Error', err.message) else: self.fail('Expected server error') def testUnexpectedError(self): try: self.stub.raise_unexpected_error() except remote.ServerError, err: self.assertEquals('Internal Server Error', err.message) else: self.fail('Expected server error') def testBadResponse(self): try: self.stub.return_bad_message() except remote.ServerError, err: self.assertEquals('Internal Server Error', err.message) else: self.fail('Expected server error') def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/experimental/0000755000076500000240000000000012300027071023025 5ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc/experimental/__init__.py0000644000076500000240000000126112277637135025163 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2011 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Main module for ProtoRPC package.""" __author__ = 'rafek@google.com (Rafe Kaplan)' protorpc-standalone-0.9.1/protorpc/generate.py0000755000076500000240000000721212277637135022526 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # __author__ = 'rafek@google.com (Rafe Kaplan)' import contextlib from . import messages from . import util __all__ = ['IndentationError', 'IndentWriter', ] class IndentationError(messages.Error): """Raised when end_indent is called too many times.""" class IndentWriter(object): """Utility class to make it easy to write formatted indented text. IndentWriter delegates to a file-like object and is able to keep track of the level of indentation. Each call to write_line will write a line terminated by a new line proceeded by a number of spaces indicated by the current level of indentation. IndexWriter overloads the << operator to make line writing operations clearer. The indent method returns a context manager that can be used by the Python with statement that makes generating python code easier to use. For example: index_writer << 'def factorial(n):' with index_writer.indent(): index_writer << 'if n <= 1:' with index_writer.indent(): index_writer << 'return 1' index_writer << 'else:' with index_writer.indent(): index_writer << 'return factorial(n - 1)' This would generate: def factorial(n): if n <= 1: return 1 else: return factorial(n - 1) """ @util.positional(2) def __init__(self, output, indent_space=2): """Constructor. Args: output: File-like object to wrap. indent_space: Number of spaces each level of indentation will be. """ # Private attributes: # # __output: The wrapped file-like object. # __indent_space: String to append for each level of indentation. # __indentation: The current full indentation string. self.__output = output self.__indent_space = indent_space * ' ' self.__indentation = 0 @property def indent_level(self): """Current level of indentation for IndentWriter.""" return self.__indentation def write_line(self, line): """Write line to wrapped file-like object using correct indentation. The line is written with the current level of indentation printed before it and terminated by a new line. Args: line: Line to write to wrapped file-like object. """ if line != '': self.__output.write(self.__indentation * self.__indent_space) self.__output.write(line) self.__output.write('\n') def begin_indent(self): """Begin a level of indentation.""" self.__indentation += 1 def end_indent(self): """Undo the most recent level of indentation. Raises: IndentationError when called with no indentation levels. """ if not self.__indentation: raise IndentationError('Unable to un-indent further') self.__indentation -= 1 @contextlib.contextmanager def indent(self): """Create indentation level compatible with the Python 'with' keyword.""" self.begin_indent() yield self.end_indent() def __lshift__(self, line): """Syntactic sugar for write_line method. Args: line: Line to write to wrapped file-like object. """ self.write_line(line) protorpc-standalone-0.9.1/protorpc/generate_proto.py0000755000076500000240000000742412277637135023756 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import with_statement __author__ = 'rafek@google.com (Rafe Kaplan)' import logging from . import descriptor from . import generate from . import messages from . import util __all__ = ['format_proto_file'] @util.positional(2) def format_proto_file(file_descriptor, output, indent_space=2): out = generate.IndentWriter(output, indent_space=indent_space) if file_descriptor.package: out << 'package %s;' % file_descriptor.package def write_enums(enum_descriptors): """Write nested and non-nested Enum types. Args: enum_descriptors: List of EnumDescriptor objects from which to generate enums. """ # Write enums. for enum in enum_descriptors or []: out << '' out << '' out << 'enum %s {' % enum.name out << '' with out.indent(): if enum.values: for enum_value in enum.values: out << '%s = %s;' % (enum_value.name, enum_value.number) out << '}' write_enums(file_descriptor.enum_types) def write_fields(field_descriptors): """Write fields for Message types. Args: field_descriptors: List of FieldDescriptor objects from which to generate fields. """ for field in field_descriptors or []: default_format = '' if field.default_value is not None: if field.label == descriptor.FieldDescriptor.Label.REPEATED: logging.warning('Default value for repeated field %s is not being ' 'written to proto file' % field.name) else: # Convert default value to string. if field.variant == messages.Variant.MESSAGE: logging.warning( 'Message field %s should not have default values' % field.name) default = None elif field.variant == messages.Variant.STRING: default = repr(field.default_value.encode('utf-8')) elif field.variant == messages.Variant.BYTES: default = repr(field.default_value) else: default = str(field.default_value) if default is not None: default_format = ' [default=%s]' % default if field.variant in (messages.Variant.MESSAGE, messages.Variant.ENUM): field_type = field.type_name else: field_type = str(field.variant).lower() out << '%s %s %s = %s%s;' % (str(field.label).lower(), field_type, field.name, field.number, default_format) def write_messages(message_descriptors): """Write nested and non-nested Message types. Args: message_descriptors: List of MessageDescriptor objects from which to generate messages. """ for message in message_descriptors or []: out << '' out << '' out << 'message %s {' % message.name with out.indent(): if message.enum_types: write_enums(message.enum_types) if message.message_types: write_messages(message.message_types) if message.fields: write_fields(message.fields) out << '}' write_messages(file_descriptor.message_types) protorpc-standalone-0.9.1/protorpc/generate_proto_test.py0000755000076500000240000001435012277637135025011 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.generate_proto_test.""" import os import shutil import cStringIO import sys import tempfile import unittest from protorpc import descriptor from protorpc import generate_proto from protorpc import test_util from protorpc import util class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = generate_proto class FormatProtoFileTest(test_util.TestCase): def setUp(self): self.file_descriptor = descriptor.FileDescriptor() self.output = cStringIO.StringIO() @property def result(self): return self.output.getvalue() def MakeMessage(self, name='MyMessage', fields=[]): message = descriptor.MessageDescriptor() message.name = name message.fields = fields messages_list = getattr(self.file_descriptor, 'fields', []) messages_list.append(message) self.file_descriptor.message_types = messages_list def testBlankPackage(self): self.file_descriptor.package = None generate_proto.format_proto_file(self.file_descriptor, self.output) self.assertEquals('', self.result) def testEmptyPackage(self): self.file_descriptor.package = 'my_package' generate_proto.format_proto_file(self.file_descriptor, self.output) self.assertEquals('package my_package;\n', self.result) def testSingleField(self): field = descriptor.FieldDescriptor() field.name = 'integer_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.INT64 self.MakeMessage(fields=[field]) generate_proto.format_proto_file(self.file_descriptor, self.output) self.assertEquals('\n\n' 'message MyMessage {\n' ' optional int64 integer_field = 1;\n' '}\n', self.result) def testSingleFieldWithDefault(self): field = descriptor.FieldDescriptor() field.name = 'integer_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.INT64 field.default_value = '10' self.MakeMessage(fields=[field]) generate_proto.format_proto_file(self.file_descriptor, self.output) self.assertEquals('\n\n' 'message MyMessage {\n' ' optional int64 integer_field = 1 [default=10];\n' '}\n', self.result) def testRepeatedFieldWithDefault(self): field = descriptor.FieldDescriptor() field.name = 'integer_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.REPEATED field.variant = descriptor.FieldDescriptor.Variant.INT64 field.default_value = '[10, 20]' self.MakeMessage(fields=[field]) generate_proto.format_proto_file(self.file_descriptor, self.output) self.assertEquals('\n\n' 'message MyMessage {\n' ' repeated int64 integer_field = 1;\n' '}\n', self.result) def testSingleFieldWithDefaultString(self): field = descriptor.FieldDescriptor() field.name = 'string_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.STRING field.default_value = 'hello' self.MakeMessage(fields=[field]) generate_proto.format_proto_file(self.file_descriptor, self.output) self.assertEquals('\n\n' 'message MyMessage {\n' " optional string string_field = 1 [default='hello'];\n" '}\n', self.result) def testSingleFieldWithDefaultEmptyString(self): field = descriptor.FieldDescriptor() field.name = 'string_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.STRING field.default_value = '' self.MakeMessage(fields=[field]) generate_proto.format_proto_file(self.file_descriptor, self.output) self.assertEquals('\n\n' 'message MyMessage {\n' " optional string string_field = 1 [default=''];\n" '}\n', self.result) def testSingleFieldWithDefaultMessage(self): field = descriptor.FieldDescriptor() field.name = 'message_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.MESSAGE field.type_name = 'MyNestedMessage' field.default_value = 'not valid' self.MakeMessage(fields=[field]) generate_proto.format_proto_file(self.file_descriptor, self.output) self.assertEquals('\n\n' 'message MyMessage {\n' " optional MyNestedMessage message_field = 1;\n" '}\n', self.result) def testSingleFieldWithDefaultEnum(self): field = descriptor.FieldDescriptor() field.name = 'enum_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.ENUM field.type_name = 'my_package.MyEnum' field.default_value = '17' self.MakeMessage(fields=[field]) generate_proto.format_proto_file(self.file_descriptor, self.output) self.assertEquals('\n\n' 'message MyMessage {\n' " optional my_package.MyEnum enum_field = 1 " "[default=17];\n" '}\n', self.result) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/generate_python.py0000755000076500000240000001516512277637135024135 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import with_statement __author__ = 'rafek@google.com (Rafe Kaplan)' from . import descriptor from . import generate from . import message_types from . import messages from . import util __all__ = ['format_python_file'] _MESSAGE_FIELD_MAP = { message_types.DateTimeMessage.definition_name(): message_types.DateTimeField, } def _write_enums(enum_descriptors, out): """Write nested and non-nested Enum types. Args: enum_descriptors: List of EnumDescriptor objects from which to generate enums. out: Indent writer used for generating text. """ # Write enums. for enum in enum_descriptors or []: out << '' out << '' out << 'class %s(messages.Enum):' % enum.name out << '' with out.indent(): if not enum.values: out << 'pass' else: for enum_value in enum.values: out << '%s = %s' % (enum_value.name, enum_value.number) def _write_fields(field_descriptors, out): """Write fields for Message types. Args: field_descriptors: List of FieldDescriptor objects from which to generate fields. out: Indent writer used for generating text. """ out << '' for field in field_descriptors or []: type_format = '' label_format = '' message_field = _MESSAGE_FIELD_MAP.get(field.type_name) if message_field: module = 'message_types' field_type = message_field else: module = 'messages' field_type = messages.Field.lookup_field_type_by_variant(field.variant) if field_type in (messages.EnumField, messages.MessageField): type_format = '\'%s\', ' % field.type_name if field.label == descriptor.FieldDescriptor.Label.REQUIRED: label_format = ', required=True' elif field.label == descriptor.FieldDescriptor.Label.REPEATED: label_format = ', repeated=True' if field_type.DEFAULT_VARIANT != field.variant: variant_format = ', variant=messages.Variant.%s' % field.variant else: variant_format = '' if field.default_value: if field_type in [messages.BytesField, messages.StringField, ]: default_value = repr(field.default_value) elif field_type is messages.EnumField: try: default_value = str(int(field.default_value)) except ValueError: default_value = repr(field.default_value) else: default_value = field.default_value default_format = ', default=%s' % (default_value,) else: default_format = '' out << '%s = %s.%s(%s%s%s%s%s)' % (field.name, module, field_type.__name__, type_format, field.number, label_format, variant_format, default_format) def _write_messages(message_descriptors, out): """Write nested and non-nested Message types. Args: message_descriptors: List of MessageDescriptor objects from which to generate messages. out: Indent writer used for generating text. """ for message in message_descriptors or []: out << '' out << '' out << 'class %s(messages.Message):' % message.name with out.indent(): if not (message.enum_types or message.message_types or message.fields): out << '' out << 'pass' else: _write_enums(message.enum_types, out) _write_messages(message.message_types, out) _write_fields(message.fields, out) def _write_methods(method_descriptors, out): """Write methods of Service types. All service method implementations raise NotImplementedError. Args: method_descriptors: List of MethodDescriptor objects from which to generate methods. out: Indent writer used for generating text. """ for method in method_descriptors: out << '' out << "@remote.method('%s', '%s')" % (method.request_type, method.response_type) out << 'def %s(self, request):' % (method.name,) with out.indent(): out << ('raise NotImplementedError' "('Method %s is not implemented')" % (method.name)) def _write_services(service_descriptors, out): """Write Service types. Args: service_descriptors: List of ServiceDescriptor instances from which to generate services. out: Indent writer used for generating text. """ for service in service_descriptors or []: out << '' out << '' out << 'class %s(remote.Service):' % service.name with out.indent(): if service.methods: _write_methods(service.methods, out) else: out << '' out << 'pass' @util.positional(2) def format_python_file(file_descriptor, output, indent_space=2): """Format FileDescriptor object as a single Python module. Services generated by this function will raise NotImplementedError. All Python classes generated by this function use delayed binding for all message fields, enum fields and method parameter types. For example a service method might be generated like so: class MyService(remote.Service): @remote.method('my_package.MyRequestType', 'my_package.MyResponseType') def my_method(self, request): raise NotImplementedError('Method my_method is not implemented') Args: file_descriptor: FileDescriptor instance to format as python module. output: File-like object to write module source code to. indent_space: Number of spaces for each level of Python indentation. """ out = generate.IndentWriter(output, indent_space=indent_space) out << 'from protorpc import message_types' out << 'from protorpc import messages' if file_descriptor.service_types: out << 'from protorpc import remote' if file_descriptor.package: out << "package = '%s'" % file_descriptor.package _write_enums(file_descriptor.enum_types, out) _write_messages(file_descriptor.message_types, out) _write_services(file_descriptor.service_types, out) protorpc-standalone-0.9.1/protorpc/generate_python_test.py0000755000076500000240000002653112277637135025173 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.generate_python_test.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import os import shutil import sys import tempfile import unittest from protorpc import descriptor from protorpc import generate_python from protorpc import test_util from protorpc import util class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = generate_python class FormatPythonFileTest(test_util.TestCase): def setUp(self): self.original_path = list(sys.path) self.original_modules = dict(sys.modules) sys.path = list(sys.path) self.file_descriptor = descriptor.FileDescriptor() # Create temporary directory and add to Python path so that generated # Python code can be easily parsed, imported and executed. self.temp_dir = tempfile.mkdtemp() sys.path.append(self.temp_dir) def tearDown(self): # Reset path. sys.path[:] = [] sys.path.extend(self.original_path) # Reset modules. sys.modules.clear() sys.modules.update(self.original_modules) # Remove temporary directory. try: shutil.rmtree(self.temp_dir) except IOError: pass def DoPythonTest(self, file_descriptor): """Execute python test based on a FileDescriptor object. The full test of the Python code generation is to generate a Python source code file, import the module and regenerate the FileDescriptor from it. If the generated FileDescriptor is the same as the original, it means that the generated source code correctly implements the actual FileDescriptor. """ file_name = os.path.join(self.temp_dir, '%s.py' % (file_descriptor.package or 'blank',)) source_file = open(file_name, 'wt') try: generate_python.format_python_file(file_descriptor, source_file) finally: source_file.close() module_to_import = file_descriptor.package or 'blank' module = __import__(module_to_import) if not file_descriptor.package: self.assertFalse(hasattr(module, 'package')) module.package = '' # Create package name so that comparison will work. reloaded_descriptor = descriptor.describe_file(module) # Need to sort both message_types fields because document order is never # Ensured. # TODO(rafek): Ensure document order. if reloaded_descriptor.message_types: reloaded_descriptor.message_types = sorted( reloaded_descriptor.message_types, key=lambda v: v.name) if file_descriptor.message_types: file_descriptor.message_types = sorted( file_descriptor.message_types, key=lambda v: v.name) self.assertEquals(file_descriptor, reloaded_descriptor) @util.positional(2) def DoMessageTest(self, field_descriptors, message_types=None, enum_types=None): """Execute message generation test based on FieldDescriptor objects. Args: field_descriptor: List of FieldDescriptor object to generate and test. message_types: List of other MessageDescriptor objects that the new Message class depends on. enum_types: List of EnumDescriptor objects that the new Message class depends on. """ file_descriptor = descriptor.FileDescriptor() file_descriptor.package = 'my_package' message_descriptor = descriptor.MessageDescriptor() message_descriptor.name = 'MyMessage' message_descriptor.fields = list(field_descriptors) file_descriptor.message_types = message_types or [] file_descriptor.message_types.append(message_descriptor) if enum_types: file_descriptor.enum_types = list(enum_types) self.DoPythonTest(file_descriptor) def testBlankPackage(self): self.DoPythonTest(descriptor.FileDescriptor()) def testEmptyPackage(self): file_descriptor = descriptor.FileDescriptor() file_descriptor.package = 'mypackage' self.DoPythonTest(file_descriptor) def testSingleField(self): field = descriptor.FieldDescriptor() field.name = 'integer_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.INT64 self.DoMessageTest([field]) def testMessageField_InternalReference(self): other_message = descriptor.MessageDescriptor() other_message.name = 'OtherMessage' field = descriptor.FieldDescriptor() field.name = 'message_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.MESSAGE field.type_name = 'my_package.OtherMessage' self.DoMessageTest([field], message_types=[other_message]) def testMessageField_ExternalReference(self): field = descriptor.FieldDescriptor() field.name = 'message_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.MESSAGE field.type_name = 'protorpc.registry.GetFileSetResponse' self.DoMessageTest([field]) def testEnumField_InternalReference(self): enum = descriptor.EnumDescriptor() enum.name = 'Color' field = descriptor.FieldDescriptor() field.name = 'color' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.ENUM field.type_name = 'my_package.Color' self.DoMessageTest([field], enum_types=[enum]) def testEnumField_ExternalReference(self): field = descriptor.FieldDescriptor() field.name = 'color' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.ENUM field.type_name = 'protorpc.descriptor.FieldDescriptor.Label' self.DoMessageTest([field]) def testDateTimeField(self): field = descriptor.FieldDescriptor() field.name = 'timestamp' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.MESSAGE field.type_name = 'protorpc.message_types.DateTimeMessage' self.DoMessageTest([field]) def testNonDefaultVariant(self): field = descriptor.FieldDescriptor() field.name = 'integer_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.UINT64 self.DoMessageTest([field]) def testRequiredField(self): field = descriptor.FieldDescriptor() field.name = 'integer_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.REQUIRED field.variant = descriptor.FieldDescriptor.Variant.INT64 self.DoMessageTest([field]) def testRepeatedField(self): field = descriptor.FieldDescriptor() field.name = 'integer_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.REPEATED field.variant = descriptor.FieldDescriptor.Variant.INT64 self.DoMessageTest([field]) def testIntegerDefaultValue(self): field = descriptor.FieldDescriptor() field.name = 'integer_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.INT64 field.default_value = '10' self.DoMessageTest([field]) def testFloatDefaultValue(self): field = descriptor.FieldDescriptor() field.name = 'float_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.DOUBLE field.default_value = '10.1' self.DoMessageTest([field]) def testStringDefaultValue(self): field = descriptor.FieldDescriptor() field.name = 'string_field' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.STRING field.default_value = u'a nice lovely string\'s "string"' self.DoMessageTest([field]) def testEnumDefaultValue(self): field = descriptor.FieldDescriptor() field.name = 'label' field.number = 1 field.label = descriptor.FieldDescriptor.Label.OPTIONAL field.variant = descriptor.FieldDescriptor.Variant.ENUM field.type_name = 'protorpc.descriptor.FieldDescriptor.Label' field.default_value = '2' self.DoMessageTest([field]) def testMultiFields(self): field1 = descriptor.FieldDescriptor() field1.name = 'integer_field' field1.number = 1 field1.label = descriptor.FieldDescriptor.Label.OPTIONAL field1.variant = descriptor.FieldDescriptor.Variant.INT64 field2 = descriptor.FieldDescriptor() field2.name = 'string_field' field2.number = 2 field2.label = descriptor.FieldDescriptor.Label.OPTIONAL field2.variant = descriptor.FieldDescriptor.Variant.STRING field3 = descriptor.FieldDescriptor() field3.name = 'unsigned_integer_field' field3.number = 3 field3.label = descriptor.FieldDescriptor.Label.OPTIONAL field3.variant = descriptor.FieldDescriptor.Variant.UINT64 self.DoMessageTest([field1, field2, field3]) def testNestedMessage(self): message = descriptor.MessageDescriptor() message.name = 'OuterMessage' inner_message = descriptor.MessageDescriptor() inner_message.name = 'InnerMessage' inner_inner_message = descriptor.MessageDescriptor() inner_inner_message.name = 'InnerInnerMessage' inner_message.message_types = [inner_inner_message] message.message_types = [inner_message] file_descriptor = descriptor.FileDescriptor() file_descriptor.message_types = [message] self.DoPythonTest(file_descriptor) def testNestedEnum(self): message = descriptor.MessageDescriptor() message.name = 'OuterMessage' inner_enum = descriptor.EnumDescriptor() inner_enum.name = 'InnerEnum' message.enum_types = [inner_enum] file_descriptor = descriptor.FileDescriptor() file_descriptor.message_types = [message] self.DoPythonTest(file_descriptor) def testService(self): service = descriptor.ServiceDescriptor() service.name = 'TheService' method1 = descriptor.MethodDescriptor() method1.name = 'method1' method1.request_type = 'protorpc.descriptor.FileDescriptor' method1.response_type = 'protorpc.descriptor.MethodDescriptor' service.methods = [method1] file_descriptor = descriptor.FileDescriptor() file_descriptor.service_types = [service] self.DoPythonTest(file_descriptor) # Test to make sure that implementation methods raise an exception. import blank service_instance = blank.TheService() self.assertRaisesWithRegexpMatch(NotImplementedError, 'Method method1 is not implemented', service_instance.method1, descriptor.FileDescriptor()) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/generate_test.py0000755000076500000240000001114112277637135023561 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.generate.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import cStringIO import sys import unittest from protorpc import generate from protorpc import test_util class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = generate class IndentWriterTest(test_util.TestCase): def setUp(self): self.out = cStringIO.StringIO() self.indent_writer = generate.IndentWriter(self.out) def testWriteLine(self): self.indent_writer.write_line('This is a line') self.indent_writer.write_line('This is another line') self.assertEquals('This is a line\n' 'This is another line\n', self.out.getvalue()) def testLeftShift(self): self.run_count = 0 def mock_write_line(line): self.run_count += 1 self.assertEquals('same as calling write_line', line) self.indent_writer.write_line = mock_write_line self.indent_writer << 'same as calling write_line' self.assertEquals(1, self.run_count) def testIndentation(self): self.indent_writer << 'indent 0' self.indent_writer.begin_indent() self.indent_writer << 'indent 1' self.indent_writer.begin_indent() self.indent_writer << 'indent 2' self.indent_writer.end_indent() self.indent_writer << 'end 2' self.indent_writer.end_indent() self.indent_writer << 'end 1' self.assertRaises(generate.IndentationError, self.indent_writer.end_indent) self.assertEquals('indent 0\n' ' indent 1\n' ' indent 2\n' ' end 2\n' 'end 1\n', self.out.getvalue()) def testBlankLine(self): self.indent_writer << '' self.indent_writer.begin_indent() self.indent_writer << '' self.assertEquals('\n\n', self.out.getvalue()) def testNoneInvalid(self): with self.assertRaises(TypeError): self.indent_writer << None def testAltIndentation(self): self.indent_writer = generate.IndentWriter(self.out, indent_space=3) self.indent_writer << 'indent 0' self.assertEquals(0, self.indent_writer.indent_level) self.indent_writer.begin_indent() self.indent_writer << 'indent 1' self.assertEquals(1, self.indent_writer.indent_level) self.indent_writer.begin_indent() self.indent_writer << 'indent 2' self.assertEquals(2, self.indent_writer.indent_level) self.indent_writer.end_indent() self.indent_writer << 'end 2' self.assertEquals(1, self.indent_writer.indent_level) self.indent_writer.end_indent() self.indent_writer << 'end 1' self.assertEquals(0, self.indent_writer.indent_level) self.assertRaises(generate.IndentationError, self.indent_writer.end_indent) self.assertEquals(0, self.indent_writer.indent_level) self.assertEquals('indent 0\n' ' indent 1\n' ' indent 2\n' ' end 2\n' 'end 1\n', self.out.getvalue()) def testIndent(self): self.indent_writer << 'indent 0' self.assertEquals(0, self.indent_writer.indent_level) def indent1(): self.indent_writer << 'indent 1' self.assertEquals(1, self.indent_writer.indent_level) def indent2(): self.indent_writer << 'indent 2' self.assertEquals(2, self.indent_writer.indent_level) test_util.do_with(self.indent_writer.indent(), indent2) self.assertEquals(1, self.indent_writer.indent_level) self.indent_writer << 'end 2' test_util.do_with(self.indent_writer.indent(), indent1) self.assertEquals(0, self.indent_writer.indent_level) self.indent_writer << 'end 1' self.assertEquals('indent 0\n' ' indent 1\n' ' indent 2\n' ' end 2\n' 'end 1\n', self.out.getvalue()) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/google_imports.py0000755000076500000240000000054012277637135023762 0ustar jeremydwstaff00000000000000"""Dynamically decide from where to import other SDK modules. All other protorpc code should import other SDK modules from this module. If necessary, add new imports here (in both places). """ __author__ = 'yey@google.com (Ye Yuan)' # pylint: disable=g-import-not-at-top # pylint: disable=unused-import from _google.net.proto import ProtocolBuffer protorpc-standalone-0.9.1/protorpc/message_types.py0000755000076500000240000000730212277637135023604 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Simple protocol message types. Includes new message and field types that are outside what is defined by the protocol buffers standard. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import datetime from . import messages from . import util __all__ = [ 'DateTimeField', 'DateTimeMessage', 'VoidMessage', ] class VoidMessage(messages.Message): """Empty message.""" class DateTimeMessage(messages.Message): """Message to store/transmit a DateTime. Fields: milliseconds: Milliseconds since Jan 1st 1970 local time. time_zone_offset: Optional time zone offset, in minutes from UTC. """ milliseconds = messages.IntegerField(1, required=True) time_zone_offset = messages.IntegerField(2) class DateTimeField(messages.MessageField): """Field definition for datetime values. Stores a python datetime object as a field. If time zone information is included in the datetime object, it will be included in the encoded data when this is encoded/decoded. """ type = datetime.datetime message_type = DateTimeMessage @util.positional(3) def __init__(self, number, **kwargs): super(DateTimeField, self).__init__(self.message_type, number, **kwargs) def value_from_message(self, message): """Convert DateTimeMessage to a datetime. Args: A DateTimeMessage instance. Returns: A datetime instance. """ message = super(DateTimeField, self).value_from_message(message) if message.time_zone_offset is None: return datetime.datetime.utcfromtimestamp(message.milliseconds / 1000.0) # Need to subtract the time zone offset, because when we call # datetime.fromtimestamp, it will add the time zone offset to the # value we pass. milliseconds = (message.milliseconds - 60000 * message.time_zone_offset) timezone = util.TimeZoneOffset(message.time_zone_offset) return datetime.datetime.fromtimestamp(milliseconds / 1000.0, tz=timezone) def value_to_message(self, value): value = super(DateTimeField, self).value_to_message(value) # First, determine the delta from the epoch, so we can fill in # DateTimeMessage's milliseconds field. if value.tzinfo is None: time_zone_offset = 0 local_epoch = datetime.datetime.utcfromtimestamp(0) else: time_zone_offset = value.tzinfo.utcoffset(value).total_seconds() # Determine Jan 1, 1970 local time. local_epoch = datetime.datetime.fromtimestamp(-time_zone_offset, tz=value.tzinfo) delta = value - local_epoch # Create and fill in the DateTimeMessage, including time zone if # one was specified. message = DateTimeMessage() message.milliseconds = int(delta.total_seconds() * 1000) if value.tzinfo is not None: utc_offset = value.tzinfo.utcoffset(value) if utc_offset is not None: message.time_zone_offset = int( value.tzinfo.utcoffset(value).total_seconds() / 60) return message protorpc-standalone-0.9.1/protorpc/message_types_test.py0000755000076500000240000000571112277637135024645 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2013 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.message_types.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import datetime import unittest from protorpc import message_types from protorpc import messages from protorpc import test_util from protorpc import util class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = message_types class DateTimeFieldTest(test_util.TestCase): def testValueToMessage(self): field = message_types.DateTimeField(1) message = field.value_to_message(datetime.datetime(2033, 2, 4, 11, 22, 10)) self.assertEqual(message_types.DateTimeMessage(milliseconds=1991128930000), message) def testValueToMessageBadValue(self): field = message_types.DateTimeField(1) self.assertRaisesWithRegexpMatch( messages.EncodeError, 'Expected type datetime, got int: 20', field.value_to_message, 20) def testValueToMessageWithTimeZone(self): time_zone = util.TimeZoneOffset(60 * 10) field = message_types.DateTimeField(1) message = field.value_to_message( datetime.datetime(2033, 2, 4, 11, 22, 10, tzinfo=time_zone)) self.assertEqual(message_types.DateTimeMessage(milliseconds=1991128930000, time_zone_offset=600), message) def testValueFromMessage(self): message = message_types.DateTimeMessage(milliseconds=1991128000000) field = message_types.DateTimeField(1) timestamp = field.value_from_message(message) self.assertEqual(datetime.datetime(2033, 2, 4, 11, 6, 40), timestamp) def testValueFromMessageBadValue(self): field = message_types.DateTimeField(1) self.assertRaisesWithRegexpMatch( messages.DecodeError, 'Expected type DateTimeMessage, got VoidMessage: ', field.value_from_message, message_types.VoidMessage()) def testValueFromMessageWithTimeZone(self): message = message_types.DateTimeMessage(milliseconds=1991128000000, time_zone_offset=300) field = message_types.DateTimeField(1) timestamp = field.value_from_message(message) time_zone = util.TimeZoneOffset(60 * 5) self.assertEqual(datetime.datetime(2033, 2, 4, 11, 6, 40, tzinfo=time_zone), timestamp) if __name__ == '__main__': unittest.main() protorpc-standalone-0.9.1/protorpc/messages.py0000755000076500000240000015771712277637135022563 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Stand-alone implementation of in memory protocol messages. Public Classes: Enum: Represents an enumerated type. Variant: Hint for wire format to determine how to serialize. Message: Base class for user defined messages. IntegerField: Field for integer values. FloatField: Field for float values. BooleanField: Field for boolean values. BytesField: Field for binary string values. StringField: Field for UTF-8 string values. MessageField: Field for other message type values. EnumField: Field for enumerated type values. Public Exceptions (indentation indications class hierarchy): EnumDefinitionError: Raised when enumeration is incorrectly defined. FieldDefinitionError: Raised when field is incorrectly defined. InvalidVariantError: Raised when variant is not compatible with field type. InvalidDefaultError: Raised when default is not compatiable with field. InvalidNumberError: Raised when field number is out of range or reserved. MessageDefinitionError: Raised when message is incorrectly defined. DuplicateNumberError: Raised when field has duplicate number with another. ValidationError: Raised when a message or field is not valid. DefinitionNotFoundError: Raised when definition not found. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import inspect import os import sys import traceback import types import weakref from . import util __all__ = ['MAX_ENUM_VALUE', 'MAX_FIELD_NUMBER', 'FIRST_RESERVED_FIELD_NUMBER', 'LAST_RESERVED_FIELD_NUMBER', 'Enum', 'Field', 'FieldList', 'Variant', 'Message', 'IntegerField', 'FloatField', 'BooleanField', 'BytesField', 'StringField', 'MessageField', 'EnumField', 'find_definition', 'Error', 'DecodeError', 'EncodeError', 'EnumDefinitionError', 'FieldDefinitionError', 'InvalidVariantError', 'InvalidDefaultError', 'InvalidNumberError', 'MessageDefinitionError', 'DuplicateNumberError', 'ValidationError', 'DefinitionNotFoundError', ] # TODO(rafek): Add extended module test to ensure all exceptions # in services extends Error. Error = util.Error class EnumDefinitionError(Error): """Enumeration definition error.""" class FieldDefinitionError(Error): """Field definition error.""" class InvalidVariantError(FieldDefinitionError): """Invalid variant provided to field.""" class InvalidDefaultError(FieldDefinitionError): """Invalid default provided to field.""" class InvalidNumberError(FieldDefinitionError): """Invalid number provided to field.""" class MessageDefinitionError(Error): """Message definition error.""" class DuplicateNumberError(Error): """Duplicate number assigned to field.""" class DefinitionNotFoundError(Error): """Raised when definition is not found.""" class DecodeError(Error): """Error found decoding message from encoded form.""" class EncodeError(Error): """Error found when encoding message.""" class ValidationError(Error): """Invalid value for message error.""" def __str__(self): """Prints string with field name if present on exception.""" message = Error.__str__(self) try: field_name = self.field_name except AttributeError: return message else: return message # Attributes that are reserved by a class definition that # may not be used by either Enum or Message class definitions. _RESERVED_ATTRIBUTE_NAMES = frozenset( ['__module__', '__doc__']) _POST_INIT_FIELD_ATTRIBUTE_NAMES = frozenset( ['name', '_message_definition', '_MessageField__type', '_EnumField__type', '_EnumField__resolved_default']) _POST_INIT_ATTRIBUTE_NAMES = frozenset( ['_message_definition']) # Maximum enumeration value as defined by the protocol buffers standard. # All enum values must be less than or equal to this value. MAX_ENUM_VALUE = (2 ** 29) - 1 # Maximum field number as defined by the protocol buffers standard. # All field numbers must be less than or equal to this value. MAX_FIELD_NUMBER = (2 ** 29) - 1 # Field numbers between 19000 and 19999 inclusive are reserved by the # protobuf protocol and may not be used by fields. FIRST_RESERVED_FIELD_NUMBER = 19000 LAST_RESERVED_FIELD_NUMBER = 19999 class _DefinitionClass(type): """Base meta-class used for definition meta-classes. The Enum and Message definition classes share some basic functionality. Both of these classes may be contained by a Message definition. After initialization, neither class may have attributes changed except for the protected _message_definition attribute, and that attribute may change only once. """ __initialized = False def __init__(cls, name, bases, dct): """Constructor.""" type.__init__(cls, name, bases, dct) # Base classes may never be initialized. if cls.__bases__ != (object,): cls.__initialized = True def message_definition(cls): """Get outer Message definition that contains this definition. Returns: Containing Message definition if definition is contained within one, else None. """ try: return cls._message_definition() except AttributeError: return None def __setattr__(cls, name, value): """Overridden so that cannot set variables on definition classes after init. Setting attributes on a class must work during the period of initialization to set the enumation value class variables and build the name/number maps. Once __init__ has set the __initialized flag to True prohibits setting any more values on the class. The class is in effect frozen. Args: name: Name of value to set. value: Value to set. """ if cls.__initialized and name not in _POST_INIT_ATTRIBUTE_NAMES: raise AttributeError('May not change values: %s' % name) else: type.__setattr__(cls, name, value) def __delattr__(cls, name): """Overridden so that cannot delete varaibles on definition classes.""" raise TypeError('May not delete attributes on definition class') def definition_name(cls): """Helper method for creating definition name. Names will be generated to include the classes package name, scope (if the class is nested in another definition) and class name. By default, the package name for a definition is derived from its module name. However, this value can be overriden by placing a 'package' attribute in the module that contains the definition class. For example: package = 'some.alternate.package' class MyMessage(Message): ... >>> MyMessage.definition_name() some.alternate.package.MyMessage Returns: Dot-separated fully qualified name of definition. """ outer_definition_name = cls.outer_definition_name() if outer_definition_name is None: return unicode(cls.__name__) else: return u'%s.%s' % (outer_definition_name, cls.__name__) def outer_definition_name(cls): """Helper method for creating outer definition name. Returns: If definition is nested, will return the outer definitions name, else the package name. """ outer_definition = cls.message_definition() if not outer_definition: return util.get_package_for_module(cls.__module__) else: return outer_definition.definition_name() def definition_package(cls): """Helper method for creating creating the package of a definition. Returns: Name of package that definition belongs to. """ outer_definition = cls.message_definition() if not outer_definition: return util.get_package_for_module(cls.__module__) else: return outer_definition.definition_package() class _EnumClass(_DefinitionClass): """Meta-class used for defining the Enum base class. Meta-class enables very specific behavior for any defined Enum class. All attributes defined on an Enum sub-class must be integers. Each attribute defined on an Enum sub-class is translated into an instance of that sub-class, with the name of the attribute as its name, and the number provided as its value. It also ensures that only one level of Enum class hierarchy is possible. In other words it is not possible to delcare sub-classes of sub-classes of Enum. This class also defines some functions in order to restrict the behavior of the Enum class and its sub-classes. It is not possible to change the behavior of the Enum class in later classes since any new classes may be defined with only integer values, and no methods. """ def __init__(cls, name, bases, dct): # Can only define one level of sub-classes below Enum. if not (bases == (object,) or bases == (Enum,)): raise EnumDefinitionError('Enum type %s may only inherit from Enum' % (name,)) cls.__by_number = {} cls.__by_name = {} # Enum base class does not need to be initialized or locked. if bases != (object,): # Replace integer with number. for attribute, value in dct.iteritems(): # Module will be in every enum class. if attribute in _RESERVED_ATTRIBUTE_NAMES: continue # Reject anything that is not an int. if not isinstance(value, (int, long)): raise EnumDefinitionError( 'May only use integers in Enum definitions. Found: %s = %s' % (attribute, value)) # Protocol buffer standard recommends non-negative values. # Reject negative values. if value < 0: raise EnumDefinitionError( 'Must use non-negative enum values. Found: %s = %d' % (attribute, value)) if value > MAX_ENUM_VALUE: raise EnumDefinitionError( 'Must use enum values less than or equal %d. Found: %s = %d' % (MAX_ENUM_VALUE, attribute, value)) if value in cls.__by_number: raise EnumDefinitionError( 'Value for %s = %d is already defined: %s' % (attribute, value, cls.__by_number[value].name)) # Create enum instance and list in new Enum type. instance = object.__new__(cls) cls.__init__(instance, attribute, value) cls.__by_name[instance.name] = instance cls.__by_number[instance.number] = instance setattr(cls, attribute, instance) _DefinitionClass.__init__(cls, name, bases, dct) def __iter__(cls): """Iterate over all values of enum. Yields: Enumeration instances of the Enum class in arbitrary order. """ return cls.__by_number.itervalues() def names(cls): """Get all names for Enum. Returns: An iterator for names of the enumeration in arbitrary order. """ return cls.__by_name.iterkeys() def numbers(cls): """Get all numbers for Enum. Returns: An iterator for all numbers of the enumeration in arbitrary order. """ return cls.__by_number.iterkeys() def lookup_by_name(cls, name): """Look up Enum by name. Args: name: Name of enum to find. Returns: Enum sub-class instance of that value. """ return cls.__by_name[name] def lookup_by_number(cls, number): """Look up Enum by number. Args: number: Number of enum to find. Returns: Enum sub-class instance of that value. """ return cls.__by_number[number] def __len__(cls): return len(cls.__by_name) class Enum(object): """Base class for all enumerated types.""" __metaclass__ = _EnumClass __slots__ = set(('name', 'number')) def __new__(cls, index): """Acts as look-up routine after class is initialized. The purpose of overriding __new__ is to provide a way to treat Enum subclasses as casting types, similar to how the int type functions. A program can pass a string or an integer and this method with "convert" that value in to an appropriate Enum instance. Args: index: Name or number to look up. During initialization this is always the name of the new enum value. Raises: TypeError: When an inappropriate index value is passed provided. """ # If is enum type of this class, return it. if isinstance(index, cls): return index # If number, look up by number. if isinstance(index, (int, long)): try: return cls.lookup_by_number(index) except KeyError: pass # If name, look up by name. if isinstance(index, basestring): try: return cls.lookup_by_name(index) except KeyError: pass raise TypeError('No such value for %s in Enum %s' % (index, cls.__name__)) def __init__(self, name, number=None): """Initialize new Enum instance. Since this should only be called during class initialization any calls that happen after the class is frozen raises an exception. """ # Immediately return if __init__ was called after _Enum.__init__(). # It means that casting operator version of the class constructor # is being used. if getattr(type(self), '_DefinitionClass__initialized'): return object.__setattr__(self, 'name', name) object.__setattr__(self, 'number', number) def __setattr__(self, name, value): raise TypeError('May not change enum values') def __str__(self): return self.name def __int__(self): return self.number def __repr__(self): return '%s(%s, %d)' % (type(self).__name__, self.name, self.number) def __cmp__(self, other): """Order is by number.""" if isinstance(other, type(self)): return cmp(self.number, other.number) return NotImplemented @classmethod def to_dict(cls): """Make dictionary version of enumerated class. Dictionary created this way can be used with def_num. Returns: A dict (name) -> number """ return dict((item.name, item.number) for item in iter(cls)) @staticmethod def def_enum(dct, name): """Define enum class from dictionary. Args: dct: Dictionary of enumerated values for type. name: Name of enum. """ return type(name, (Enum,), dct) # TODO(rafek): Determine to what degree this enumeration should be compatible # with FieldDescriptor.Type in: # # http://code.google.com/p/protobuf/source/browse/trunk/src/google/protobuf/descriptor.proto class Variant(Enum): """Wire format variant. Used by the 'protobuf' wire format to determine how to transmit a single piece of data. May be used by other formats. See: http://code.google.com/apis/protocolbuffers/docs/encoding.html Values: DOUBLE: 64-bit floating point number. FLOAT: 32-bit floating point number. INT64: 64-bit signed integer. UINT64: 64-bit unsigned integer. INT32: 32-bit signed integer. BOOL: Boolean value (True or False). STRING: String of UTF-8 encoded text. MESSAGE: Embedded message as byte string. BYTES: String of 8-bit bytes. UINT32: 32-bit unsigned integer. ENUM: Enum value as integer. SINT32: 32-bit signed integer. Uses "zig-zag" encoding. SINT64: 64-bit signed integer. Uses "zig-zag" encoding. """ DOUBLE = 1 FLOAT = 2 INT64 = 3 UINT64 = 4 INT32 = 5 BOOL = 8 STRING = 9 MESSAGE = 11 BYTES = 12 UINT32 = 13 ENUM = 14 SINT32 = 17 SINT64 = 18 class _MessageClass(_DefinitionClass): """Meta-class used for defining the Message base class. For more details about Message classes, see the Message class docstring. Information contained there may help understanding this class. Meta-class enables very specific behavior for any defined Message class. All attributes defined on an Message sub-class must be field instances, Enum class definitions or other Message class definitions. Each field attribute defined on an Message sub-class is added to the set of field definitions and the attribute is translated in to a slot. It also ensures that only one level of Message class hierarchy is possible. In other words it is not possible to declare sub-classes of sub-classes of Message. This class also defines some functions in order to restrict the behavior of the Message class and its sub-classes. It is not possible to change the behavior of the Message class in later classes since any new classes may be defined with only field, Enums and Messages, and no methods. """ def __new__(cls, name, bases, dct): """Create new Message class instance. The __new__ method of the _MessageClass type is overridden so as to allow the translation of Field instances to slots. """ by_number = {} by_name = {} variant_map = {} if bases != (object,): # Can only define one level of sub-classes below Message. if bases != (Message,): raise MessageDefinitionError( 'Message types may only inherit from Message') enums = [] messages = [] # Must not use iteritems because this loop will change the state of dct. for key, field in dct.items(): if key in _RESERVED_ATTRIBUTE_NAMES: continue if isinstance(field, type) and issubclass(field, Enum): enums.append(key) continue if (isinstance(field, type) and issubclass(field, Message) and field is not Message): messages.append(key) continue # Reject anything that is not a field. if type(field) is Field or not isinstance(field, Field): raise MessageDefinitionError( 'May only use fields in message definitions. Found: %s = %s' % (key, field)) if field.number in by_number: raise DuplicateNumberError( 'Field with number %d declared more than once in %s' % (field.number, name)) field.name = key # Place in name and number maps. by_name[key] = field by_number[field.number] = field # Add enums if any exist. if enums: dct['__enums__'] = sorted(enums) # Add messages if any exist. if messages: dct['__messages__'] = sorted(messages) dct['_Message__by_number'] = by_number dct['_Message__by_name'] = by_name return _DefinitionClass.__new__(cls, name, bases, dct) def __init__(cls, name, bases, dct): """Initializer required to assign references to new class.""" if bases != (object,): for value in dct.itervalues(): if isinstance(value, _DefinitionClass) and not value is Message: value._message_definition = weakref.ref(cls) for field in cls.all_fields(): field._message_definition = weakref.ref(cls) _DefinitionClass.__init__(cls, name, bases, dct) class Message(object): """Base class for user defined message objects. Used to define messages for efficient transmission across network or process space. Messages are defined using the field classes (IntegerField, FloatField, EnumField, etc.). Messages are more restricted than normal classes in that they may only contain field attributes and other Message and Enum definitions. These restrictions are in place because the structure of the Message class is intentended to itself be transmitted across network or process space and used directly by clients or even other servers. As such methods and non-field attributes could not be transmitted with the structural information causing discrepancies between different languages and implementations. Initialization and validation: A Message object is considered to be initialized if it has all required fields and any nested messages are also initialized. Calling 'check_initialized' will raise a ValidationException if it is not initialized; 'is_initialized' returns a boolean value indicating if it is valid. Validation automatically occurs when Message objects are created and populated. Validation that a given value will be compatible with a field that it is assigned to can be done through the Field instances validate() method. The validate method used on a message will check that all values of a message and its sub-messages are valid. Assingning an invalid value to a field will raise a ValidationException. Example: # Trade type. class TradeType(Enum): BUY = 1 SELL = 2 SHORT = 3 CALL = 4 class Lot(Message): price = IntegerField(1, required=True) quantity = IntegerField(2, required=True) class Order(Message): symbol = StringField(1, required=True) total_quantity = IntegerField(2, required=True) trade_type = EnumField(TradeType, 3, required=True) lots = MessageField(Lot, 4, repeated=True) limit = IntegerField(5) order = Order(symbol='GOOG', total_quantity=10, trade_type=TradeType.BUY) lot1 = Lot(price=304, quantity=7) lot2 = Lot(price = 305, quantity=3) order.lots = [lot1, lot2] # Now object is initialized! order.check_initialized() """ __metaclass__ = _MessageClass def __init__(self, **kwargs): """Initialize internal messages state. Args: A message can be initialized via the constructor by passing in keyword arguments corresponding to fields. For example: class Date(Message): day = IntegerField(1) month = IntegerField(2) year = IntegerField(3) Invoking: date = Date(day=6, month=6, year=1911) is the same as doing: date = Date() date.day = 6 date.month = 6 date.year = 1911 """ # Tag being an essential implementation detail must be private. self.__tags = {} self.__unrecognized_fields = {} assigned = set() for name, value in kwargs.iteritems(): setattr(self, name, value) assigned.add(name) # initialize repeated fields. for field in self.all_fields(): if field.repeated and field.name not in assigned: setattr(self, field.name, []) def check_initialized(self): """Check class for initialization status. Check that all required fields are initialized Raises: ValidationError: If message is not initialized. """ for name, field in self.__by_name.iteritems(): value = getattr(self, name) if value is None: if field.required: raise ValidationError("Message %s is missing required field %s" % (type(self).__name__, name)) else: try: if (isinstance(field, MessageField) and issubclass(field.message_type, Message)): if field.repeated: for item in value: item_message_value = field.value_to_message(item) item_message_value.check_initialized() else: message_value = field.value_to_message(value) message_value.check_initialized() except ValidationError, err: if not hasattr(err, 'message_name'): err.message_name = type(self).__name__ raise def is_initialized(self): """Get initialization status. Returns: True if message is valid, else False. """ try: self.check_initialized() except ValidationError: return False else: return True @classmethod def all_fields(cls): """Get all field definition objects. Ordering is arbitrary. Returns: Iterator over all values in arbitrary order. """ return cls.__by_name.itervalues() @classmethod def field_by_name(cls, name): """Get field by name. Returns: Field object associated with name. Raises: KeyError if no field found by that name. """ return cls.__by_name[name] @classmethod def field_by_number(cls, number): """Get field by number. Returns: Field object associated with number. Raises: KeyError if no field found by that number. """ return cls.__by_number[number] def get_assigned_value(self, name): """Get the assigned value of an attribute. Get the underlying value of an attribute. If value has not been set, will not return the default for the field. Args: name: Name of attribute to get. Returns: Value of attribute, None if it has not been set. """ message_type = type(self) try: field = message_type.field_by_name(name) except KeyError: raise AttributeError('Message %s has no field %s' % ( message_type.__name__, name)) return self.__tags.get(field.number) def reset(self, name): """Reset assigned value for field. Resetting a field will return it to its default value or None. Args: name: Name of field to reset. """ message_type = type(self) try: field = message_type.field_by_name(name) except KeyError: if name not in message_type.__by_name: raise AttributeError('Message %s has no field %s' % ( message_type.__name__, name)) self.__tags.pop(field.number, None) def all_unrecognized_fields(self): """Get the names of all unrecognized fields in this message.""" return self.__unrecognized_fields.keys() def get_unrecognized_field_info(self, key, value_default=None, variant_default=None): """Get the value and variant of an unknown field in this message. Args: key: The name or number of the field to retrieve. value_default: Value to be returned if the key isn't found. variant_default: Value to be returned as variant if the key isn't found. Returns: (value, variant), where value and variant are whatever was passed to set_unrecognized_field. """ value, variant = self.__unrecognized_fields.get(key, (value_default, variant_default)) return value, variant def set_unrecognized_field(self, key, value, variant): """Set an unrecognized field, used when decoding a message. Args: key: The name or number used to refer to this unknown value. value: The value of the field. variant: Type information needed to interpret the value or re-encode it. Raises: TypeError: If the variant is not an instance of messages.Variant. """ if not isinstance(variant, Variant): raise TypeError('Variant type %s is not valid.' % variant) self.__unrecognized_fields[key] = value, variant def __setattr__(self, name, value): """Change set behavior for messages. Messages may only be assigned values that are fields. Does not try to validate field when set. Args: name: Name of field to assign to. vlaue: Value to assign to field. Raises: AttributeError when trying to assign value that is not a field. """ if name in self.__by_name or name.startswith('_Message__'): object.__setattr__(self, name, value) else: raise AttributeError("May not assign arbitrary value %s " "to message %s" % (name, type(self).__name__)) def __repr__(self): """Make string representation of message. Example: class MyMessage(messages.Message): integer_value = messages.IntegerField(1) string_value = messages.StringField(2) my_message = MyMessage() my_message.integer_value = 42 my_message.string_value = u'A string' print my_message >>> Returns: String representation of message, including the values of all fields and repr of all sub-messages. """ body = ['<', type(self).__name__] for field in sorted(self.all_fields(), key=lambda f: f.number): attribute = field.name value = self.get_assigned_value(field.name) if value is not None: body.append('\n %s: %s' % (attribute, repr(value))) body.append('>') return ''.join(body) def __eq__(self, other): """Equality operator. Does field by field comparison with other message. For equality, must be same type and values of all fields must be equal. Messages not required to be initialized for comparison. Does not attempt to determine equality for values that have default values that are not set. In other words: class HasDefault(Message): attr1 = StringField(1, default='default value') message1 = HasDefault() message2 = HasDefault() message2.attr1 = 'default value' message1 != message2 Does not compare unknown values. Args: other: Other message to compare with. """ # TODO(rafek): Implement "equivalent" which does comparisons # taking default values in to consideration. if self is other: return True if type(self) is not type(other): return False return self.__tags == other.__tags def __ne__(self, other): """Not equals operator. Does field by field comparison with other message. For non-equality, must be different type or any value of a field must be non-equal to the same field in the other instance. Messages not required to be initialized for comparison. Args: other: Other message to compare with. """ return not self.__eq__(other) class FieldList(list): """List implementation that validates field values. This list implementation overrides all methods that add values in to a list in order to validate those new elements. Attempting to add or set list values that are not of the correct type will raise ValidationError. """ def __init__(self, field_instance, sequence): """Constructor. Args: field_instance: Instance of field that validates the list. sequence: List or tuple to construct list from. """ if not field_instance.repeated: raise FieldDefinitionError('FieldList may only accept repeated fields') self.__field = field_instance self.__field.validate(sequence) list.__init__(self, sequence) @property def field(self): """Field that validates list.""" return self.__field def __setslice__(self, i, j, sequence): """Validate slice assignment to list.""" self.__field.validate(sequence) list.__setslice__(self, i, j, sequence) def __setitem__(self, index, value): """Validate item assignment to list.""" self.__field.validate_element(value) list.__setitem__(self, index, value) def append(self, value): """Validate item appending to list.""" self.__field.validate_element(value) return list.append(self, value) def extend(self, sequence): """Validate extension of list.""" self.__field.validate(sequence) return list.extend(self, sequence) def insert(self, index, value): """Validate item insertion to list.""" self.__field.validate_element(value) return list.insert(self, index, value) # TODO(rafek): Prevent additional field subclasses. class Field(object): __variant_to_type = {} class __metaclass__(type): def __init__(cls, name, bases, dct): getattr(cls, '_Field__variant_to_type').update( (variant, cls) for variant in dct.get('VARIANTS', [])) type.__init__(cls, name, bases, dct) __initialized = False @util.positional(2) def __init__(self, number, required=False, repeated=False, variant=None, default=None): """Constructor. The required and repeated parameters are mutually exclusive. Setting both to True will raise a FieldDefinitionError. Sub-class Attributes: Each sub-class of Field must define the following: VARIANTS: Set of variant types accepted by that field. DEFAULT_VARIANT: Default variant type if not specified in constructor. Args: number: Number of field. Must be unique per message class. required: Whether or not field is required. Mutually exclusive with 'repeated'. repeated: Whether or not field is repeated. Mutually exclusive with 'required'. variant: Wire-format variant hint. default: Default value for field if not found in stream. Raises: InvalidVariantError when invalid variant for field is provided. InvalidDefaultError when invalid default for field is provided. FieldDefinitionError when invalid number provided or mutually exclusive fields are used. InvalidNumberError when the field number is out of range or reserved. """ if not isinstance(number, int) or not 1 <= number <= MAX_FIELD_NUMBER: raise InvalidNumberError('Invalid number for field: %s\n' 'Number must be 1 or greater and %d or less' % (number, MAX_FIELD_NUMBER)) if FIRST_RESERVED_FIELD_NUMBER <= number <= LAST_RESERVED_FIELD_NUMBER: raise InvalidNumberError('Tag number %d is a reserved number.\n' 'Numbers %d to %d are reserved' % (number, FIRST_RESERVED_FIELD_NUMBER, LAST_RESERVED_FIELD_NUMBER)) if repeated and required: raise FieldDefinitionError('Cannot set both repeated and required') if variant is None: variant = self.DEFAULT_VARIANT if repeated and default is not None: raise FieldDefinitionError('Repeated fields may not have defaults') if variant not in self.VARIANTS: raise InvalidVariantError( 'Invalid variant: %s\nValid variants for %s are %r' % (variant, type(self).__name__, sorted(self.VARIANTS))) self.number = number self.required = required self.repeated = repeated self.variant = variant if default is not None: try: self.validate_default(default) except ValidationError, err: try: name = self.name except AttributeError: # For when raising error before name initialization. raise InvalidDefaultError('Invalid default value for %s: %s: %s' % (self.__class__.__name__, default, err)) else: raise InvalidDefaultError('Invalid default value for field %s: ' '%s: %s' % (name, default, err)) self.__default = default self.__initialized = True def __setattr__(self, name, value): """Setter overidden to prevent assignment to fields after creation. Args: name: Name of attribute to set. value: Value to assign. """ # Special case post-init names. They need to be set after constructor. if name in _POST_INIT_FIELD_ATTRIBUTE_NAMES: object.__setattr__(self, name, value) return # All other attributes must be set before __initialized. if not self.__initialized: # Not initialized yet, allow assignment. object.__setattr__(self, name, value) else: raise AttributeError('Field objects are read-only') def __set__(self, message_instance, value): """Set value on message. Args: message_instance: Message instance to set value on. value: Value to set on message. """ # Reaches in to message instance directly to assign to private tags. if value is None: if self.repeated: raise ValidationError( 'May not assign None to repeated field %s' % self.name) else: message_instance._Message__tags.pop(self.number, None) else: if self.repeated: value = FieldList(self, value) else: self.validate(value) message_instance._Message__tags[self.number] = value def __get__(self, message_instance, message_class): if message_instance is None: return self result = message_instance._Message__tags.get(self.number) if result is None: return self.default else: return result def validate_element(self, value): """Validate single element of field. This is different from validate in that it is used on individual values of repeated fields. Args: value: Value to validate. Raises: ValidationError if value is not expected type. """ if not isinstance(value, self.type): if value is None: if self.required: raise ValidationError('Required field is missing') else: try: name = self.name except AttributeError: raise ValidationError('Expected type %s for %s, ' 'found %s (type %s)' % (self.type, self.__class__.__name__, value, type(value))) else: raise ValidationError('Expected type %s for field %s, ' 'found %s (type %s)' % (self.type, name, value, type(value))) def __validate(self, value, validate_element): """Internal validation function. Validate an internal value using a function to validate individual elements. Args: value: Value to validate. validate_element: Function to use to validate individual elements. Raises: ValidationError if value is not expected type. """ if not self.repeated: validate_element(value) else: # Must be a list or tuple, may not be a string. if isinstance(value, (list, tuple)): for element in value: if element is None: try: name = self.name except AttributeError: raise ValidationError('Repeated values for %s ' 'may not be None' % self.__class__.__name__) else: raise ValidationError('Repeated values for field %s ' 'may not be None' % name) validate_element(element) elif value is not None: try: name = self.name except AttributeError: raise ValidationError('%s is repeated. Found: %s' % ( self.__class__.__name__, value)) else: raise ValidationError('Field %s is repeated. Found: %s' % (name, value)) def validate(self, value): """Validate value assigned to field. Args: value: Value to validate. Raises: ValidationError if value is not expected type. """ self.__validate(value, self.validate_element) def validate_default_element(self, value): """Validate value as assigned to field default field. Some fields may allow for delayed resolution of default types necessary in the case of circular definition references. In this case, the default value might be a place holder that is resolved when needed after all the message classes are defined. Args: value: Default value to validate. Raises: ValidationError if value is not expected type. """ self.validate_element(value) def validate_default(self, value): """Validate default value assigned to field. Args: value: Value to validate. Raises: ValidationError if value is not expected type. """ self.__validate(value, self.validate_default_element) def message_definition(self): """Get Message definition that contains this Field definition. Returns: Containing Message definition for Field. Will return None if for some reason Field is defined outside of a Message class. """ try: return self._message_definition() except AttributeError: return None @property def default(self): """Get default value for field.""" return self.__default @classmethod def lookup_field_type_by_variant(cls, variant): return cls.__variant_to_type[variant] class IntegerField(Field): """Field definition for integer values.""" VARIANTS = frozenset([Variant.INT32, Variant.INT64, Variant.UINT32, Variant.UINT64, Variant.SINT32, Variant.SINT64, ]) DEFAULT_VARIANT = Variant.INT64 type = (int, long) class FloatField(Field): """Field definition for float values.""" VARIANTS = frozenset([Variant.FLOAT, Variant.DOUBLE, ]) DEFAULT_VARIANT = Variant.DOUBLE type = float class BooleanField(Field): """Field definition for boolean values.""" VARIANTS = frozenset([Variant.BOOL]) DEFAULT_VARIANT = Variant.BOOL type = bool class BytesField(Field): """Field definition for byte string values.""" VARIANTS = frozenset([Variant.BYTES]) DEFAULT_VARIANT = Variant.BYTES type = str class StringField(Field): """Field definition for unicode string values.""" VARIANTS = frozenset([Variant.STRING]) DEFAULT_VARIANT = Variant.STRING type = unicode def validate_element(self, value): """Validate StringField allowing for str and unicode. Raises: ValidationError if a str value is not 7-bit ascii. """ # If value is str is it considered valid. Satisfies "required=True". if isinstance(value, str): try: unicode(value) except UnicodeDecodeError, err: try: name = self.name except AttributeError: validation_error = ValidationError( 'Field encountered non-ASCII string %s: %s' % (value, err)) else: validation_error = ValidationError( 'Field %s encountered non-ASCII string %s: %s' % (self.name, value, err)) validation_error.field_name = self.name raise validation_error else: super(StringField, self).validate_element(value) class MessageField(Field): """Field definition for sub-message values. Message fields contain instance of other messages. Instances stored on messages stored on message fields are considered to be owned by the containing message instance and should not be shared between owning instances. Message fields must be defined to reference a single type of message. Normally message field are defined by passing the referenced message class in to the constructor. It is possible to define a message field for a type that does not yet exist by passing the name of the message in to the constructor instead of a message class. Resolution of the actual type of the message is deferred until it is needed, for example, during message verification. Names provided to the constructor must refer to a class within the same python module as the class that is using it. Names refer to messages relative to the containing messages scope. For example, the two fields of OuterMessage refer to the same message type: class Outer(Message): inner_relative = MessageField('Inner', 1) inner_absolute = MessageField('Outer.Inner', 2) class Inner(Message): ... When resolving an actual type, MessageField will traverse the entire scope of nested messages to match a message name. This makes it easy for siblings to reference siblings: class Outer(Message): class Inner(Message): sibling = MessageField('Sibling', 1) class Sibling(Message): ... """ VARIANTS = frozenset([Variant.MESSAGE]) DEFAULT_VARIANT = Variant.MESSAGE @util.positional(3) def __init__(self, message_type, number, required=False, repeated=False, variant=None): """Constructor. Args: message_type: Message type for field. Must be subclass of Message. number: Number of field. Must be unique per message class. required: Whether or not field is required. Mutually exclusive to 'repeated'. repeated: Whether or not field is repeated. Mutually exclusive to 'required'. variant: Wire-format variant hint. Raises: FieldDefinitionError when invalid message_type is provided. """ valid_type = (isinstance(message_type, basestring) or (message_type is not Message and isinstance(message_type, type) and issubclass(message_type, Message))) if not valid_type: raise FieldDefinitionError('Invalid message class: %s' % message_type) if isinstance(message_type, basestring): self.__type_name = message_type self.__type = None else: self.__type = message_type super(MessageField, self).__init__(number, required=required, repeated=repeated, variant=variant) @property def type(self): """Message type used for field.""" if self.__type is None: message_type = find_definition(self.__type_name, self.message_definition()) if not (message_type is not Message and isinstance(message_type, type) and issubclass(message_type, Message)): raise FieldDefinitionError('Invalid message class: %s' % message_type) self.__type = message_type return self.__type @property def message_type(self): """Underlying message type used for serialization. Will always be a sub-class of Message. This is different from type which represents the python value that message_type is mapped to for use by the user. """ return self.type def value_from_message(self, message): """Convert a message to a value instance. Used by deserializers to convert from underlying messages to value of expected user type. Args: message: A message instance of type self.message_type. Returns: Value of self.message_type. """ if not isinstance(message, self.message_type): raise DecodeError('Expected type %s, got %s: %r' % (self.message_type.__name__, type(message).__name__, message)) return message def value_to_message(self, value): """Convert a value instance to a message. Used by serializers to convert Python user types to underlying messages for transmission. Args: value: A value of type self.type. Returns: An instance of type self.message_type. """ if not isinstance(value, self.type): raise EncodeError('Expected type %s, got %s: %r' % (self.type.__name__, type(value).__name__, value)) return value class EnumField(Field): """Field definition for enum values. Enum fields may have default values that are delayed until the associated enum type is resolved. This is necessary to support certain circular references. For example: class Message1(Message): class Color(Enum): RED = 1 GREEN = 2 BLUE = 3 # This field default value will be validated when default is accessed. animal = EnumField('Message2.Animal', 1, default='HORSE') class Message2(Message): class Animal(Enum): DOG = 1 CAT = 2 HORSE = 3 # This fields default value will be validated right away since Color is # already fully resolved. color = EnumField(Message1.Color, 1, default='RED') """ VARIANTS = frozenset([Variant.ENUM]) DEFAULT_VARIANT = Variant.ENUM def __init__(self, enum_type, number, **kwargs): """Constructor. Args: enum_type: Enum type for field. Must be subclass of Enum. number: Number of field. Must be unique per message class. required: Whether or not field is required. Mutually exclusive to 'repeated'. repeated: Whether or not field is repeated. Mutually exclusive to 'required'. variant: Wire-format variant hint. default: Default value for field if not found in stream. Raises: FieldDefinitionError when invalid enum_type is provided. """ valid_type = (isinstance(enum_type, basestring) or (enum_type is not Enum and isinstance(enum_type, type) and issubclass(enum_type, Enum))) if not valid_type: raise FieldDefinitionError('Invalid enum type: %s' % enum_type) if isinstance(enum_type, basestring): self.__type_name = enum_type self.__type = None else: self.__type = enum_type super(EnumField, self).__init__(number, **kwargs) def validate_default_element(self, value): """Validate default element of Enum field. Enum fields allow for delayed resolution of default values when the type of the field has not been resolved. The default value of a field may be a string or an integer. If the Enum type of the field has been resolved, the default value is validated against that type. Args: value: Value to validate. Raises: ValidationError if value is not expected message type. """ if isinstance(value, (basestring, int, long)): # Validation of the value does not happen for delayed resolution # enumerated types. Ignore if type is not yet resolved. if self.__type: self.__type(value) return super(EnumField, self).validate_default_element(value) @property def type(self): """Enum type used for field.""" if self.__type is None: found_type = find_definition(self.__type_name, self.message_definition()) if not (found_type is not Enum and isinstance(found_type, type) and issubclass(found_type, Enum)): raise FieldDefinitionError('Invalid enum type: %s' % found_type) self.__type = found_type return self.__type @property def default(self): """Default for enum field. Will cause resolution of Enum type and unresolved default value. """ try: return self.__resolved_default except AttributeError: resolved_default = super(EnumField, self).default if isinstance(resolved_default, (basestring, int, long)): resolved_default = self.type(resolved_default) self.__resolved_default = resolved_default return self.__resolved_default @util.positional(2) def find_definition(name, relative_to=None, importer=__import__): """Find definition by name in module-space. The find algorthm will look for definitions by name relative to a message definition or by fully qualfied name. If no definition is found relative to the relative_to parameter it will do the same search against the container of relative_to. If relative_to is a nested Message, it will search its message_definition(). If that message has no message_definition() it will search its module. If relative_to is a module, it will attempt to look for the containing module and search relative to it. If the module is a top-level module, it will look for the a message using a fully qualified name. If no message is found then, the search fails and DefinitionNotFoundError is raised. For example, when looking for any definition 'foo.bar.ADefinition' relative to an actual message definition abc.xyz.SomeMessage: find_definition('foo.bar.ADefinition', SomeMessage) It is like looking for the following fully qualified names: abc.xyz.SomeMessage. foo.bar.ADefinition abc.xyz. foo.bar.ADefinition abc. foo.bar.ADefinition foo.bar.ADefinition When resolving the name relative to Message definitions and modules, the algorithm searches any Messages or sub-modules found in its path. Non-Message values are not searched. A name that begins with '.' is considered to be a fully qualified name. The name is always searched for from the topmost package. For example, assume two message types: abc.xyz.SomeMessage xyz.SomeMessage Searching for '.xyz.SomeMessage' relative to 'abc' will resolve to 'xyz.SomeMessage' and not 'abc.xyz.SomeMessage'. For this kind of name, the relative_to parameter is effectively ignored and always set to None. For more information about package name resolution, please see: http://code.google.com/apis/protocolbuffers/docs/proto.html#packages Args: name: Name of definition to find. May be fully qualified or relative name. relative_to: Search for definition relative to message definition or module. None will cause a fully qualified name search. importer: Import function to use for resolving modules. Returns: Enum or Message class definition associated with name. Raises: DefinitionNotFoundError if no definition is found in any search path. """ # Check parameters. if not (relative_to is None or isinstance(relative_to, types.ModuleType) or isinstance(relative_to, type) and issubclass(relative_to, Message)): raise TypeError('relative_to must be None, Message definition or module. ' 'Found: %s' % relative_to) name_path = name.split('.') # Handle absolute path reference. if not name_path[0]: relative_to = None name_path = name_path[1:] def search_path(): """Performs a single iteration searching the path from relative_to. This is the function that searches up the path from a relative object. fully.qualified.object . relative.or.nested.Definition ----------------------------> ^ | this part of search --+ Returns: Message or Enum at the end of name_path, else None. """ next = relative_to for node in name_path: # Look for attribute first. attribute = getattr(next, node, None) if attribute is not None: next = attribute else: # If module, look for sub-module. if next is None or isinstance(next, types.ModuleType): if next is None: module_name = node else: module_name = '%s.%s' % (next.__name__, node) try: fromitem = module_name.split('.')[-1] next = importer(module_name, '', '', [str(fromitem)]) except ImportError: return None else: return None if (not isinstance(next, types.ModuleType) and not (isinstance(next, type) and issubclass(next, (Message, Enum)))): return None return next while True: found = search_path() if isinstance(found, type) and issubclass(found, (Enum, Message)): return found else: # Find next relative_to to search against. # # fully.qualified.object . relative.or.nested.Definition # <--------------------- # ^ # | # does this part of search if relative_to is None: # Fully qualified search was done. Nothing found. Fail. raise DefinitionNotFoundError('Could not find definition for %s' % (name,)) else: if isinstance(relative_to, types.ModuleType): # Find parent module. module_path = relative_to.__name__.split('.')[:-1] if not module_path: relative_to = None else: # Should not raise ImportError. If it does... weird and # unexepected. Propagate. relative_to = importer( '.'.join(module_path), '', '', [module_path[-1]]) elif (isinstance(relative_to, type) and issubclass(relative_to, Message)): parent = relative_to.message_definition() if parent is None: last_module_name = relative_to.__module__.split('.')[-1] relative_to = importer( relative_to.__module__, '', '', [last_module_name]) else: relative_to = parent protorpc-standalone-0.9.1/protorpc/messages_test.py0000755000076500000240000020414212277637135023603 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.messages.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import imp import inspect import new import re import sys import types import unittest from protorpc import descriptor from protorpc import message_types from protorpc import messages from protorpc import test_util class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = messages class ValidationErrorTest(test_util.TestCase): def testStr_NoFieldName(self): """Test string version of ValidationError when no name provided.""" self.assertEquals('Validation error', str(messages.ValidationError('Validation error'))) def testStr_FieldName(self): """Test string version of ValidationError when no name provided.""" validation_error = messages.ValidationError('Validation error') validation_error.field_name = 'a_field' self.assertEquals('Validation error', str(validation_error)) class EnumTest(test_util.TestCase): def setUp(self): """Set up tests.""" # Redefine Color class in case so that changes to it (an error) in one test # does not affect other tests. global Color class Color(messages.Enum): RED = 20 ORANGE = 2 YELLOW = 40 GREEN = 4 BLUE = 50 INDIGO = 5 VIOLET = 80 def testNames(self): """Test that names iterates over enum names.""" self.assertEquals( set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED', 'VIOLET', 'YELLOW']), set(Color.names())) def testNumbers(self): """Tests that numbers iterates of enum numbers.""" self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers())) def testIterate(self): """Test that __iter__ iterates over all enum values.""" self.assertEquals(set(Color), set([Color.RED, Color.ORANGE, Color.YELLOW, Color.GREEN, Color.BLUE, Color.INDIGO, Color.VIOLET])) def testNaturalOrder(self): """Test that natural order enumeration is in numeric order.""" self.assertEquals([Color.ORANGE, Color.GREEN, Color.INDIGO, Color.RED, Color.YELLOW, Color.BLUE, Color.VIOLET], sorted(Color)) def testByName(self): """Test look-up by name.""" self.assertEquals(Color.RED, Color.lookup_by_name('RED')) self.assertRaises(KeyError, Color.lookup_by_name, 20) self.assertRaises(KeyError, Color.lookup_by_name, Color.RED) def testByNumber(self): """Test look-up by number.""" self.assertRaises(KeyError, Color.lookup_by_number, 'RED') self.assertEquals(Color.RED, Color.lookup_by_number(20)) self.assertRaises(KeyError, Color.lookup_by_number, Color.RED) def testConstructor(self): """Test that constructor look-up by name or number.""" self.assertEquals(Color.RED, Color('RED')) self.assertEquals(Color.RED, Color(u'RED')) self.assertEquals(Color.RED, Color(20)) self.assertEquals(Color.RED, Color(20L)) self.assertEquals(Color.RED, Color(Color.RED)) self.assertRaises(TypeError, Color, 'Not exists') self.assertRaises(TypeError, Color, 'Red') self.assertRaises(TypeError, Color, 100) self.assertRaises(TypeError, Color, 10.0) def testLen(self): """Test that len function works to count enums.""" self.assertEquals(7, len(Color)) def testNoSubclasses(self): """Test that it is not possible to sub-class enum classes.""" def declare_subclass(): class MoreColor(Color): pass self.assertRaises(messages.EnumDefinitionError, declare_subclass) def testClassNotMutable(self): """Test that enum classes themselves are not mutable.""" self.assertRaises(AttributeError, setattr, Color, 'something_new', 10) def testInstancesMutable(self): """Test that enum instances are not mutable.""" self.assertRaises(TypeError, setattr, Color.RED, 'something_new', 10) def testDefEnum(self): """Test def_enum works by building enum class from dict.""" WeekDay = messages.Enum.def_enum({'Monday': 1, 'Tuesday': 2, 'Wednesday': 3, 'Thursday': 4, 'Friday': 6, 'Saturday': 7, 'Sunday': 8}, 'WeekDay') self.assertEquals('Wednesday', WeekDay(3).name) self.assertEquals(6, WeekDay('Friday').number) self.assertEquals(WeekDay.Sunday, WeekDay('Sunday')) def testNonInt(self): """Test that non-integer values rejection by enum def.""" self.assertRaises(messages.EnumDefinitionError, messages.Enum.def_enum, {'Bad': '1'}, 'BadEnum') def testNegativeInt(self): """Test that negative numbers rejection by enum def.""" self.assertRaises(messages.EnumDefinitionError, messages.Enum.def_enum, {'Bad': -1}, 'BadEnum') def testLowerBound(self): """Test that zero is accepted by enum def.""" class NotImportant(messages.Enum): """Testing for value zero""" VALUE = 0 self.assertEquals(0, int(NotImportant.VALUE)) def testTooLargeInt(self): """Test that numbers too large are rejected.""" self.assertRaises(messages.EnumDefinitionError, messages.Enum.def_enum, {'Bad': (2 ** 29)}, 'BadEnum') def testRepeatedInt(self): """Test duplicated numbers are forbidden.""" self.assertRaises(messages.EnumDefinitionError, messages.Enum.def_enum, {'Ok': 1, 'Repeated': 1}, 'BadEnum') def testStr(self): """Test converting to string.""" self.assertEquals('RED', str(Color.RED)) self.assertEquals('ORANGE', str(Color.ORANGE)) def testInt(self): """Test converting to int.""" self.assertEquals(20, int(Color.RED)) self.assertEquals(2, int(Color.ORANGE)) def testRepr(self): """Test enum representation.""" self.assertEquals('Color(RED, 20)', repr(Color.RED)) self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW)) def testDocstring(self): """Test that docstring is supported ok.""" class NotImportant(messages.Enum): """I have a docstring.""" VALUE1 = 1 self.assertEquals('I have a docstring.', NotImportant.__doc__) def testDeleteEnumValue(self): """Test that enum values cannot be deleted.""" self.assertRaises(TypeError, delattr, Color, 'RED') def testEnumName(self): """Test enum name.""" module_name = test_util.get_module_name(EnumTest) self.assertEquals('%s.Color' % module_name, Color.definition_name()) self.assertEquals(module_name, Color.outer_definition_name()) self.assertEquals(module_name, Color.definition_package()) def testDefinitionName_OverrideModule(self): """Test enum module is overriden by module package name.""" global package try: package = 'my.package' self.assertEquals('my.package.Color', Color.definition_name()) self.assertEquals('my.package', Color.outer_definition_name()) self.assertEquals('my.package', Color.definition_package()) finally: del package def testDefinitionName_NoModule(self): """Test what happens when there is no module for enum.""" class Enum1(messages.Enum): pass original_modules = sys.modules sys.modules = dict(sys.modules) try: del sys.modules[__name__] self.assertEquals('Enum1', Enum1.definition_name()) self.assertEquals(None, Enum1.outer_definition_name()) self.assertEquals(None, Enum1.definition_package()) self.assertEquals(unicode, type(Enum1.definition_name())) finally: sys.modules = original_modules def testDefinitionName_Nested(self): """Test nested Enum names.""" class MyMessage(messages.Message): class NestedEnum(messages.Enum): pass class NestedMessage(messages.Message): class NestedEnum(messages.Enum): pass module_name = test_util.get_module_name(EnumTest) self.assertEquals('%s.MyMessage.NestedEnum' % module_name, MyMessage.NestedEnum.definition_name()) self.assertEquals('%s.MyMessage' % module_name, MyMessage.NestedEnum.outer_definition_name()) self.assertEquals(module_name, MyMessage.NestedEnum.definition_package()) self.assertEquals('%s.MyMessage.NestedMessage.NestedEnum' % module_name, MyMessage.NestedMessage.NestedEnum.definition_name()) self.assertEquals( '%s.MyMessage.NestedMessage' % module_name, MyMessage.NestedMessage.NestedEnum.outer_definition_name()) self.assertEquals(module_name, MyMessage.NestedMessage.NestedEnum.definition_package()) def testMessageDefinition(self): """Test that enumeration knows its enclosing message definition.""" class OuterEnum(messages.Enum): pass self.assertEquals(None, OuterEnum.message_definition()) class OuterMessage(messages.Message): class InnerEnum(messages.Enum): pass self.assertEquals(OuterMessage, OuterMessage.InnerEnum.message_definition()) def testComparison(self): """Test comparing various enums to different types.""" class Enum1(messages.Enum): VAL1 = 1 VAL2 = 2 class Enum2(messages.Enum): VAL1 = 1 self.assertEquals(Enum1.VAL1, Enum1.VAL1) self.assertNotEquals(Enum1.VAL1, Enum1.VAL2) self.assertNotEquals(Enum1.VAL1, Enum2.VAL1) self.assertNotEquals(Enum1.VAL1, 'VAL1') self.assertNotEquals(Enum1.VAL1, 1) self.assertNotEquals(Enum1.VAL1, 2) self.assertNotEquals(Enum1.VAL1, None) self.assertNotEquals(Enum1.VAL1, Enum2.VAL1) self.assertTrue(Enum1.VAL1 < Enum1.VAL2) self.assertTrue(Enum1.VAL2 > Enum1.VAL1) self.assertNotEquals(1, Enum2.VAL1) class FieldListTest(test_util.TestCase): def setUp(self): self.integer_field = messages.IntegerField(1, repeated=True) def testConstructor(self): self.assertEquals([1, 2, 3], messages.FieldList(self.integer_field, [1, 2, 3])) self.assertEquals([1, 2, 3], messages.FieldList(self.integer_field, (1, 2, 3))) self.assertEquals([], messages.FieldList(self.integer_field, [])) def testNone(self): self.assertRaises(TypeError, messages.FieldList, self.integer_field, None) def testDoNotAutoConvertString(self): string_field = messages.StringField(1, repeated=True) self.assertRaises(messages.ValidationError, messages.FieldList, string_field, 'abc') def testConstructorCopies(self): a_list = [1, 3, 6] field_list = messages.FieldList(self.integer_field, a_list) self.assertFalse(a_list is field_list) self.assertFalse(field_list is messages.FieldList(self.integer_field, field_list)) def testNonRepeatedField(self): self.assertRaisesWithRegexpMatch( messages.FieldDefinitionError, 'FieldList may only accept repeated fields', messages.FieldList, messages.IntegerField(1), []) def testConstructor_InvalidValues(self): self.assertRaisesWithRegexpMatch( messages.ValidationError, re.escape("Expected type (, ) " "for IntegerField, found 1 (type )"), messages.FieldList, self.integer_field, ["1", "2", "3"]) def testConstructor_Scalars(self): self.assertRaisesWithRegexpMatch( messages.ValidationError, "IntegerField is repeated. Found: 3", messages.FieldList, self.integer_field, 3) self.assertRaisesWithRegexpMatch( messages.ValidationError, "IntegerField is repeated. Found: , ) " "for IntegerField, found 10 (type )"), setslice) def testSetItem(self): field_list = messages.FieldList(self.integer_field, [2]) field_list[0] = 10 self.assertEquals([10], field_list) def testSetItem_InvalidValues(self): field_list = messages.FieldList(self.integer_field, [2]) def setitem(): field_list[0] = '10' self.assertRaisesWithRegexpMatch( messages.ValidationError, re.escape("Expected type (, ) " "for IntegerField, found 10 (type )"), setitem) def testAppend(self): field_list = messages.FieldList(self.integer_field, [2]) field_list.append(10) self.assertEquals([2, 10], field_list) def testAppend_InvalidValues(self): field_list = messages.FieldList(self.integer_field, [2]) field_list.name = 'a_field' def append(): field_list.append('10') self.assertRaisesWithRegexpMatch( messages.ValidationError, re.escape("Expected type (, ) " "for IntegerField, found 10 (type )"), append) def testExtend(self): field_list = messages.FieldList(self.integer_field, [2]) field_list.extend([10]) self.assertEquals([2, 10], field_list) def testExtend_InvalidValues(self): field_list = messages.FieldList(self.integer_field, [2]) def extend(): field_list.extend(['10']) self.assertRaisesWithRegexpMatch( messages.ValidationError, re.escape("Expected type (, ) " "for IntegerField, found 10 (type )"), extend) def testInsert(self): field_list = messages.FieldList(self.integer_field, [2, 3]) field_list.insert(1, 10) self.assertEquals([2, 10, 3], field_list) def testInsert_InvalidValues(self): field_list = messages.FieldList(self.integer_field, [2, 3]) def insert(): field_list.insert(1, '10') self.assertRaisesWithRegexpMatch( messages.ValidationError, re.escape("Expected type (, ) " "for IntegerField, found 10 (type )"), insert) class FieldTest(test_util.TestCase): def ActionOnAllFieldClasses(self, action): """Test all field classes except Message and Enum. Message and Enum require separate tests. Args: action: Callable that takes the field class as a parameter. """ for field_class in (messages.IntegerField, messages.FloatField, messages.BooleanField, messages.BytesField, messages.StringField, ): action(field_class) def testNumberAttribute(self): """Test setting the number attribute.""" def action(field_class): # Check range. self.assertRaises(messages.InvalidNumberError, field_class, 0) self.assertRaises(messages.InvalidNumberError, field_class, -1) self.assertRaises(messages.InvalidNumberError, field_class, messages.MAX_FIELD_NUMBER + 1) # Check reserved. self.assertRaises(messages.InvalidNumberError, field_class, messages.FIRST_RESERVED_FIELD_NUMBER) self.assertRaises(messages.InvalidNumberError, field_class, messages.LAST_RESERVED_FIELD_NUMBER) self.assertRaises(messages.InvalidNumberError, field_class, '1') # This one should work. field_class(number=1) self.ActionOnAllFieldClasses(action) def testRequiredAndRepeated(self): """Test setting the required and repeated fields.""" def action(field_class): field_class(1, required=True) field_class(1, repeated=True) self.assertRaises(messages.FieldDefinitionError, field_class, 1, required=True, repeated=True) self.ActionOnAllFieldClasses(action) def testInvalidVariant(self): """Test field with invalid variants.""" def action(field_class): if field_class is not message_types.DateTimeField: self.assertRaises(messages.InvalidVariantError, field_class, 1, variant=messages.Variant.ENUM) self.ActionOnAllFieldClasses(action) def testDefaultVariant(self): """Test that default variant is used when not set.""" def action(field_class): field = field_class(1) self.assertEquals(field_class.DEFAULT_VARIANT, field.variant) self.ActionOnAllFieldClasses(action) def testAlternateVariant(self): """Test that default variant is used when not set.""" field = messages.IntegerField(1, variant=messages.Variant.UINT32) self.assertEquals(messages.Variant.UINT32, field.variant) def testDefaultFields_Single(self): """Test default field is correct type.""" defaults = {messages.IntegerField: 10, messages.FloatField: 1.5, messages.BooleanField: False, messages.BytesField: 'abc', messages.StringField: u'abc', } def action(field_class): field_class(1, default=defaults[field_class]) self.ActionOnAllFieldClasses(action) # Run defaults test again checking for str/unicode compatiblity. defaults[messages.StringField] = 'abc' self.ActionOnAllFieldClasses(action) def testStringField_BadUnicodeInDefault(self): """Test binary values in string field.""" self.assertRaisesWithRegexpMatch( messages.InvalidDefaultError, 'Invalid default value for StringField: \211: ' 'Field encountered non-ASCII string \211:', messages.StringField, 1, default='\x89') def testDefaultFields_InvalidSingle(self): """Test default field is correct type.""" def action(field_class): self.assertRaises(messages.InvalidDefaultError, field_class, 1, default=object()) self.ActionOnAllFieldClasses(action) def testDefaultFields_InvalidRepeated(self): """Test default field does not accept defaults.""" self.assertRaisesWithRegexpMatch( messages.FieldDefinitionError, 'Repeated fields may not have defaults', messages.StringField, 1, repeated=True, default=[1, 2, 3]) def testDefaultFields_None(self): """Test none is always acceptable.""" def action(field_class): field_class(1, default=None) field_class(1, required=True, default=None) field_class(1, repeated=True, default=None) self.ActionOnAllFieldClasses(action) def testDefaultFields_Enum(self): """Test the default for enum fields.""" class Symbol(messages.Enum): ALPHA = 1 BETA = 2 GAMMA = 3 field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA) self.assertEquals(Symbol.ALPHA, field.default) def testDefaultFields_EnumStringDelayedResolution(self): """Test that enum fields resolve default strings.""" field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', 1, default='OPTIONAL') self.assertEquals(descriptor.FieldDescriptor.Label.OPTIONAL, field.default) def testDefaultFields_EnumIntDelayedResolution(self): """Test that enum fields resolve default integers.""" field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', 1, default=2) self.assertEquals(descriptor.FieldDescriptor.Label.REQUIRED, field.default) def testDefaultFields_EnumOkIfTypeKnown(self): """Test that enum fields accept valid default values when type is known.""" field = messages.EnumField(descriptor.FieldDescriptor.Label, 1, default='REPEATED') self.assertEquals(descriptor.FieldDescriptor.Label.REPEATED, field.default) def testDefaultFields_EnumForceCheckIfTypeKnown(self): """Test that enum fields validate default values if type is known.""" self.assertRaisesWithRegexpMatch(TypeError, 'No such value for NOT_A_LABEL in ' 'Enum Label', messages.EnumField, descriptor.FieldDescriptor.Label, 1, default='NOT_A_LABEL') def testDefaultFields_EnumInvalidDelayedResolution(self): """Test that enum fields raise errors upon delayed resolution error.""" field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', 1, default=200) self.assertRaisesWithRegexpMatch(TypeError, 'No such value for 200 in Enum Label', getattr, field, 'default') def testValidate_Valid(self): """Test validation of valid values.""" values = {messages.IntegerField: 10, messages.FloatField: 1.5, messages.BooleanField: False, messages.BytesField: 'abc', messages.StringField: u'abc', } def action(field_class): # Optional. field = field_class(1) field.validate(values[field_class]) # Required. field = field_class(1, required=True) field.validate(values[field_class]) # Repeated. field = field_class(1, repeated=True) field.validate([]) field.validate(()) field.validate([values[field_class]]) field.validate((values[field_class],)) # Right value, but not repeated. self.assertRaises(messages.ValidationError, field.validate, values[field_class]) self.assertRaises(messages.ValidationError, field.validate, values[field_class]) self.ActionOnAllFieldClasses(action) def testValidate_Invalid(self): """Test validation of valid values.""" values = {messages.IntegerField: "10", messages.FloatField: 1, messages.BooleanField: 0, messages.BytesField: 10.20, messages.StringField: 42, } def action(field_class): # Optional. field = field_class(1) self.assertRaises(messages.ValidationError, field.validate, values[field_class]) # Required. field = field_class(1, required=True) self.assertRaises(messages.ValidationError, field.validate, values[field_class]) # Repeated. field = field_class(1, repeated=True) self.assertRaises(messages.ValidationError, field.validate, [values[field_class]]) self.assertRaises(messages.ValidationError, field.validate, (values[field_class],)) self.ActionOnAllFieldClasses(action) def testValidate_None(self): """Test that None is valid for non-required fields.""" def action(field_class): # Optional. field = field_class(1) field.validate(None) # Required. field = field_class(1, required=True) self.assertRaisesWithRegexpMatch(messages.ValidationError, 'Required field is missing', field.validate, None) # Repeated. field = field_class(1, repeated=True) field.validate(None) self.assertRaisesWithRegexpMatch(messages.ValidationError, 'Repeated values for %s may ' 'not be None' % field_class.__name__, field.validate, [None]) self.assertRaises(messages.ValidationError, field.validate, (None,)) self.ActionOnAllFieldClasses(action) def testValidateElement(self): """Test validation of valid values.""" values = {messages.IntegerField: 10, messages.FloatField: 1.5, messages.BooleanField: False, messages.BytesField: 'abc', messages.StringField: u'abc', } def action(field_class): # Optional. field = field_class(1) field.validate_element(values[field_class]) # Required. field = field_class(1, required=True) field.validate_element(values[field_class]) # Repeated. field = field_class(1, repeated=True) self.assertRaises(message.VAlidationError, field.validate_element, []) self.assertRaises(message.VAlidationError, field.validate_element, ()) field.validate_element(values[field_class]) field.validate_element(values[field_class]) # Right value, but repeated. self.assertRaises(messages.ValidationError, field.validate_element, [values[field_class]]) self.assertRaises(messages.ValidationError, field.validate_element, (values[field_class],)) def testReadOnly(self): """Test that objects are all read-only.""" def action(field_class): field = field_class(10) self.assertRaises(AttributeError, setattr, field, 'number', 20) self.assertRaises(AttributeError, setattr, field, 'anything_else', 'whatever') self.ActionOnAllFieldClasses(action) def testMessageField(self): """Test the construction of message fields.""" self.assertRaises(messages.FieldDefinitionError, messages.MessageField, str, 10) self.assertRaises(messages.FieldDefinitionError, messages.MessageField, messages.Message, 10) class MyMessage(messages.Message): pass field = messages.MessageField(MyMessage, 10) self.assertEquals(MyMessage, field.type) def testMessageField_ForwardReference(self): """Test the construction of forward reference message fields.""" global MyMessage global ForwardMessage try: class MyMessage(messages.Message): self_reference = messages.MessageField('MyMessage', 1) forward = messages.MessageField('ForwardMessage', 2) nested = messages.MessageField('ForwardMessage.NestedMessage', 3) inner = messages.MessageField('Inner', 4) class Inner(messages.Message): sibling = messages.MessageField('Sibling', 1) class Sibling(messages.Message): pass class ForwardMessage(messages.Message): class NestedMessage(messages.Message): pass self.assertEquals(MyMessage, MyMessage.field_by_name('self_reference').type) self.assertEquals(ForwardMessage, MyMessage.field_by_name('forward').type) self.assertEquals(ForwardMessage.NestedMessage, MyMessage.field_by_name('nested').type) self.assertEquals(MyMessage.Inner, MyMessage.field_by_name('inner').type) self.assertEquals(MyMessage.Sibling, MyMessage.Inner.field_by_name('sibling').type) finally: try: del MyMessage del ForwardMessage except: pass def testMessageField_WrongType(self): """Test that forward referencing the wrong type raises an error.""" global AnEnum try: class AnEnum(messages.Enum): pass class AnotherMessage(messages.Message): a_field = messages.MessageField('AnEnum', 1) self.assertRaises(messages.FieldDefinitionError, getattr, AnotherMessage.field_by_name('a_field'), 'type') finally: del AnEnum def testMessageFieldValidate(self): """Test validation on message field.""" class MyMessage(messages.Message): pass class AnotherMessage(messages.Message): pass field = messages.MessageField(MyMessage, 10) field.validate(MyMessage()) self.assertRaises(messages.ValidationError, field.validate, AnotherMessage()) def testMessageFieldMessageType(self): """Test message_type property.""" class MyMessage(messages.Message): pass class HasMessage(messages.Message): field = messages.MessageField(MyMessage, 1) self.assertEqual(HasMessage.field.type, HasMessage.field.message_type) def testMessageFieldValueFromMessage(self): class MyMessage(messages.Message): pass class HasMessage(messages.Message): field = messages.MessageField(MyMessage, 1) instance = MyMessage() self.assertIs(instance, HasMessage.field.value_from_message(instance)) def testMessageFieldValueFromMessageWrongType(self): class MyMessage(messages.Message): pass class HasMessage(messages.Message): field = messages.MessageField(MyMessage, 1) self.assertRaisesWithRegexpMatch( messages.DecodeError, 'Expected type MyMessage, got int: 10', HasMessage.field.value_from_message, 10) def testMessageFieldValueToMessage(self): class MyMessage(messages.Message): pass class HasMessage(messages.Message): field = messages.MessageField(MyMessage, 1) instance = MyMessage() self.assertIs(instance, HasMessage.field.value_to_message(instance)) def testMessageFieldValueToMessageWrongType(self): class MyMessage(messages.Message): pass class MyOtherMessage(messages.Message): pass class HasMessage(messages.Message): field = messages.MessageField(MyMessage, 1) instance = MyOtherMessage() self.assertRaisesWithRegexpMatch( messages.EncodeError, 'Expected type MyMessage, got MyOtherMessage: ', HasMessage.field.value_to_message, instance) def testIntegerField_AllowLong(self): """Test that the integer field allows for longs.""" messages.IntegerField(10, default=long(10)) def testMessageFieldValidate_Initialized(self): """Test validation on message field.""" class MyMessage(messages.Message): field1 = messages.IntegerField(1, required=True) field = messages.MessageField(MyMessage, 10) # Will validate messages where is_initialized() is False. message = MyMessage() field.validate(message) message.field1 = 20 field.validate(message) def testEnumField(self): """Test the construction of enum fields.""" self.assertRaises(messages.FieldDefinitionError, messages.EnumField, str, 10) self.assertRaises(messages.FieldDefinitionError, messages.EnumField, messages.Enum, 10) class Color(messages.Enum): RED = 1 GREEN = 2 BLUE = 3 field = messages.EnumField(Color, 10) self.assertEquals(Color, field.type) class Another(messages.Enum): VALUE = 1 self.assertRaises(messages.InvalidDefaultError, messages.EnumField, Color, 10, default=Another.VALUE) def testEnumField_ForwardReference(self): """Test the construction of forward reference enum fields.""" global MyMessage global ForwardEnum global ForwardMessage try: class MyMessage(messages.Message): forward = messages.EnumField('ForwardEnum', 1) nested = messages.EnumField('ForwardMessage.NestedEnum', 2) inner = messages.EnumField('Inner', 3) class Inner(messages.Enum): pass class ForwardEnum(messages.Enum): pass class ForwardMessage(messages.Message): class NestedEnum(messages.Enum): pass self.assertEquals(ForwardEnum, MyMessage.field_by_name('forward').type) self.assertEquals(ForwardMessage.NestedEnum, MyMessage.field_by_name('nested').type) self.assertEquals(MyMessage.Inner, MyMessage.field_by_name('inner').type) finally: try: del MyMessage del ForwardEnum del ForwardMessage except: pass def testEnumField_WrongType(self): """Test that forward referencing the wrong type raises an error.""" global AMessage try: class AMessage(messages.Message): pass class AnotherMessage(messages.Message): a_field = messages.EnumField('AMessage', 1) self.assertRaises(messages.FieldDefinitionError, getattr, AnotherMessage.field_by_name('a_field'), 'type') finally: del AMessage def testMessageDefinition(self): """Test that message definition is set on fields.""" class MyMessage(messages.Message): my_field = messages.StringField(1) self.assertEquals(MyMessage, MyMessage.field_by_name('my_field').message_definition()) def testNoneAssignment(self): """Test that assigning None does not change comparison.""" class MyMessage(messages.Message): my_field = messages.StringField(1) m1 = MyMessage() m2 = MyMessage() m2.my_field = None self.assertEquals(m1, m2) def testNonAsciiStr(self): """Test validation fails for non-ascii StringField values.""" class Thing(messages.Message): string_field = messages.StringField(2) thing = Thing() self.assertRaisesWithRegexpMatch( messages.ValidationError, 'Field string_field encountered non-ASCII string', setattr, thing, 'string_field', test_util.BINARY) class MessageTest(test_util.TestCase): """Tests for message class.""" def CreateMessageClass(self): """Creates a simple message class with 3 fields. Fields are defined in alphabetical order but with conflicting numeric order. """ class ComplexMessage(messages.Message): a3 = messages.IntegerField(3) b1 = messages.StringField(1) c2 = messages.StringField(2) return ComplexMessage def testSameNumbers(self): """Test that cannot assign two fields with same numbers.""" def action(): class BadMessage(messages.Message): f1 = messages.IntegerField(1) f2 = messages.IntegerField(1) self.assertRaises(messages.DuplicateNumberError, action) def testStrictAssignment(self): """Tests that cannot assign to unknown or non-reserved attributes.""" class SimpleMessage(messages.Message): field = messages.IntegerField(1) simple_message = SimpleMessage() self.assertRaises(AttributeError, setattr, simple_message, 'does_not_exist', 10) def testListAssignmentDoesNotCopy(self): class SimpleMessage(messages.Message): repeated = messages.IntegerField(1, repeated=True) message = SimpleMessage() original = message.repeated message.repeated = [] self.assertFalse(original is message.repeated) def testValidate_Optional(self): """Tests validation of optional fields.""" class SimpleMessage(messages.Message): non_required = messages.IntegerField(1) simple_message = SimpleMessage() simple_message.check_initialized() simple_message.non_required = 10 simple_message.check_initialized() def testValidate_Required(self): """Tests validation of required fields.""" class SimpleMessage(messages.Message): required = messages.IntegerField(1, required=True) simple_message = SimpleMessage() self.assertRaises(messages.ValidationError, simple_message.check_initialized) simple_message.required = 10 simple_message.check_initialized() def testValidate_Repeated(self): """Tests validation of repeated fields.""" class SimpleMessage(messages.Message): repeated = messages.IntegerField(1, repeated=True) simple_message = SimpleMessage() # Check valid values. for valid_value in [], [10], [10, 20], (), (10,), (10, 20): simple_message.repeated = valid_value simple_message.check_initialized() # Check cleared. simple_message.repeated = [] simple_message.check_initialized() # Check invalid values. for invalid_value in 10, ['10', '20'], [None], (None,): self.assertRaises(messages.ValidationError, setattr, simple_message, 'repeated', invalid_value) def testIsInitialized(self): """Tests is_initialized.""" class SimpleMessage(messages.Message): required = messages.IntegerField(1, required=True) simple_message = SimpleMessage() self.assertFalse(simple_message.is_initialized()) simple_message.required = 10 self.assertTrue(simple_message.is_initialized()) def testIsInitializedNestedField(self): """Tests is_initialized for nested fields.""" class SimpleMessage(messages.Message): required = messages.IntegerField(1, required=True) class NestedMessage(messages.Message): simple = messages.MessageField(SimpleMessage, 1) simple_message = SimpleMessage() self.assertFalse(simple_message.is_initialized()) nested_message = NestedMessage(simple=simple_message) self.assertFalse(nested_message.is_initialized()) simple_message.required = 10 self.assertTrue(simple_message.is_initialized()) self.assertTrue(nested_message.is_initialized()) def testNestedMethodsNotAllowed(self): """Test that method definitions on Message classes are not allowed.""" def action(): class WithMethods(messages.Message): def not_allowed(self): pass self.assertRaises(messages.MessageDefinitionError, action) def testNestedAttributesNotAllowed(self): """Test that attribute assignment on Message classes are not allowed.""" def int_attribute(): class WithMethods(messages.Message): not_allowed = 1 def string_attribute(): class WithMethods(messages.Message): not_allowed = 'not allowed' def enum_attribute(): class WithMethods(messages.Message): not_allowed = Color.RED for action in (int_attribute, string_attribute, enum_attribute): self.assertRaises(messages.MessageDefinitionError, action) def testNameIsSetOnFields(self): """Make sure name is set on fields after Message class init.""" class HasNamedFields(messages.Message): field = messages.StringField(1) self.assertEquals('field', HasNamedFields.field_by_number(1).name) def testSubclassingMessageDisallowed(self): """Not permitted to create sub-classes of message classes.""" class SuperClass(messages.Message): pass def action(): class SubClass(SuperClass): pass self.assertRaises(messages.MessageDefinitionError, action) def testAllFields(self): """Test all_fields method.""" ComplexMessage = self.CreateMessageClass() fields = list(ComplexMessage.all_fields()) # Order does not matter, so sort now. fields = sorted(fields, lambda f1, f2: cmp(f1.name, f2.name)) self.assertEquals(3, len(fields)) self.assertEquals('a3', fields[0].name) self.assertEquals('b1', fields[1].name) self.assertEquals('c2', fields[2].name) def testFieldByName(self): """Test getting field by name.""" ComplexMessage = self.CreateMessageClass() self.assertEquals(3, ComplexMessage.field_by_name('a3').number) self.assertEquals(1, ComplexMessage.field_by_name('b1').number) self.assertEquals(2, ComplexMessage.field_by_name('c2').number) self.assertRaises(KeyError, ComplexMessage.field_by_name, 'unknown') def testFieldByNumber(self): """Test getting field by number.""" ComplexMessage = self.CreateMessageClass() self.assertEquals('a3', ComplexMessage.field_by_number(3).name) self.assertEquals('b1', ComplexMessage.field_by_number(1).name) self.assertEquals('c2', ComplexMessage.field_by_number(2).name) self.assertRaises(KeyError, ComplexMessage.field_by_number, 4) def testGetAssignedValue(self): """Test getting the assigned value of a field.""" class SomeMessage(messages.Message): a_value = messages.StringField(1, default=u'a default') message = SomeMessage() self.assertEquals(None, message.get_assigned_value('a_value')) message.a_value = u'a string' self.assertEquals(u'a string', message.get_assigned_value('a_value')) message.a_value = u'a default' self.assertEquals(u'a default', message.get_assigned_value('a_value')) self.assertRaisesWithRegexpMatch( AttributeError, 'Message SomeMessage has no field no_such_field', message.get_assigned_value, 'no_such_field') def testReset(self): """Test resetting a field value.""" class SomeMessage(messages.Message): a_value = messages.StringField(1, default=u'a default') message = SomeMessage() self.assertRaises(AttributeError, message.reset, 'unknown') self.assertEquals(u'a default', message.a_value) message.reset('a_value') self.assertEquals(u'a default', message.a_value) message.a_value = u'a new value' self.assertEquals(u'a new value', message.a_value) message.reset('a_value') self.assertEquals(u'a default', message.a_value) def testAllowNestedEnums(self): """Test allowing nested enums in a message definition.""" class Trade(messages.Message): class Duration(messages.Enum): GTC = 1 DAY = 2 class Currency(messages.Enum): USD = 1 GBP = 2 INR = 3 # Sorted by name order seems to be the only feasible option. self.assertEquals(['Currency', 'Duration'], Trade.__enums__) # Message definition will now be set on Enumerated objects. self.assertEquals(Trade, Trade.Duration.message_definition()) def testAllowNestedMessages(self): """Test allowing nested messages in a message definition.""" class Trade(messages.Message): class Lot(messages.Message): pass class Agent(messages.Message): pass # Sorted by name order seems to be the only feasible option. self.assertEquals(['Agent', 'Lot'], Trade.__messages__) self.assertEquals(Trade, Trade.Agent.message_definition()) self.assertEquals(Trade, Trade.Lot.message_definition()) # But not Message itself. def action(): class Trade(messages.Message): NiceTry = messages.Message self.assertRaises(messages.MessageDefinitionError, action) def testDisallowClassAssignments(self): """Test setting class attributes may not happen.""" class MyMessage(messages.Message): pass self.assertRaises(AttributeError, setattr, MyMessage, 'x', 'do not assign') def testEquality(self): """Test message class equality.""" # Comparison against enums must work. class MyEnum(messages.Enum): val1 = 1 val2 = 2 # Comparisons against nested messages must work. class AnotherMessage(messages.Message): string = messages.StringField(1) class MyMessage(messages.Message): field1 = messages.IntegerField(1) field2 = messages.EnumField(MyEnum, 2) field3 = messages.MessageField(AnotherMessage, 3) message1 = MyMessage() self.assertNotEquals('hi', message1) self.assertNotEquals(AnotherMessage(), message1) self.assertEquals(message1, message1) message2 = MyMessage() self.assertEquals(message1, message2) message1.field1 = 10 self.assertNotEquals(message1, message2) message2.field1 = 20 self.assertNotEquals(message1, message2) message2.field1 = 10 self.assertEquals(message1, message2) message1.field2 = MyEnum.val1 self.assertNotEquals(message1, message2) message2.field2 = MyEnum.val2 self.assertNotEquals(message1, message2) message2.field2 = MyEnum.val1 self.assertEquals(message1, message2) message1.field3 = AnotherMessage() message1.field3.string = 'value1' self.assertNotEquals(message1, message2) message2.field3 = AnotherMessage() message2.field3.string = 'value2' self.assertNotEquals(message1, message2) message2.field3.string = 'value1' self.assertEquals(message1, message2) def testEqualityWithUnknowns(self): """Test message class equality with unknown fields.""" class MyMessage(messages.Message): field1 = messages.IntegerField(1) message1 = MyMessage() message2 = MyMessage() self.assertEquals(message1, message2) message1.set_unrecognized_field('unknown1', 'value1', messages.Variant.STRING) self.assertEquals(message1, message2) message1.set_unrecognized_field('unknown2', ['asdf', 3], messages.Variant.STRING) message1.set_unrecognized_field('unknown3', 4.7, messages.Variant.DOUBLE) self.assertEquals(message1, message2) def testUnrecognizedFieldInvalidVariant(self): class MyMessage(messages.Message): field1 = messages.IntegerField(1) message1 = MyMessage() self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4', {'unhandled': 'type'}, None) self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4', {'unhandled': 'type'}, 123) def testRepr(self): """Test represtation of Message object.""" class MyMessage(messages.Message): integer_value = messages.IntegerField(1) string_value = messages.StringField(2) unassigned = messages.StringField(3) unassigned_with_default = messages.StringField(4, default=u'a default') my_message = MyMessage() my_message.integer_value = 42 my_message.string_value = u'A string' self.assertEquals("", repr(my_message)) def testValidation(self): """Test validation of message values.""" # Test optional. class SubMessage(messages.Message): pass class Message(messages.Message): val = messages.MessageField(SubMessage, 1) message = Message() message_field = messages.MessageField(Message, 1) message_field.validate(message) message.val = SubMessage() message_field.validate(message) self.assertRaises(messages.ValidationError, setattr, message, 'val', [SubMessage()]) # Test required. class Message(messages.Message): val = messages.MessageField(SubMessage, 1, required=True) message = Message() message_field = messages.MessageField(Message, 1) message_field.validate(message) message.val = SubMessage() message_field.validate(message) self.assertRaises(messages.ValidationError, setattr, message, 'val', [SubMessage()]) # Test repeated. class Message(messages.Message): val = messages.MessageField(SubMessage, 1, repeated=True) message = Message() message_field = messages.MessageField(Message, 1) message_field.validate(message) self.assertRaisesWithRegexpMatch( messages.ValidationError, "Field val is repeated. Found: ", setattr, message, 'val', SubMessage()) message.val = [SubMessage()] message_field.validate(message) def testDefinitionName(self): """Test message name.""" class MyMessage(messages.Message): pass module_name = test_util.get_module_name(FieldTest) self.assertEquals('%s.MyMessage' % module_name, MyMessage.definition_name()) self.assertEquals(module_name, MyMessage.outer_definition_name()) self.assertEquals(module_name, MyMessage.definition_package()) self.assertEquals(unicode, type(MyMessage.definition_name())) self.assertEquals(unicode, type(MyMessage.outer_definition_name())) self.assertEquals(unicode, type(MyMessage.definition_package())) def testDefinitionName_OverrideModule(self): """Test message module is overriden by module package name.""" class MyMessage(messages.Message): pass global package package = 'my.package' try: self.assertEquals('my.package.MyMessage', MyMessage.definition_name()) self.assertEquals('my.package', MyMessage.outer_definition_name()) self.assertEquals('my.package', MyMessage.definition_package()) self.assertEquals(unicode, type(MyMessage.definition_name())) self.assertEquals(unicode, type(MyMessage.outer_definition_name())) self.assertEquals(unicode, type(MyMessage.definition_package())) finally: del package def testDefinitionName_NoModule(self): """Test what happens when there is no module for message.""" class MyMessage(messages.Message): pass original_modules = sys.modules sys.modules = dict(sys.modules) try: del sys.modules[__name__] self.assertEquals('MyMessage', MyMessage.definition_name()) self.assertEquals(None, MyMessage.outer_definition_name()) self.assertEquals(None, MyMessage.definition_package()) self.assertEquals(unicode, type(MyMessage.definition_name())) finally: sys.modules = original_modules def testDefinitionName_Nested(self): """Test nested message names.""" class MyMessage(messages.Message): class NestedMessage(messages.Message): class NestedMessage(messages.Message): pass module_name = test_util.get_module_name(MessageTest) self.assertEquals('%s.MyMessage.NestedMessage' % module_name, MyMessage.NestedMessage.definition_name()) self.assertEquals('%s.MyMessage' % module_name, MyMessage.NestedMessage.outer_definition_name()) self.assertEquals(module_name, MyMessage.NestedMessage.definition_package()) self.assertEquals('%s.MyMessage.NestedMessage.NestedMessage' % module_name, MyMessage.NestedMessage.NestedMessage.definition_name()) self.assertEquals( '%s.MyMessage.NestedMessage' % module_name, MyMessage.NestedMessage.NestedMessage.outer_definition_name()) self.assertEquals( module_name, MyMessage.NestedMessage.NestedMessage.definition_package()) def testMessageDefinition(self): """Test that enumeration knows its enclosing message definition.""" class OuterMessage(messages.Message): class InnerMessage(messages.Message): pass self.assertEquals(None, OuterMessage.message_definition()) self.assertEquals(OuterMessage, OuterMessage.InnerMessage.message_definition()) def testConstructorKwargs(self): """Test kwargs via constructor.""" class SomeMessage(messages.Message): name = messages.StringField(1) number = messages.IntegerField(2) expected = SomeMessage() expected.name = 'my name' expected.number = 200 self.assertEquals(expected, SomeMessage(name='my name', number=200)) def testConstructorNotAField(self): """Test kwargs via constructor with wrong names.""" class SomeMessage(messages.Message): pass self.assertRaisesWithRegexpMatch( AttributeError, 'May not assign arbitrary value does_not_exist to message SomeMessage', SomeMessage, does_not_exist=10) def testGetUnsetRepeatedValue(self): class SomeMessage(messages.Message): repeated = messages.IntegerField(1, repeated=True) instance = SomeMessage() self.assertEquals([], instance.repeated) self.assertTrue(isinstance(instance.repeated, messages.FieldList)) def testCompareAutoInitializedRepeatedFields(self): class SomeMessage(messages.Message): repeated = messages.IntegerField(1, repeated=True) message1 = SomeMessage(repeated=[]) message2 = SomeMessage() self.assertEquals(message1, message2) def testUnknownValues(self): """Test message class equality with unknown fields.""" class MyMessage(messages.Message): field1 = messages.IntegerField(1) message = MyMessage() self.assertEquals([], message.all_unrecognized_fields()) self.assertEquals((None, None), message.get_unrecognized_field_info('doesntexist')) self.assertEquals((None, None), message.get_unrecognized_field_info( 'doesntexist', None, None)) self.assertEquals(('defaultvalue', 'defaultwire'), message.get_unrecognized_field_info( 'doesntexist', 'defaultvalue', 'defaultwire')) self.assertEquals((3, None), message.get_unrecognized_field_info( 'doesntexist', value_default=3)) message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE) self.assertEquals(1, len(message.all_unrecognized_fields())) self.assertIn('exists', message.all_unrecognized_fields()) self.assertEquals((9.5, messages.Variant.DOUBLE), message.get_unrecognized_field_info('exists')) self.assertEquals((9.5, messages.Variant.DOUBLE), message.get_unrecognized_field_info('exists', 'type', 1234)) self.assertEquals((1234, None), message.get_unrecognized_field_info('doesntexist', 1234)) message.set_unrecognized_field('another', 'value', messages.Variant.STRING) self.assertEquals(2, len(message.all_unrecognized_fields())) self.assertIn('exists', message.all_unrecognized_fields()) self.assertIn('another', message.all_unrecognized_fields()) self.assertEquals((9.5, messages.Variant.DOUBLE), message.get_unrecognized_field_info('exists')) self.assertEquals(('value', messages.Variant.STRING), message.get_unrecognized_field_info('another')) message.set_unrecognized_field('typetest1', ['list', 0, ('test',)], messages.Variant.STRING) self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING), message.get_unrecognized_field_info('typetest1')) message.set_unrecognized_field('typetest2', '', messages.Variant.STRING) self.assertEquals(('', messages.Variant.STRING), message.get_unrecognized_field_info('typetest2')) class FindDefinitionTest(test_util.TestCase): """Test finding definitions relative to various definitions and modules.""" def setUp(self): """Set up module-space. Starts off empty.""" self.modules = {} def DefineModule(self, name): """Define a module and its parents in module space. Modules that are already defined in self.modules are not re-created. Args: name: Fully qualified name of modules to create. Returns: Deepest nested module. For example: DefineModule('a.b.c') # Returns c. """ name_path = name.split('.') full_path = [] for node in name_path: full_path.append(node) full_name = '.'.join(full_path) self.modules.setdefault(full_name, new.module(full_name)) return self.modules[name] def DefineMessage(self, module, name, children={}, add_to_module=True): """Define a new Message class in the context of a module. Used for easily describing complex Message hierarchy. Message is defined including all child definitions. Args: module: Fully qualified name of module to place Message class in. name: Name of Message to define within module. children: Define any level of nesting of children definitions. To define a message, map the name to another dictionary. The dictionary can itself contain additional definitions, and so on. To map to an Enum, define the Enum class separately and map it by name. add_to_module: If True, new Message class is added to module. If False, new Message is not added. """ # Make sure module exists. module_instance = self.DefineModule(module) # Recursively define all child messages. for attribute, value in children.items(): if isinstance(value, dict): children[attribute] = self.DefineMessage( module, attribute, value, False) # Override default __module__ variable. children['__module__'] = module # Instantiate and possibly add to module. message_class = new.classobj(name, (messages.Message,), dict(children)) if add_to_module: setattr(module_instance, name, message_class) return message_class def Importer(self, module, globals='', locals='', fromlist=None): """Importer function. Acts like __import__. Only loads modules from self.modules. Does not try to load real modules defined elsewhere. Does not try to handle relative imports. Args: module: Fully qualified name of module to load from self.modules. """ if fromlist is None: module = module.split('.')[0] try: return self.modules[module] except KeyError: raise ImportError() def testNoSuchModule(self): """Test searching for definitions that do no exist.""" self.assertRaises(messages.DefinitionNotFoundError, messages.find_definition, 'does.not.exist', importer=self.Importer) def testRefersToModule(self): """Test that referring to a module does not return that module.""" self.DefineModule('i.am.a.module') self.assertRaises(messages.DefinitionNotFoundError, messages.find_definition, 'i.am.a.module', importer=self.Importer) def testNoDefinition(self): """Test not finding a definition in an existing module.""" self.DefineModule('i.am.a.module') self.assertRaises(messages.DefinitionNotFoundError, messages.find_definition, 'i.am.a.module.MyMessage', importer=self.Importer) def testNotADefinition(self): """Test trying to fetch something that is not a definition.""" module = self.DefineModule('i.am.a.module') setattr(module, 'A', 'a string') self.assertRaises(messages.DefinitionNotFoundError, messages.find_definition, 'i.am.a.module.A', importer=self.Importer) def testGlobalFind(self): """Test finding definitions from fully qualified module names.""" A = self.DefineMessage('a.b.c', 'A', {}) self.assertEquals(A, messages.find_definition('a.b.c.A', importer=self.Importer)) B = self.DefineMessage('a.b.c', 'B', {'C':{}}) self.assertEquals(B.C, messages.find_definition('a.b.c.B.C', importer=self.Importer)) def testRelativeToModule(self): """Test finding definitions relative to modules.""" # Define modules. a = self.DefineModule('a') b = self.DefineModule('a.b') c = self.DefineModule('a.b.c') # Define messages. A = self.DefineMessage('a', 'A') B = self.DefineMessage('a.b', 'B') C = self.DefineMessage('a.b.c', 'C') D = self.DefineMessage('a.b.d', 'D') # Find A, B, C and D relative to a. self.assertEquals(A, messages.find_definition( 'A', a, importer=self.Importer)) self.assertEquals(B, messages.find_definition( 'b.B', a, importer=self.Importer)) self.assertEquals(C, messages.find_definition( 'b.c.C', a, importer=self.Importer)) self.assertEquals(D, messages.find_definition( 'b.d.D', a, importer=self.Importer)) # Find A, B, C and D relative to b. self.assertEquals(A, messages.find_definition( 'A', b, importer=self.Importer)) self.assertEquals(B, messages.find_definition( 'B', b, importer=self.Importer)) self.assertEquals(C, messages.find_definition( 'c.C', b, importer=self.Importer)) self.assertEquals(D, messages.find_definition( 'd.D', b, importer=self.Importer)) # Find A, B, C and D relative to c. Module d is the same case as c. self.assertEquals(A, messages.find_definition( 'A', c, importer=self.Importer)) self.assertEquals(B, messages.find_definition( 'B', c, importer=self.Importer)) self.assertEquals(C, messages.find_definition( 'C', c, importer=self.Importer)) self.assertEquals(D, messages.find_definition( 'd.D', c, importer=self.Importer)) def testRelativeToMessages(self): """Test finding definitions relative to Message definitions.""" A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}}) B = A.B C = A.B.C D = A.B.D # Find relative to A. self.assertEquals(A, messages.find_definition( 'A', A, importer=self.Importer)) self.assertEquals(B, messages.find_definition( 'B', A, importer=self.Importer)) self.assertEquals(C, messages.find_definition( 'B.C', A, importer=self.Importer)) self.assertEquals(D, messages.find_definition( 'B.D', A, importer=self.Importer)) # Find relative to B. self.assertEquals(A, messages.find_definition( 'A', B, importer=self.Importer)) self.assertEquals(B, messages.find_definition( 'B', B, importer=self.Importer)) self.assertEquals(C, messages.find_definition( 'C', B, importer=self.Importer)) self.assertEquals(D, messages.find_definition( 'D', B, importer=self.Importer)) # Find relative to C. self.assertEquals(A, messages.find_definition( 'A', C, importer=self.Importer)) self.assertEquals(B, messages.find_definition( 'B', C, importer=self.Importer)) self.assertEquals(C, messages.find_definition( 'C', C, importer=self.Importer)) self.assertEquals(D, messages.find_definition( 'D', C, importer=self.Importer)) # Find relative to C searching from c. self.assertEquals(A, messages.find_definition( 'b.A', C, importer=self.Importer)) self.assertEquals(B, messages.find_definition( 'b.A.B', C, importer=self.Importer)) self.assertEquals(C, messages.find_definition( 'b.A.B.C', C, importer=self.Importer)) self.assertEquals(D, messages.find_definition( 'b.A.B.D', C, importer=self.Importer)) def testAbsoluteReference(self): """Test finding absolute definition names.""" # Define modules. a = self.DefineModule('a') b = self.DefineModule('a.a') # Define messages. aA = self.DefineMessage('a', 'A') aaA = self.DefineMessage('a.a', 'A') # Always find a.A. self.assertEquals(aA, messages.find_definition('.a.A', None, importer=self.Importer)) self.assertEquals(aA, messages.find_definition('.a.A', a, importer=self.Importer)) self.assertEquals(aA, messages.find_definition('.a.A', aA, importer=self.Importer)) self.assertEquals(aA, messages.find_definition('.a.A', aaA, importer=self.Importer)) def testFindEnum(self): """Test that Enums are found.""" class Color(messages.Enum): pass A = self.DefineMessage('a', 'A', {'Color': Color}) self.assertEquals( Color, messages.find_definition('Color', A, importer=self.Importer)) def testFalseScope(self): """Test that Message definitions nested in strange objects are hidden.""" global X class X(object): class A(messages.Message): pass self.assertRaises(TypeError, messages.find_definition, 'A', X) self.assertRaises(messages.DefinitionNotFoundError, messages.find_definition, 'X.A', sys.modules[__name__]) def testSearchAttributeFirst(self): """Make sure not faked out by module, but continues searching.""" A = self.DefineMessage('a', 'A') module_A = self.DefineModule('a.A') self.assertEquals(A, messages.find_definition( 'a.A', None, importer=self.Importer)) class FindDefinitionUnicodeTests(test_util.TestCase): def testUnicodeString(self): """Test using unicode names.""" self.assertEquals('ServiceMapping', messages.find_definition( u'protorpc.registry.ServiceMapping', None).__name__) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/non_sdk_imports.py0000755000076500000240000000112112277637135024135 0ustar jeremydwstaff00000000000000"""Dynamically decide from where to import other non SDK Google modules. All other protorpc code should import other non SDK modules from this module. If necessary, add new imports here (in both places). """ __author__ = 'yey@google.com (Ye Yuan)' # pylint: disable=g-import-not-at-top # pylint: disable=unused-import try: from google.protobuf import descriptor normal_environment = True except ImportError: normal_environment = False if normal_environment: from google.protobuf import descriptor_pb2 from google.protobuf import message from google.protobuf import reflection protorpc-standalone-0.9.1/protorpc/protobuf.py0000755000076500000240000002534212277637135022600 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Protocol buffer support for message types. For more details about protocol buffer encoding and decoding please see: http://code.google.com/apis/protocolbuffers/docs/encoding.html Public Exceptions: DecodeError: Raised when a decode error occurs from incorrect protobuf format. Public Functions: encode_message: Encodes a message in to a protocol buffer string. decode_message: Decode from a protocol buffer string to a message. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import array from . import message_types from . import messages from . import util from .google_imports import ProtocolBuffer __all__ = ['ALTERNATIVE_CONTENT_TYPES', 'CONTENT_TYPE', 'encode_message', 'decode_message', ] CONTENT_TYPE = 'application/octet-stream' ALTERNATIVE_CONTENT_TYPES = ['application/x-google-protobuf'] class _Encoder(ProtocolBuffer.Encoder): """Extension of protocol buffer encoder. Original protocol buffer encoder does not have complete set of methods for handling required encoding. This class adds them. """ # TODO(rafek): Implement the missing encoding types. def no_encoding(self, value): """No encoding available for type. Args: value: Value to encode. Raises: NotImplementedError at all times. """ raise NotImplementedError() def encode_enum(self, value): """Encode an enum value. Args: value: Enum to encode. """ self.putVarInt32(value.number) def encode_message(self, value): """Encode a Message in to an embedded message. Args: value: Message instance to encode. """ self.putPrefixedString(encode_message(value)) def encode_unicode_string(self, value): """Helper to properly pb encode unicode strings to UTF-8. Args: value: String value to encode. """ if isinstance(value, unicode): value = value.encode('utf-8') self.putPrefixedString(value) class _Decoder(ProtocolBuffer.Decoder): """Extension of protocol buffer decoder. Original protocol buffer decoder does not have complete set of methods for handling required decoding. This class adds them. """ # TODO(rafek): Implement the missing encoding types. def no_decoding(self): """No decoding available for type. Raises: NotImplementedError at all times. """ raise NotImplementedError() def decode_string(self): """Decode a unicode string. Returns: Next value in stream as a unicode string. """ return self.getPrefixedString().decode('UTF-8') def decode_boolean(self): """Decode a boolean value. Returns: Next value in stream as a boolean. """ return bool(self.getBoolean()) # Number of bits used to describe a protocol buffer bits used for the variant. _WIRE_TYPE_BITS = 3 _WIRE_TYPE_MASK = 7 # Maps variant to underlying wire type. Many variants map to same type. _VARIANT_TO_WIRE_TYPE = { messages.Variant.DOUBLE: _Encoder.DOUBLE, messages.Variant.FLOAT: _Encoder.FLOAT, messages.Variant.INT64: _Encoder.NUMERIC, messages.Variant.UINT64: _Encoder.NUMERIC, messages.Variant.INT32: _Encoder.NUMERIC, messages.Variant.BOOL: _Encoder.NUMERIC, messages.Variant.STRING: _Encoder.STRING, messages.Variant.MESSAGE: _Encoder.STRING, messages.Variant.BYTES: _Encoder.STRING, messages.Variant.UINT32: _Encoder.NUMERIC, messages.Variant.ENUM: _Encoder.NUMERIC, messages.Variant.SINT32: _Encoder.NUMERIC, messages.Variant.SINT64: _Encoder.NUMERIC, } # Maps variant to encoder method. _VARIANT_TO_ENCODER_MAP = { messages.Variant.DOUBLE: _Encoder.putDouble, messages.Variant.FLOAT: _Encoder.putFloat, messages.Variant.INT64: _Encoder.putVarInt64, messages.Variant.UINT64: _Encoder.putVarUint64, messages.Variant.INT32: _Encoder.putVarInt32, messages.Variant.BOOL: _Encoder.putBoolean, messages.Variant.STRING: _Encoder.encode_unicode_string, messages.Variant.MESSAGE: _Encoder.encode_message, messages.Variant.BYTES: _Encoder.encode_unicode_string, messages.Variant.UINT32: _Encoder.no_encoding, messages.Variant.ENUM: _Encoder.encode_enum, messages.Variant.SINT32: _Encoder.no_encoding, messages.Variant.SINT64: _Encoder.no_encoding, } # Basic wire format decoders. Used for reading unknown values. _WIRE_TYPE_TO_DECODER_MAP = { _Encoder.NUMERIC: _Decoder.getVarInt64, _Encoder.DOUBLE: _Decoder.getDouble, _Encoder.STRING: _Decoder.getPrefixedString, _Encoder.FLOAT: _Decoder.getFloat, } # Map wire type to variant. Used to find a variant for unknown values. _WIRE_TYPE_TO_VARIANT_MAP = { _Encoder.NUMERIC: messages.Variant.INT64, _Encoder.DOUBLE: messages.Variant.DOUBLE, _Encoder.STRING: messages.Variant.STRING, _Encoder.FLOAT: messages.Variant.FLOAT, } # Wire type to name mapping for error messages. _WIRE_TYPE_NAME = { _Encoder.NUMERIC: 'NUMERIC', _Encoder.DOUBLE: 'DOUBLE', _Encoder.STRING: 'STRING', _Encoder.FLOAT: 'FLOAT', } # Maps variant to decoder method. _VARIANT_TO_DECODER_MAP = { messages.Variant.DOUBLE: _Decoder.getDouble, messages.Variant.FLOAT: _Decoder.getFloat, messages.Variant.INT64: _Decoder.getVarInt64, messages.Variant.UINT64: _Decoder.getVarUint64, messages.Variant.INT32: _Decoder.getVarInt32, messages.Variant.BOOL: _Decoder.decode_boolean, messages.Variant.STRING: _Decoder.decode_string, messages.Variant.MESSAGE: _Decoder.getPrefixedString, messages.Variant.BYTES: _Decoder.getPrefixedString, messages.Variant.UINT32: _Decoder.no_decoding, messages.Variant.ENUM: _Decoder.getVarInt32, messages.Variant.SINT32: _Decoder.no_decoding, messages.Variant.SINT64: _Decoder.no_decoding, } def encode_message(message): """Encode Message instance to protocol buffer. Args: Message instance to encode in to protocol buffer. Returns: String encoding of Message instance in protocol buffer format. Raises: messages.ValidationError if message is not initialized. """ message.check_initialized() encoder = _Encoder() # Get all fields, from the known fields we parsed and the unknown fields # we saved. Note which ones were known, so we can process them differently. all_fields = [(field.number, field) for field in message.all_fields()] all_fields.extend((key, None) for key in message.all_unrecognized_fields() if isinstance(key, (int, long))) all_fields.sort() for field_num, field in all_fields: if field: # Known field. value = message.get_assigned_value(field.name) if value is None: continue variant = field.variant repeated = field.repeated else: # Unrecognized field. value, variant = message.get_unrecognized_field_info(field_num) if not isinstance(variant, messages.Variant): continue repeated = isinstance(value, (list, tuple)) tag = ((field_num << _WIRE_TYPE_BITS) | _VARIANT_TO_WIRE_TYPE[variant]) # Write value to wire. if repeated: values = value else: values = [value] for next in values: encoder.putVarInt32(tag) if isinstance(field, messages.MessageField): next = field.value_to_message(next) field_encoder = _VARIANT_TO_ENCODER_MAP[variant] field_encoder(encoder, next) return encoder.buffer().tostring() def decode_message(message_type, encoded_message): """Decode protocol buffer to Message instance. Args: message_type: Message type to decode data to. encoded_message: Encoded version of message as string. Returns: Decoded instance of message_type. Raises: DecodeError if an error occurs during decoding, such as incompatible wire format for a field. messages.ValidationError if merged message is not initialized. """ message = message_type() message_array = array.array('B') message_array.fromstring(encoded_message) try: decoder = _Decoder(message_array, 0, len(message_array)) while decoder.avail() > 0: # Decode tag and variant information. encoded_tag = decoder.getVarInt32() tag = encoded_tag >> _WIRE_TYPE_BITS wire_type = encoded_tag & _WIRE_TYPE_MASK try: found_wire_type_decoder = _WIRE_TYPE_TO_DECODER_MAP[wire_type] except: raise messages.DecodeError('No such wire type %d' % wire_type) if tag < 1: raise messages.DecodeError('Invalid tag value %d' % tag) try: field = message.field_by_number(tag) except KeyError: # Unexpected tags are ok. field = None wire_type_decoder = found_wire_type_decoder else: expected_wire_type = _VARIANT_TO_WIRE_TYPE[field.variant] if expected_wire_type != wire_type: raise messages.DecodeError('Expected wire type %s but found %s' % ( _WIRE_TYPE_NAME[expected_wire_type], _WIRE_TYPE_NAME[wire_type])) wire_type_decoder = _VARIANT_TO_DECODER_MAP[field.variant] value = wire_type_decoder(decoder) # Save unknown fields and skip additional processing. if not field: # When saving this, save it under the tag number (which should # be unique), and set the variant and value so we know how to # interpret the value later. variant = _WIRE_TYPE_TO_VARIANT_MAP.get(wire_type) if variant: message.set_unrecognized_field(tag, value, variant) continue # Special case Enum and Message types. if isinstance(field, messages.EnumField): try: value = field.type(value) except TypeError: raise messages.DecodeError('Invalid enum value %s' % value) elif isinstance(field, messages.MessageField): value = decode_message(field.message_type, value) value = field.value_from_message(value) # Merge value in to message. if field.repeated: values = getattr(message, field.name) if values is None: setattr(message, field.name, [value]) else: values.append(value) else: setattr(message, field.name, value) except ProtocolBuffer.ProtocolBufferDecodeError, err: raise messages.DecodeError('Decoding error: %s' % str(err)) message.check_initialized() return message protorpc-standalone-0.9.1/protorpc/protobuf_test.py0000755000076500000240000002436612277637135023644 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.protobuf.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import datetime import unittest from protorpc import message_types from protorpc import messages from protorpc import protobuf from protorpc import protorpc_test_pb2 from protorpc import test_util from protorpc import util # TODO: Add DateTimeFields to protorpc_test.proto when definition.py # supports date time fields. class HasDateTimeMessage(messages.Message): value = message_types.DateTimeField(1) class NestedDateTimeMessage(messages.Message): value = messages.MessageField(message_types.DateTimeMessage, 1) class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = protobuf class EncodeMessageTest(test_util.TestCase, test_util.ProtoConformanceTestBase): """Test message to protocol buffer encoding.""" PROTOLIB = protobuf def assertErrorIs(self, exception, message, function, *params, **kwargs): try: function(*params, **kwargs) self.fail('Expected to raise exception %s but did not.' % exception) except exception, err: self.assertEquals(message, str(err)) @property def encoded_partial(self): proto = protorpc_test_pb2.OptionalMessage() proto.double_value = 1.23 proto.int64_value = -100000000000 proto.int32_value = 1020 proto.string_value = u'a string' proto.enum_value = protorpc_test_pb2.OptionalMessage.VAL2 return proto.SerializeToString() @property def encoded_full(self): proto = protorpc_test_pb2.OptionalMessage() proto.double_value = 1.23 proto.float_value = -2.5 proto.int64_value = -100000000000 proto.uint64_value = 102020202020 proto.int32_value = 1020 proto.bool_value = True proto.string_value = u'a string\u044f' proto.bytes_value = 'a bytes\xff\xfe' proto.enum_value = protorpc_test_pb2.OptionalMessage.VAL2 return proto.SerializeToString() @property def encoded_repeated(self): proto = protorpc_test_pb2.RepeatedMessage() proto.double_value.append(1.23) proto.double_value.append(2.3) proto.float_value.append(-2.5) proto.float_value.append(0.5) proto.int64_value.append(-100000000000) proto.int64_value.append(20) proto.uint64_value.append(102020202020) proto.uint64_value.append(10) proto.int32_value.append(1020) proto.int32_value.append(718) proto.bool_value.append(True) proto.bool_value.append(False) proto.string_value.append(u'a string\u044f') proto.string_value.append(u'another string') proto.bytes_value.append('a bytes\xff\xfe') proto.bytes_value.append('another bytes') proto.enum_value.append(protorpc_test_pb2.RepeatedMessage.VAL2) proto.enum_value.append(protorpc_test_pb2.RepeatedMessage.VAL1) return proto.SerializeToString() @property def encoded_nested(self): proto = protorpc_test_pb2.HasNestedMessage() proto.nested.a_value = 'a string' return proto.SerializeToString() @property def encoded_repeated_nested(self): proto = protorpc_test_pb2.HasNestedMessage() proto.repeated_nested.add().a_value = 'a string' proto.repeated_nested.add().a_value = 'another string' return proto.SerializeToString() unexpected_tag_message = ( chr((15 << protobuf._WIRE_TYPE_BITS) | protobuf._Encoder.NUMERIC) + chr(5)) @property def encoded_default_assigned(self): proto = protorpc_test_pb2.HasDefault() proto.a_value = test_util.HasDefault.a_value.default return proto.SerializeToString() @property def encoded_nested_empty(self): proto = protorpc_test_pb2.HasOptionalNestedMessage() proto.nested.Clear() return proto.SerializeToString() @property def encoded_repeated_nested_empty(self): proto = protorpc_test_pb2.HasOptionalNestedMessage() proto.repeated_nested.add() proto.repeated_nested.add() return proto.SerializeToString() @property def encoded_extend_message(self): proto = protorpc_test_pb2.RepeatedMessage() proto.add_int64_value(400) proto.add_int64_value(50) proto.add_int64_value(6000) return proto.SerializeToString() @property def encoded_string_types(self): proto = protorpc_test_pb2.OptionalMessage() proto.string_value = u'Latin' return proto.SerializeToString() @property def encoded_invalid_enum(self): encoder = protobuf._Encoder() field_num = test_util.OptionalMessage.enum_value.number tag = (field_num << protobuf._WIRE_TYPE_BITS) | encoder.NUMERIC encoder.putVarInt32(tag) encoder.putVarInt32(1000) return encoder.buffer().tostring() def testDecodeWrongWireFormat(self): """Test what happens when wrong wire format found in protobuf.""" class ExpectedProto(messages.Message): value = messages.StringField(1) class WrongVariant(messages.Message): value = messages.IntegerField(1) original = WrongVariant() original.value = 10 self.assertErrorIs(messages.DecodeError, 'Expected wire type STRING but found NUMERIC', protobuf.decode_message, ExpectedProto, protobuf.encode_message(original)) def testDecodeBadWireType(self): """Test what happens when non-existant wire type found in protobuf.""" # Message has tag 1, type 3 which does not exist. bad_wire_type_message = chr((1 << protobuf._WIRE_TYPE_BITS) | 3) self.assertErrorIs(messages.DecodeError, 'No such wire type 3', protobuf.decode_message, test_util.OptionalMessage, bad_wire_type_message) def testUnexpectedTagBelowOne(self): """Test that completely invalid tags generate an error.""" # Message has tag 0, type NUMERIC. invalid_tag_message = chr(protobuf._Encoder.NUMERIC) self.assertErrorIs(messages.DecodeError, 'Invalid tag value 0', protobuf.decode_message, test_util.OptionalMessage, invalid_tag_message) def testProtocolBufferDecodeError(self): """Test what happens when there a ProtocolBufferDecodeError. This is what happens when the underlying ProtocolBuffer library raises it's own decode error. """ # Message has tag 1, type DOUBLE, missing value. truncated_message = ( chr((1 << protobuf._WIRE_TYPE_BITS) | protobuf._Encoder.DOUBLE)) self.assertErrorIs(messages.DecodeError, 'Decoding error: truncated', protobuf.decode_message, test_util.OptionalMessage, truncated_message) def testProtobufUnrecognizedField(self): """Test that unrecognized fields are serialized and can be accessed.""" decoded = protobuf.decode_message(test_util.OptionalMessage, self.unexpected_tag_message) self.assertEquals(1, len(decoded.all_unrecognized_fields())) self.assertEquals(15, decoded.all_unrecognized_fields()[0]) self.assertEquals((5, messages.Variant.INT64), decoded.get_unrecognized_field_info(15)) def testUnrecognizedFieldWrongFormat(self): """Test that unrecognized fields in the wrong format are skipped.""" class SimpleMessage(messages.Message): value = messages.IntegerField(1) message = SimpleMessage(value=3) message.set_unrecognized_field('from_json', 'test', messages.Variant.STRING) encoded = protobuf.encode_message(message) expected = ( chr((1 << protobuf._WIRE_TYPE_BITS) | protobuf._Encoder.NUMERIC) + chr(3)) self.assertEquals(encoded, expected) def testProtobufDecodeDateTimeMessage(self): """Test what happens when decoding a DateTimeMessage.""" nested = NestedDateTimeMessage() nested.value = message_types.DateTimeMessage(milliseconds=2500) value = protobuf.decode_message(HasDateTimeMessage, protobuf.encode_message(nested)).value self.assertEqual(datetime.datetime(1970, 1, 1, 0, 0, 2, 500000), value) def testProtobufDecodeDateTimeMessageWithTimeZone(self): """Test what happens when decoding a DateTimeMessage with a time zone.""" nested = NestedDateTimeMessage() nested.value = message_types.DateTimeMessage(milliseconds=12345678, time_zone_offset=60) value = protobuf.decode_message(HasDateTimeMessage, protobuf.encode_message(nested)).value self.assertEqual(datetime.datetime(1970, 1, 1, 3, 25, 45, 678000, tzinfo=util.TimeZoneOffset(60)), value) def testProtobufEncodeDateTimeMessage(self): """Test what happens when encoding a DateTimeField.""" mine = HasDateTimeMessage(value=datetime.datetime(1970, 1, 1)) nested = NestedDateTimeMessage() nested.value = message_types.DateTimeMessage(milliseconds=0) my_encoded = protobuf.encode_message(mine) encoded = protobuf.encode_message(nested) self.assertEquals(my_encoded, encoded) def testProtobufEncodeDateTimeMessageWithTimeZone(self): """Test what happens when encoding a DateTimeField with a time zone.""" for tz_offset in (30, -30, 8 * 60, 0): mine = HasDateTimeMessage(value=datetime.datetime( 1970, 1, 1, tzinfo=util.TimeZoneOffset(tz_offset))) nested = NestedDateTimeMessage() nested.value = message_types.DateTimeMessage( milliseconds=0, time_zone_offset=tz_offset) my_encoded = protobuf.encode_message(mine) encoded = protobuf.encode_message(nested) self.assertEquals(my_encoded, encoded) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/protojson.py0000755000076500000240000002475212277637135023001 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """JSON support for message types. Public classes: MessageJSONEncoder: JSON encoder for message objects. Public functions: encode_message: Encodes a message in to a JSON string. decode_message: Merge from a JSON string in to a message. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import cStringIO import base64 import logging from . import message_types from . import messages from . import util __all__ = [ 'ALTERNATIVE_CONTENT_TYPES', 'CONTENT_TYPE', 'MessageJSONEncoder', 'encode_message', 'decode_message', 'ProtoJson', ] def _load_json_module(): """Try to load a valid json module. There are more than one json modules that might be installed. They are mostly compatible with one another but some versions may be different. This function attempts to load various json modules in a preferred order. It does a basic check to guess if a loaded version of json is compatible. Returns: Compatible json module. Raises: ImportError if there are no json modules or the loaded json module is not compatible with ProtoRPC. """ first_import_error = None for module_name in ['json', 'simplejson']: try: module = __import__(module_name, {}, {}, 'json') if not hasattr(module, 'JSONEncoder'): message = ('json library "%s" is not compatible with ProtoRPC' % module_name) logging.warning(message) raise ImportError(message) else: return module except ImportError, err: if not first_import_error: first_import_error = err logging.error('Must use valid json library (Python 2.6 json or simplejson)') raise first_import_error json = _load_json_module() # TODO: Rename this to MessageJsonEncoder. class MessageJSONEncoder(json.JSONEncoder): """Message JSON encoder class. Extension of JSONEncoder that can build JSON from a message object. """ def __init__(self, protojson_protocol=None, **kwargs): """Constructor. Args: protojson_protocol: ProtoJson instance. """ super(MessageJSONEncoder, self).__init__(**kwargs) self.__protojson_protocol = protojson_protocol or ProtoJson.get_default() def default(self, value): """Return dictionary instance from a message object. Args: value: Value to get dictionary for. If not encodable, will call superclasses default method. """ if isinstance(value, messages.Enum): return str(value) if isinstance(value, messages.Message): result = {} for field in value.all_fields(): item = value.get_assigned_value(field.name) if item not in (None, [], ()): result[field.name] = self.__protojson_protocol.encode_field( field, item) # Handle unrecognized fields, so they're included when a message is # decoded then encoded. for unknown_key in value.all_unrecognized_fields(): unrecognized_field, _ = value.get_unrecognized_field_info(unknown_key) result[unknown_key] = unrecognized_field return result else: return super(MessageJSONEncoder, self).default(value) class ProtoJson(object): """ProtoRPC JSON implementation class. Implementation of JSON based protocol used for serializing and deserializing message objects. Instances of remote.ProtocolConfig constructor or used with remote.Protocols.add_protocol. See the remote.py module for more details. """ CONTENT_TYPE = 'application/json' ALTERNATIVE_CONTENT_TYPES = [ 'application/x-javascript', 'text/javascript', 'text/x-javascript', 'text/x-json', 'text/json', ] def encode_field(self, field, value): """Encode a python field value to a JSON value. Args: field: A ProtoRPC field instance. value: A python value supported by field. Returns: A JSON serializable value appropriate for field. """ if isinstance(field, messages.BytesField): if field.repeated: value = [base64.b64encode(byte) for byte in value] else: value = base64.b64encode(value) elif isinstance(field, message_types.DateTimeField): # DateTimeField stores its data as a RFC 3339 compliant string. if field.repeated: value = [i.isoformat() for i in value] else: value = value.isoformat() return value def encode_message(self, message): """Encode Message instance to JSON string. Args: Message instance to encode in to JSON string. Returns: String encoding of Message instance in protocol JSON format. Raises: messages.ValidationError if message is not initialized. """ message.check_initialized() return json.dumps(message, cls=MessageJSONEncoder, protojson_protocol=self) def decode_message(self, message_type, encoded_message): """Merge JSON structure to Message instance. Args: message_type: Message to decode data to. encoded_message: JSON encoded version of message. Returns: Decoded instance of message_type. Raises: ValueError: If encoded_message is not valid JSON. messages.ValidationError if merged message is not initialized. """ if not encoded_message.strip(): return message_type() dictionary = json.loads(encoded_message) message = self.__decode_dictionary(message_type, dictionary) message.check_initialized() return message def __find_variant(self, value): """Find the messages.Variant type that describes this value. Args: value: The value whose variant type is being determined. Returns: The messages.Variant value that best describes value's type, or None if it's a type we don't know how to handle. """ if isinstance(value, bool): return messages.Variant.BOOL elif isinstance(value, (int, long)): return messages.Variant.INT64 elif isinstance(value, float): return messages.Variant.DOUBLE elif isinstance(value, basestring): return messages.Variant.STRING elif isinstance(value, (list, tuple)): # Find the most specific variant that covers all elements. variant_priority = [None, messages.Variant.INT64, messages.Variant.DOUBLE, messages.Variant.STRING] chosen_priority = 0 for v in value: variant = self.__find_variant(v) try: priority = variant_priority.index(variant) except IndexError: priority = -1 if priority > chosen_priority: chosen_priority = priority return variant_priority[chosen_priority] # Unrecognized type. return None def __decode_dictionary(self, message_type, dictionary): """Merge dictionary in to message. Args: message: Message to merge dictionary in to. dictionary: Dictionary to extract information from. Dictionary is as parsed from JSON. Nested objects will also be dictionaries. """ message = message_type() for key, value in dictionary.iteritems(): if value is None: try: message.reset(key) except AttributeError: pass # This is an unrecognized field, skip it. continue try: field = message.field_by_name(key) except KeyError: # Save unknown values. variant = self.__find_variant(value) if variant: if key.isdigit(): key = int(key) message.set_unrecognized_field(key, value, variant) else: logging.warning('No variant found for unrecognized field: %s', key) continue # Normalize values in to a list. if isinstance(value, list): if not value: continue else: value = [value] valid_value = [] for item in value: valid_value.append(self.decode_field(field, item)) if field.repeated: existing_value = getattr(message, field.name) setattr(message, field.name, valid_value) else: setattr(message, field.name, valid_value[-1]) return message def decode_field(self, field, value): """Decode a JSON value to a python value. Args: field: A ProtoRPC field instance. value: A serialized JSON value. Return: A Python value compatible with field. """ if isinstance(field, messages.EnumField): try: return field.type(value) except TypeError: raise messages.DecodeError('Invalid enum value "%s"' % value[0]) elif isinstance(field, messages.BytesField): try: return base64.b64decode(value) except TypeError, err: raise messages.DecodeError('Base64 decoding error: %s' % err) elif isinstance(field, message_types.DateTimeField): try: return util.decode_datetime(value) except ValueError, err: raise messages.DecodeError(err) elif isinstance(field, messages.MessageField): return self.__decode_dictionary(field.message_type, value) elif (isinstance(field, messages.FloatField) and isinstance(value, (int, long, basestring))): try: return float(value) except: pass elif (isinstance(field, messages.IntegerField) and isinstance(value, basestring)): try: return int(value) except: pass return value @staticmethod def get_default(): """Get default instanceof ProtoJson.""" try: return ProtoJson.__default except AttributeError: ProtoJson.__default = ProtoJson() return ProtoJson.__default @staticmethod def set_default(protocol): """Set the default instance of ProtoJson. Args: protocol: A ProtoJson instance. """ if not isinstance(protocol, ProtoJson): raise TypeError('Expected protocol of type ProtoJson') ProtoJson.__default = protocol CONTENT_TYPE = ProtoJson.CONTENT_TYPE ALTERNATIVE_CONTENT_TYPES = ProtoJson.ALTERNATIVE_CONTENT_TYPES encode_message = ProtoJson.get_default().encode_message decode_message = ProtoJson.get_default().decode_message protorpc-standalone-0.9.1/protorpc/protojson_test.py0000755000076500000240000004053712277637135024037 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.protojson.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import __builtin__ import base64 import datetime import sys import unittest from protorpc import message_types from protorpc import messages from protorpc import protojson from protorpc import test_util import simplejson class MyMessage(messages.Message): """Test message containing various types.""" class Color(messages.Enum): RED = 1 GREEN = 2 BLUE = 3 class Nested(messages.Message): nested_value = messages.StringField(1) a_string = messages.StringField(2) an_integer = messages.IntegerField(3) a_float = messages.FloatField(4) a_boolean = messages.BooleanField(5) an_enum = messages.EnumField(Color, 6) a_nested = messages.MessageField(Nested, 7) a_repeated = messages.IntegerField(8, repeated=True) a_repeated_float = messages.FloatField(9, repeated=True) a_datetime = message_types.DateTimeField(10) a_repeated_datetime = message_types.DateTimeField(11, repeated=True) class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = protojson # TODO(rafek): Convert this test to the compliance test in test_util. class ProtojsonTest(test_util.TestCase, test_util.ProtoConformanceTestBase): """Test JSON encoding and decoding.""" PROTOLIB = protojson def CompareEncoded(self, expected_encoded, actual_encoded): """JSON encoding will be laundered to remove string differences.""" self.assertEquals(simplejson.loads(expected_encoded), simplejson.loads(actual_encoded)) encoded_empty_message = '{}' encoded_partial = """{ "double_value": 1.23, "int64_value": -100000000000, "int32_value": 1020, "string_value": "a string", "enum_value": "VAL2" } """ encoded_full = """{ "double_value": 1.23, "float_value": -2.5, "int64_value": -100000000000, "uint64_value": 102020202020, "int32_value": 1020, "bool_value": true, "string_value": "a string\u044f", "bytes_value": "YSBieXRlc//+", "enum_value": "VAL2" } """ encoded_repeated = """{ "double_value": [1.23, 2.3], "float_value": [-2.5, 0.5], "int64_value": [-100000000000, 20], "uint64_value": [102020202020, 10], "int32_value": [1020, 718], "bool_value": [true, false], "string_value": ["a string\u044f", "another string"], "bytes_value": ["YSBieXRlc//+", "YW5vdGhlciBieXRlcw=="], "enum_value": ["VAL2", "VAL1"] } """ encoded_nested = """{ "nested": { "a_value": "a string" } } """ encoded_repeated_nested = """{ "repeated_nested": [{"a_value": "a string"}, {"a_value": "another string"}] } """ unexpected_tag_message = '{"unknown": "value"}' encoded_default_assigned = '{"a_value": "a default"}' encoded_nested_empty = '{"nested": {}}' encoded_repeated_nested_empty = '{"repeated_nested": [{}, {}]}' encoded_extend_message = '{"int64_value": [400, 50, 6000]}' encoded_string_types = '{"string_value": "Latin"}' encoded_invalid_enum = '{"enum_value": "undefined"}' def testConvertIntegerToFloat(self): """Test that integers passed in to float fields are converted. This is necessary because JSON outputs integers for numbers with 0 decimals. """ message = protojson.decode_message(MyMessage, '{"a_float": 10}') self.assertTrue(isinstance(message.a_float, float)) self.assertEquals(10.0, message.a_float) def testConvertStringToNumbers(self): """Test that strings passed to integer fields are converted.""" message = protojson.decode_message(MyMessage, """{"an_integer": "10", "a_float": "3.5", "a_repeated": ["1", "2"], "a_repeated_float": ["1.5", "2", 10] }""") self.assertEquals(MyMessage(an_integer=10, a_float=3.5, a_repeated=[1, 2], a_repeated_float=[1.5, 2.0, 10.0]), message) def testWrongTypeAssignment(self): """Test when wrong type is assigned to a field.""" self.assertRaises(messages.ValidationError, protojson.decode_message, MyMessage, '{"a_string": 10}') self.assertRaises(messages.ValidationError, protojson.decode_message, MyMessage, '{"an_integer": 10.2}') self.assertRaises(messages.ValidationError, protojson.decode_message, MyMessage, '{"an_integer": "10.2"}') def testNumericEnumeration(self): """Test that numbers work for enum values.""" message = protojson.decode_message(MyMessage, '{"an_enum": 2}') expected_message = MyMessage() expected_message.an_enum = MyMessage.Color.GREEN self.assertEquals(expected_message, message) def testNullValues(self): """Test that null values overwrite existing values.""" self.assertEquals(MyMessage(), protojson.decode_message(MyMessage, ('{"an_integer": null,' ' "a_nested": null' '}'))) def testEmptyList(self): """Test that empty lists are ignored.""" self.assertEquals(MyMessage(), protojson.decode_message(MyMessage, '{"a_repeated": []}')) def testNotJSON(self): """Test error when string is not valid JSON.""" self.assertRaises(ValueError, protojson.decode_message, MyMessage, '{this is not json}') def testDoNotEncodeStrangeObjects(self): """Test trying to encode a strange object. The main purpose of this test is to complete coverage. It ensures that the default behavior of the JSON encoder is preserved when someone tries to serialized an unexpected type. """ class BogusObject(object): def check_initialized(self): pass self.assertRaises(TypeError, protojson.encode_message, BogusObject()) def testMergeEmptyString(self): """Test merging the empty or space only string.""" message = protojson.decode_message(test_util.OptionalMessage, '') self.assertEquals(test_util.OptionalMessage(), message) message = protojson.decode_message(test_util.OptionalMessage, ' ') self.assertEquals(test_util.OptionalMessage(), message) def testProtojsonUnrecognizedFieldName(self): """Test that unrecognized fields are saved and can be accessed.""" decoded = protojson.decode_message(MyMessage, ('{"an_integer": 1, "unknown_val": 2}')) self.assertEquals(decoded.an_integer, 1) self.assertEquals(1, len(decoded.all_unrecognized_fields())) self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0]) self.assertEquals((2, messages.Variant.INT64), decoded.get_unrecognized_field_info('unknown_val')) def testProtojsonUnrecognizedFieldNumber(self): """Test that unrecognized fields are saved and can be accessed.""" decoded = protojson.decode_message( MyMessage, '{"an_integer": 1, "1001": "unknown", "-123": "negative", ' '"456_mixed": 2}') self.assertEquals(decoded.an_integer, 1) self.assertEquals(3, len(decoded.all_unrecognized_fields())) self.assertIn(1001, decoded.all_unrecognized_fields()) self.assertEquals(('unknown', messages.Variant.STRING), decoded.get_unrecognized_field_info(1001)) self.assertIn('-123', decoded.all_unrecognized_fields()) self.assertEquals(('negative', messages.Variant.STRING), decoded.get_unrecognized_field_info('-123')) self.assertIn('456_mixed', decoded.all_unrecognized_fields()) self.assertEquals((2, messages.Variant.INT64), decoded.get_unrecognized_field_info('456_mixed')) def testProtojsonUnrecognizedNull(self): """Test that unrecognized fields that are None are skipped.""" decoded = protojson.decode_message( MyMessage, '{"an_integer": 1, "unrecognized_null": null}') self.assertEquals(decoded.an_integer, 1) self.assertEquals(decoded.all_unrecognized_fields(), []) def testUnrecognizedFieldVariants(self): """Test that unrecognized fields are mapped to the right variants.""" for encoded, expected_variant in ( ('{"an_integer": 1, "unknown_val": 2}', messages.Variant.INT64), ('{"an_integer": 1, "unknown_val": 2.0}', messages.Variant.DOUBLE), ('{"an_integer": 1, "unknown_val": "string value"}', messages.Variant.STRING), ('{"an_integer": 1, "unknown_val": [1, 2, 3]}', messages.Variant.INT64), ('{"an_integer": 1, "unknown_val": [1, 2.0, 3]}', messages.Variant.DOUBLE), ('{"an_integer": 1, "unknown_val": [1, "foo", 3]}', messages.Variant.STRING), ('{"an_integer": 1, "unknown_val": true}', messages.Variant.BOOL)): decoded = protojson.decode_message(MyMessage, encoded) self.assertEquals(decoded.an_integer, 1) self.assertEquals(1, len(decoded.all_unrecognized_fields())) self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0]) _, decoded_variant = decoded.get_unrecognized_field_info('unknown_val') self.assertEquals(expected_variant, decoded_variant) def testDecodeDateTime(self): for datetime_string, datetime_vals in ( ('2012-09-30T15:31:50.262', (2012, 9, 30, 15, 31, 50, 262000)), ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): message = protojson.decode_message( MyMessage, '{"a_datetime": "%s"}' % datetime_string) expected_message = MyMessage( a_datetime=datetime.datetime(*datetime_vals)) self.assertEquals(expected_message, message) def testDecodeInvalidDateTime(self): self.assertRaises(messages.DecodeError, protojson.decode_message, MyMessage, '{"a_datetime": "invalid"}') def testEncodeDateTime(self): for datetime_string, datetime_vals in ( ('2012-09-30T15:31:50.262000', (2012, 9, 30, 15, 31, 50, 262000)), ('2012-09-30T15:31:50.262123', (2012, 9, 30, 15, 31, 50, 262123)), ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): decoded_message = protojson.encode_message( MyMessage(a_datetime=datetime.datetime(*datetime_vals))) expected_decoding = '{"a_datetime": "%s"}' % datetime_string self.CompareEncoded(expected_decoding, decoded_message) def testDecodeRepeatedDateTime(self): message = protojson.decode_message( MyMessage, '{"a_repeated_datetime": ["2012-09-30T15:31:50.262", ' '"2010-01-21T09:52:00", "2000-01-01T01:00:59.999999"]}') expected_message = MyMessage( a_repeated_datetime=[ datetime.datetime(2012, 9, 30, 15, 31, 50, 262000), datetime.datetime(2010, 1, 21, 9, 52), datetime.datetime(2000, 1, 1, 1, 0, 59, 999999)]) self.assertEquals(expected_message, message) def testDecodeBadBase64BytesField(self): """Test decoding improperly encoded base64 bytes value.""" self.assertRaisesWithRegexpMatch( messages.DecodeError, 'Base64 decoding error: Incorrect padding', protojson.decode_message, test_util.OptionalMessage, '{"bytes_value": "abcdefghijklmnopq"}') class CustomProtoJson(protojson.ProtoJson): def encode_field(self, field, value): return '{encoded}' + value def decode_field(self, field, value): return '{decoded}' + value class CustomProtoJsonTest(test_util.TestCase): """Tests for serialization overriding functionality.""" def setUp(self): self.protojson = CustomProtoJson() def testEncode(self): self.assertEqual('{"a_string": "{encoded}xyz"}', self.protojson.encode_message(MyMessage(a_string='xyz'))) def testDecode(self): self.assertEqual( MyMessage(a_string='{decoded}xyz'), self.protojson.decode_message(MyMessage, '{"a_string": "xyz"}')) def testDefault(self): self.assertTrue(protojson.ProtoJson.get_default(), protojson.ProtoJson.get_default()) instance = CustomProtoJson() protojson.ProtoJson.set_default(instance) self.assertTrue(instance is protojson.ProtoJson.get_default()) class InvalidJsonModule(object): pass class ValidJsonModule(object): class JSONEncoder(object): pass class TestJsonDependencyLoading(test_util.TestCase): """Test loading various implementations of json.""" def get_import(self): """Get __import__ method. Returns: The current __import__ method. """ if isinstance(__builtins__, dict): return __builtins__['__import__'] else: return __builtins__.__import__ def set_import(self, new_import): """Set __import__ method. Args: new_import: Function to replace __import__. """ if isinstance(__builtins__, dict): __builtins__['__import__'] = new_import else: __builtins__.__import__ = new_import def setUp(self): """Save original import function.""" self.simplejson = sys.modules.pop('simplejson', None) self.json = sys.modules.pop('json', None) self.original_import = self.get_import() def block_all_jsons(name, *args, **kwargs): if 'json' in name: if name in sys.modules: module = sys.modules[name] module.name = name return module raise ImportError('Unable to find %s' % name) else: return self.original_import(name, *args, **kwargs) self.set_import(block_all_jsons) def tearDown(self): """Restore original import functions and any loaded modules.""" def reset_module(name, module): if module: sys.modules[name] = module else: sys.modules.pop(name, None) reset_module('simplejson', self.simplejson) reset_module('json', self.json) reload(protojson) def testLoadProtojsonWithValidJsonModule(self): """Test loading protojson module with a valid json dependency.""" sys.modules['json'] = ValidJsonModule # This will cause protojson to reload with the default json module # instead of simplejson. reload(protojson) self.assertEquals('json', protojson.json.name) def testLoadProtojsonWithSimplejsonModule(self): """Test loading protojson module with simplejson dependency.""" sys.modules['simplejson'] = ValidJsonModule # This will cause protojson to reload with the default json module # instead of simplejson. reload(protojson) self.assertEquals('simplejson', protojson.json.name) def testLoadProtojsonWithInvalidJsonModule(self): """Loading protojson module with an invalid json defaults to simplejson.""" sys.modules['json'] = InvalidJsonModule sys.modules['simplejson'] = ValidJsonModule # Ignore bad module and default back to simplejson. reload(protojson) self.assertEquals('simplejson', protojson.json.name) def testLoadProtojsonWithInvalidJsonModuleAndNoSimplejson(self): """Loading protojson module with invalid json and no simplejson.""" sys.modules['json'] = InvalidJsonModule # Bad module without simplejson back raises errors. self.assertRaisesWithRegexpMatch( ImportError, 'json library "json" is not compatible with ProtoRPC', reload, protojson) def testLoadProtojsonWithNoJsonModules(self): """Loading protojson module with invalid json and no simplejson.""" # No json modules raise the first exception. self.assertRaisesWithRegexpMatch( ImportError, 'Unable to find json', reload, protojson) if __name__ == '__main__': unittest.main() protorpc-standalone-0.9.1/protorpc/protorpc_test_pb2.py0000755000076500000240000003705012277637135024411 0ustar jeremydwstaff00000000000000# Generated by the protocol buffer compiler. DO NOT EDIT (except the imports)! # Replace auto generated imports with .non_sdk_imports manually! # Do the replacement and copy this comment everytime! from .non_sdk_imports import descriptor from .non_sdk_imports import message from .non_sdk_imports import reflection from .non_sdk_imports import descriptor_pb2 # @@protoc_insertion_point(imports) DESCRIPTOR = descriptor.FileDescriptor( name='protorpc_test.proto', package='protorpc', serialized_pb='\n\x13protorpc_test.proto\x12\x08protorpc\" \n\rNestedMessage\x12\x0f\n\x07\x61_value\x18\x01 \x02(\t\"m\n\x10HasNestedMessage\x12\'\n\x06nested\x18\x01 \x01(\x0b\x32\x17.protorpc.NestedMessage\x12\x30\n\x0frepeated_nested\x18\x02 \x03(\x0b\x32\x17.protorpc.NestedMessage\"(\n\nHasDefault\x12\x1a\n\x07\x61_value\x18\x01 \x01(\t:\ta default\"\x97\x02\n\x0fOptionalMessage\x12\x14\n\x0c\x64ouble_value\x18\x01 \x01(\x01\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x13\n\x0bint64_value\x18\x03 \x01(\x03\x12\x14\n\x0cuint64_value\x18\x04 \x01(\x04\x12\x13\n\x0bint32_value\x18\x05 \x01(\x05\x12\x12\n\nbool_value\x18\x06 \x01(\x08\x12\x14\n\x0cstring_value\x18\x07 \x01(\t\x12\x13\n\x0b\x62ytes_value\x18\x08 \x01(\x0c\x12\x38\n\nenum_value\x18\n \x01(\x0e\x32$.protorpc.OptionalMessage.SimpleEnum\" \n\nSimpleEnum\x12\x08\n\x04VAL1\x10\x01\x12\x08\n\x04VAL2\x10\x02\"\x97\x02\n\x0fRepeatedMessage\x12\x14\n\x0c\x64ouble_value\x18\x01 \x03(\x01\x12\x13\n\x0b\x66loat_value\x18\x02 \x03(\x02\x12\x13\n\x0bint64_value\x18\x03 \x03(\x03\x12\x14\n\x0cuint64_value\x18\x04 \x03(\x04\x12\x13\n\x0bint32_value\x18\x05 \x03(\x05\x12\x12\n\nbool_value\x18\x06 \x03(\x08\x12\x14\n\x0cstring_value\x18\x07 \x03(\t\x12\x13\n\x0b\x62ytes_value\x18\x08 \x03(\x0c\x12\x38\n\nenum_value\x18\n \x03(\x0e\x32$.protorpc.RepeatedMessage.SimpleEnum\" \n\nSimpleEnum\x12\x08\n\x04VAL1\x10\x01\x12\x08\n\x04VAL2\x10\x02\"y\n\x18HasOptionalNestedMessage\x12)\n\x06nested\x18\x01 \x01(\x0b\x32\x19.protorpc.OptionalMessage\x12\x32\n\x0frepeated_nested\x18\x02 \x03(\x0b\x32\x19.protorpc.OptionalMessage') _OPTIONALMESSAGE_SIMPLEENUM = descriptor.EnumDescriptor( name='SimpleEnum', full_name='protorpc.OptionalMessage.SimpleEnum', filename=None, file=DESCRIPTOR, values=[ descriptor.EnumValueDescriptor( name='VAL1', index=0, number=1, options=None, type=None), descriptor.EnumValueDescriptor( name='VAL2', index=1, number=2, options=None, type=None), ], containing_type=None, options=None, serialized_start=468, serialized_end=500, ) _REPEATEDMESSAGE_SIMPLEENUM = descriptor.EnumDescriptor( name='SimpleEnum', full_name='protorpc.RepeatedMessage.SimpleEnum', filename=None, file=DESCRIPTOR, values=[ descriptor.EnumValueDescriptor( name='VAL1', index=0, number=1, options=None, type=None), descriptor.EnumValueDescriptor( name='VAL2', index=1, number=2, options=None, type=None), ], containing_type=None, options=None, serialized_start=468, serialized_end=500, ) _NESTEDMESSAGE = descriptor.Descriptor( name='NestedMessage', full_name='protorpc.NestedMessage', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ descriptor.FieldDescriptor( name='a_value', full_name='protorpc.NestedMessage.a_value', index=0, number=1, type=9, cpp_type=9, label=2, has_default_value=False, default_value=unicode("", "utf-8"), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, extension_ranges=[], serialized_start=33, serialized_end=65, ) _HASNESTEDMESSAGE = descriptor.Descriptor( name='HasNestedMessage', full_name='protorpc.HasNestedMessage', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ descriptor.FieldDescriptor( name='nested', full_name='protorpc.HasNestedMessage.nested', index=0, number=1, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='repeated_nested', full_name='protorpc.HasNestedMessage.repeated_nested', index=1, number=2, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, extension_ranges=[], serialized_start=67, serialized_end=176, ) _HASDEFAULT = descriptor.Descriptor( name='HasDefault', full_name='protorpc.HasDefault', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ descriptor.FieldDescriptor( name='a_value', full_name='protorpc.HasDefault.a_value', index=0, number=1, type=9, cpp_type=9, label=1, has_default_value=True, default_value=unicode("a default", "utf-8"), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, extension_ranges=[], serialized_start=178, serialized_end=218, ) _OPTIONALMESSAGE = descriptor.Descriptor( name='OptionalMessage', full_name='protorpc.OptionalMessage', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ descriptor.FieldDescriptor( name='double_value', full_name='protorpc.OptionalMessage.double_value', index=0, number=1, type=1, cpp_type=5, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='float_value', full_name='protorpc.OptionalMessage.float_value', index=1, number=2, type=2, cpp_type=6, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='int64_value', full_name='protorpc.OptionalMessage.int64_value', index=2, number=3, type=3, cpp_type=2, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='uint64_value', full_name='protorpc.OptionalMessage.uint64_value', index=3, number=4, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='int32_value', full_name='protorpc.OptionalMessage.int32_value', index=4, number=5, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='bool_value', full_name='protorpc.OptionalMessage.bool_value', index=5, number=6, type=8, cpp_type=7, label=1, has_default_value=False, default_value=False, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='string_value', full_name='protorpc.OptionalMessage.string_value', index=6, number=7, type=9, cpp_type=9, label=1, has_default_value=False, default_value=unicode("", "utf-8"), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='bytes_value', full_name='protorpc.OptionalMessage.bytes_value', index=7, number=8, type=12, cpp_type=9, label=1, has_default_value=False, default_value="", message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='enum_value', full_name='protorpc.OptionalMessage.enum_value', index=8, number=10, type=14, cpp_type=8, label=1, has_default_value=False, default_value=1, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ _OPTIONALMESSAGE_SIMPLEENUM, ], options=None, is_extendable=False, extension_ranges=[], serialized_start=221, serialized_end=500, ) _REPEATEDMESSAGE = descriptor.Descriptor( name='RepeatedMessage', full_name='protorpc.RepeatedMessage', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ descriptor.FieldDescriptor( name='double_value', full_name='protorpc.RepeatedMessage.double_value', index=0, number=1, type=1, cpp_type=5, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='float_value', full_name='protorpc.RepeatedMessage.float_value', index=1, number=2, type=2, cpp_type=6, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='int64_value', full_name='protorpc.RepeatedMessage.int64_value', index=2, number=3, type=3, cpp_type=2, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='uint64_value', full_name='protorpc.RepeatedMessage.uint64_value', index=3, number=4, type=4, cpp_type=4, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='int32_value', full_name='protorpc.RepeatedMessage.int32_value', index=4, number=5, type=5, cpp_type=1, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='bool_value', full_name='protorpc.RepeatedMessage.bool_value', index=5, number=6, type=8, cpp_type=7, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='string_value', full_name='protorpc.RepeatedMessage.string_value', index=6, number=7, type=9, cpp_type=9, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='bytes_value', full_name='protorpc.RepeatedMessage.bytes_value', index=7, number=8, type=12, cpp_type=9, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='enum_value', full_name='protorpc.RepeatedMessage.enum_value', index=8, number=10, type=14, cpp_type=8, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ _REPEATEDMESSAGE_SIMPLEENUM, ], options=None, is_extendable=False, extension_ranges=[], serialized_start=503, serialized_end=782, ) _HASOPTIONALNESTEDMESSAGE = descriptor.Descriptor( name='HasOptionalNestedMessage', full_name='protorpc.HasOptionalNestedMessage', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ descriptor.FieldDescriptor( name='nested', full_name='protorpc.HasOptionalNestedMessage.nested', index=0, number=1, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='repeated_nested', full_name='protorpc.HasOptionalNestedMessage.repeated_nested', index=1, number=2, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, extension_ranges=[], serialized_start=784, serialized_end=905, ) _HASNESTEDMESSAGE.fields_by_name['nested'].message_type = _NESTEDMESSAGE _HASNESTEDMESSAGE.fields_by_name['repeated_nested'].message_type = _NESTEDMESSAGE _OPTIONALMESSAGE.fields_by_name['enum_value'].enum_type = _OPTIONALMESSAGE_SIMPLEENUM _OPTIONALMESSAGE_SIMPLEENUM.containing_type = _OPTIONALMESSAGE; _REPEATEDMESSAGE.fields_by_name['enum_value'].enum_type = _REPEATEDMESSAGE_SIMPLEENUM _REPEATEDMESSAGE_SIMPLEENUM.containing_type = _REPEATEDMESSAGE; _HASOPTIONALNESTEDMESSAGE.fields_by_name['nested'].message_type = _OPTIONALMESSAGE _HASOPTIONALNESTEDMESSAGE.fields_by_name['repeated_nested'].message_type = _OPTIONALMESSAGE DESCRIPTOR.message_types_by_name['NestedMessage'] = _NESTEDMESSAGE DESCRIPTOR.message_types_by_name['HasNestedMessage'] = _HASNESTEDMESSAGE DESCRIPTOR.message_types_by_name['HasDefault'] = _HASDEFAULT DESCRIPTOR.message_types_by_name['OptionalMessage'] = _OPTIONALMESSAGE DESCRIPTOR.message_types_by_name['RepeatedMessage'] = _REPEATEDMESSAGE DESCRIPTOR.message_types_by_name['HasOptionalNestedMessage'] = _HASOPTIONALNESTEDMESSAGE class NestedMessage(message.Message): __metaclass__ = reflection.GeneratedProtocolMessageType DESCRIPTOR = _NESTEDMESSAGE # @@protoc_insertion_point(class_scope:protorpc.NestedMessage) class HasNestedMessage(message.Message): __metaclass__ = reflection.GeneratedProtocolMessageType DESCRIPTOR = _HASNESTEDMESSAGE # @@protoc_insertion_point(class_scope:protorpc.HasNestedMessage) class HasDefault(message.Message): __metaclass__ = reflection.GeneratedProtocolMessageType DESCRIPTOR = _HASDEFAULT # @@protoc_insertion_point(class_scope:protorpc.HasDefault) class OptionalMessage(message.Message): __metaclass__ = reflection.GeneratedProtocolMessageType DESCRIPTOR = _OPTIONALMESSAGE # @@protoc_insertion_point(class_scope:protorpc.OptionalMessage) class RepeatedMessage(message.Message): __metaclass__ = reflection.GeneratedProtocolMessageType DESCRIPTOR = _REPEATEDMESSAGE # @@protoc_insertion_point(class_scope:protorpc.RepeatedMessage) class HasOptionalNestedMessage(message.Message): __metaclass__ = reflection.GeneratedProtocolMessageType DESCRIPTOR = _HASOPTIONALNESTEDMESSAGE # @@protoc_insertion_point(class_scope:protorpc.HasOptionalNestedMessage) # @@protoc_insertion_point(module_scope) protorpc-standalone-0.9.1/protorpc/protourlencode.py0000755000076500000240000004407512277637135024010 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """URL encoding support for messages types. Protocol support for URL encoded form parameters. Nested Fields: Nested fields are repesented by dot separated names. For example, consider the following messages: class WebPage(Message): title = StringField(1) tags = StringField(2, repeated=True) class WebSite(Message): name = StringField(1) home = MessageField(WebPage, 2) pages = MessageField(WebPage, 3, repeated=True) And consider the object: page = WebPage() page.title = 'Welcome to NewSite 2010' site = WebSite() site.name = 'NewSite 2010' site.home = page The URL encoded representation of this constellation of objects is. name=NewSite+2010&home.title=Welcome+to+NewSite+2010 An object that exists but does not have any state can be represented with a reference to its name alone with no value assigned to it. For example: page = WebSite() page.name = 'My Empty Site' page.home = WebPage() is represented as: name=My+Empty+Site&home= This represents a site with an empty uninitialized home page. Repeated Fields: Repeated fields are represented by the name of and the index of each value separated by a dash. For example, consider the following message: home = Page() home.title = 'Nome' news = Page() news.title = 'News' news.tags = ['news', 'articles'] instance = WebSite() instance.name = 'Super fun site' instance.pages = [home, news, preferences] An instance of this message can be represented as: name=Super+fun+site&page-0.title=Home&pages-1.title=News&... pages-1.tags-0=new&pages-1.tags-1=articles Helper classes: URLEncodedRequestBuilder: Used for encapsulating the logic used for building a request message from a URL encoded RPC. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import cgi import re import urllib from . import message_types from . import messages from . import util __all__ = ['CONTENT_TYPE', 'URLEncodedRequestBuilder', 'encode_message', 'decode_message', ] CONTENT_TYPE = 'application/x-www-form-urlencoded' _FIELD_NAME_REGEX = re.compile(r'^([a-zA-Z_][a-zA-Z_0-9]*)(?:-([0-9]+))?$') class URLEncodedRequestBuilder(object): """Helper that encapsulates the logic used for building URL encoded messages. This helper is used to map query parameters from a URL encoded RPC to a message instance. """ @util.positional(2) def __init__(self, message, prefix=''): """Constructor. Args: message: Message instance to build from parameters. prefix: Prefix expected at the start of valid parameters. """ self.__parameter_prefix = prefix # The empty tuple indicates the root message, which has no path. # __messages is a full cache that makes it very easy to look up message # instances by their paths. See make_path for details about what a path # is. self.__messages = {(): message} # This is a cache that stores paths which have been checked for # correctness. Correctness means that an index is present for repeated # fields on the path and absent for non-repeated fields. The cache is # also used to check that indexes are added in the right order so that # dicontiguous ranges of indexes are ignored. self.__checked_indexes = set([()]) def make_path(self, parameter_name): """Parse a parameter name and build a full path to a message value. The path of a method is a tuple of 2-tuples describing the names and indexes within repeated fields from the root message (the message being constructed by the builder) to an arbitrarily nested message within it. Each 2-tuple node of a path (name, index) is: name: The name of the field that refers to the message instance. index: The index within a repeated field that refers to the message instance, None if not a repeated field. For example, consider: class VeryInner(messages.Message): ... class Inner(messages.Message): very_inner = messages.MessageField(VeryInner, 1, repeated=True) class Outer(messages.Message): inner = messages.MessageField(Inner, 1) If this builder is building an instance of Outer, that instance is referred to in the URL encoded parameters without a path. Therefore its path is (). The child 'inner' is referred to by its path (('inner', None)). The first child of repeated field 'very_inner' on the Inner instance is referred to by (('inner', None), ('very_inner', 0)). Examples: # Correct reference to model where nation is a Message, district is # repeated Message and county is any not repeated field type. >>> make_path('nation.district-2.county') (('nation', None), ('district', 2), ('county', None)) # Field is not part of model. >>> make_path('nation.made_up_field') None # nation field is not repeated and index provided. >>> make_path('nation-1') None # district field is repeated and no index provided. >>> make_path('nation.district') None Args: parameter_name: Name of query parameter as passed in from the request. in order to make a path, this parameter_name must point to a valid field within the message structure. Nodes of the path that refer to repeated fields must be indexed with a number, non repeated nodes must not have an index. Returns: Parsed version of the parameter_name as a tuple of tuples: attribute: Name of attribute associated with path. index: Postitive integer index when it is a repeated field, else None. Will return None if the parameter_name does not have the right prefix, does not point to a field within the message structure, does not have an index if it is a repeated field or has an index but is not a repeated field. """ if parameter_name.startswith(self.__parameter_prefix): parameter_name = parameter_name[len(self.__parameter_prefix):] else: return None path = [] name = [] message_type = type(self.__messages[()]) # Get root message. for item in parameter_name.split('.'): # This will catch sub_message.real_message_field.not_real_field if not message_type: return None item_match = _FIELD_NAME_REGEX.match(item) if not item_match: return None attribute = item_match.group(1) index = item_match.group(2) if index: index = int(index) try: field = message_type.field_by_name(attribute) except KeyError: return None if field.repeated != (index is not None): return None if isinstance(field, messages.MessageField): message_type = field.message_type else: message_type = None # Path is valid so far. Append node and continue. path.append((attribute, index)) return tuple(path) def __check_index(self, parent_path, name, index): """Check correct index use and value relative to a given path. Check that for a given path the index is present for repeated fields and that it is in range for the existing list that it will be inserted in to or appended to. Args: parent_path: Path to check against name and index. name: Name of field to check for existance. index: Index to check. If field is repeated, should be a number within range of the length of the field, or point to the next item for appending. """ # Don't worry about non-repeated fields. # It's also ok if index is 0 because that means next insert will append. if not index: return True parent = self.__messages.get(parent_path, None) value_list = getattr(parent, name, None) # If the list does not exist then the index should be 0. Since it is # not, path is not valid. if not value_list: return False # The index must either point to an element of the list or to the tail. return len(value_list) >= index def __check_indexes(self, path): """Check that all indexes are valid and in the right order. This method must iterate over the path and check that all references to indexes point to an existing message or to the end of the list, meaning the next value should be appended to the repeated field. Args: path: Path to check indexes for. Tuple of 2-tuples (name, index). See make_path for more information. Returns: True if all the indexes of the path are within range, else False. """ if path in self.__checked_indexes: return True # Start with the root message. parent_path = () for name, index in path: next_path = parent_path + ((name, index),) # First look in the checked indexes cache. if next_path not in self.__checked_indexes: if not self.__check_index(parent_path, name, index): return False self.__checked_indexes.add(next_path) parent_path = next_path return True def __get_or_create_path(self, path): """Get a message from the messages cache or create it and add it. This method will also create any parent messages based on the path. When a new instance of a given message is created, it is stored in __message by its path. Args: path: Path of message to get. Path must be valid, in other words __check_index(path) returns true. Tuple of 2-tuples (name, index). See make_path for more information. Returns: Message instance if the field being pointed to by the path is a message, else will return None for non-message fields. """ message = self.__messages.get(path, None) if message: return message parent_path = () parent = self.__messages[()] # Get the root object for name, index in path: field = parent.field_by_name(name) next_path = parent_path + ((name, index),) next_message = self.__messages.get(next_path, None) if next_message is None: next_message = field.message_type() self.__messages[next_path] = next_message if not field.repeated: setattr(parent, field.name, next_message) else: list_value = getattr(parent, field.name, None) if list_value is None: setattr(parent, field.name, [next_message]) else: list_value.append(next_message) parent_path = next_path parent = next_message return parent def add_parameter(self, parameter, values): """Add a single parameter. Adds a single parameter and its value to the request message. Args: parameter: Query string parameter to map to request. values: List of values to assign to request message. Returns: True if parameter was valid and added to the message, else False. Raises: DecodeError if the parameter refers to a valid field, and the values parameter does not have one and only one value. Non-valid query parameters may have multiple values and should not cause an error. """ path = self.make_path(parameter) if not path: return False # Must check that all indexes of all items in the path are correct before # instantiating any of them. For example, consider: # # class Repeated(object): # ... # # class Inner(object): # # repeated = messages.MessageField(Repeated, 1, repeated=True) # # class Outer(object): # # inner = messages.MessageField(Inner, 1) # # instance = Outer() # builder = URLEncodedRequestBuilder(instance) # builder.add_parameter('inner.repeated') # # assert not hasattr(instance, 'inner') # # The check is done relative to the instance of Outer pass in to the # constructor of the builder. This instance is not referred to at all # because all names are assumed to be relative to it. # # The 'repeated' part of the path is not correct because it is missing an # index. Because it is missing an index, it should not create an instance # of Repeated. In this case add_parameter will return False and have no # side effects. # # A correct path that would cause a new Inner instance to be inserted at # instance.inner and a new Repeated instance to be appended to the # instance.inner.repeated list would be 'inner.repeated-0'. if not self.__check_indexes(path): return False # Ok to build objects. parent_path = path[:-1] parent = self.__get_or_create_path(parent_path) name, index = path[-1] field = parent.field_by_name(name) if len(values) != 1: raise messages.DecodeError( 'Found repeated values for field %s.' % field.name) value = values[0] if isinstance(field, messages.IntegerField): converted_value = int(value) elif isinstance(field, message_types.DateTimeField): try: converted_value = util.decode_datetime(value) except ValueError, e: raise messages.DecodeError(e) elif isinstance(field, messages.MessageField): # Just make sure it's instantiated. Assignment to field or # appending to list is done in __get_or_create_path. self.__get_or_create_path(path) return True elif isinstance(field, messages.StringField): converted_value = value.decode('utf-8') elif isinstance(field, messages.BooleanField): converted_value = value.lower() == 'true' and True or False else: try: converted_value = field.type(value) except TypeError: raise messages.DecodeError('Invalid enum value "%s"' % value) if field.repeated: value_list = getattr(parent, field.name, None) if value_list is None: setattr(parent, field.name, [converted_value]) else: if index == len(value_list): value_list.append(converted_value) else: # Index should never be above len(value_list) because it was # verified during the index check above. value_list[index] = converted_value else: setattr(parent, field.name, converted_value) return True @util.positional(1) def encode_message(message, prefix=''): """Encode Message instance to url-encoded string. Args: message: Message instance to encode in to url-encoded string. prefix: Prefix to append to field names of contained values. Returns: String encoding of Message in URL encoded format. Raises: messages.ValidationError if message is not initialized. """ message.check_initialized() parameters = [] def build_message(parent, prefix): """Recursively build parameter list for URL response. Args: parent: Message to build parameters for. prefix: Prefix to append to field names of contained values. Returns: True if some value of parent was added to the parameters list, else False, meaning the object contained no values. """ has_any_values = False for field in sorted(parent.all_fields(), key=lambda f: f.number): next_value = parent.get_assigned_value(field.name) if next_value is None: continue # Found a value. Ultimate return value should be True. has_any_values = True # Normalize all values in to a list. if not field.repeated: next_value = [next_value] for index, item in enumerate(next_value): # Create a name with an index if it is a repeated field. if field.repeated: field_name = '%s%s-%s' % (prefix, field.name, index) else: field_name = prefix + field.name if isinstance(field, message_types.DateTimeField): # DateTimeField stores its data as a RFC 3339 compliant string. parameters.append((field_name, item.isoformat())) elif isinstance(field, messages.MessageField): # Message fields must be recursed in to in order to construct # their component parameter values. if not build_message(item, field_name + '.'): # The nested message is empty. Append an empty value to # represent it. parameters.append((field_name, '')) elif isinstance(field, messages.BooleanField): parameters.append((field_name, item and 'true' or 'false')) else: if isinstance(item, unicode): item = item.encode('utf-8') parameters.append((field_name, str(item))) return has_any_values build_message(message, prefix) # Also add any unrecognized values from the decoded string. for key in message.all_unrecognized_fields(): values, _ = message.get_unrecognized_field_info(key) if not isinstance(values, (list, tuple)): values = (values,) for value in values: parameters.append((key, value)) return urllib.urlencode(parameters) def decode_message(message_type, encoded_message, **kwargs): """Decode urlencoded content to message. Args: message_type: Message instance to merge URL encoded content into. encoded_message: URL encoded message. prefix: Prefix to append to field names of contained values. Returns: Decoded instance of message_type. """ message = message_type() builder = URLEncodedRequestBuilder(message, **kwargs) arguments = cgi.parse_qs(encoded_message, keep_blank_values=True) for argument, values in sorted(arguments.iteritems()): added = builder.add_parameter(argument, values) # Save off any unknown values, so they're still accessible. if not added: message.set_unrecognized_field(argument, values, messages.Variant.STRING) message.check_initialized() return message protorpc-standalone-0.9.1/protorpc/protourlencode_test.py0000755000076500000240000003554712277637135025053 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.protourlencode.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import cgi import logging import unittest import urllib from protorpc import message_types from protorpc import messages from protorpc import protourlencode from protorpc import test_util class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = protourlencode class SuperMessage(messages.Message): """A test message with a nested message field.""" sub_message = messages.MessageField(test_util.OptionalMessage, 1) sub_messages = messages.MessageField(test_util.OptionalMessage, 2, repeated=True) class SuperSuperMessage(messages.Message): """A test message with two levels of nested.""" sub_message = messages.MessageField(SuperMessage, 1) sub_messages = messages.MessageField(SuperMessage, 2, repeated=True) class URLEncodedRequestBuilderTest(test_util.TestCase): """Test the URL Encoded request builder.""" def testMakePath(self): builder = protourlencode.URLEncodedRequestBuilder(SuperSuperMessage(), prefix='pre.') self.assertEquals(None, builder.make_path('')) self.assertEquals(None, builder.make_path('no_such_field')) self.assertEquals(None, builder.make_path('pre.no_such_field')) # Missing prefix. self.assertEquals(None, builder.make_path('sub_message')) # Valid parameters. self.assertEquals((('sub_message', None),), builder.make_path('pre.sub_message')) self.assertEquals((('sub_message', None), ('sub_messages', 1)), builder.make_path('pre.sub_message.sub_messages-1')) self.assertEquals( (('sub_message', None), ('sub_messages', 1), ('int64_value', None)), builder.make_path('pre.sub_message.sub_messages-1.int64_value')) # Missing index. self.assertEquals( None, builder.make_path('pre.sub_message.sub_messages.integer_field')) # Has unexpected index. self.assertEquals( None, builder.make_path('pre.sub_message.sub_message-1.integer_field')) def testAddParameter_SimpleAttributes(self): message = test_util.OptionalMessage() builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') self.assertTrue(builder.add_parameter('pre.int64_value', ['10'])) self.assertTrue(builder.add_parameter('pre.string_value', ['a string'])) self.assertTrue(builder.add_parameter('pre.enum_value', ['VAL1'])) self.assertEquals(10, message.int64_value) self.assertEquals('a string', message.string_value) self.assertEquals(test_util.OptionalMessage.SimpleEnum.VAL1, message.enum_value) def testAddParameter_InvalidAttributes(self): message = SuperSuperMessage() builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') def assert_empty(): self.assertEquals(None, getattr(message, 'sub_message')) self.assertEquals([], getattr(message, 'sub_messages')) self.assertFalse(builder.add_parameter('pre.nothing', ['x'])) assert_empty() self.assertFalse(builder.add_parameter('pre.sub_messages', ['x'])) self.assertFalse(builder.add_parameter('pre.sub_messages-1.nothing', ['x'])) assert_empty() def testAddParameter_NestedAttributes(self): message = SuperSuperMessage() builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') # Set an empty message fields. self.assertTrue(builder.add_parameter('pre.sub_message', [''])) self.assertTrue(isinstance(message.sub_message, SuperMessage)) # Add a basic attribute. self.assertTrue(builder.add_parameter( 'pre.sub_message.sub_message.int64_value', ['10'])) self.assertTrue(builder.add_parameter( 'pre.sub_message.sub_message.string_value', ['hello'])) self.assertTrue(10, message.sub_message.sub_message.int64_value) self.assertTrue('hello', message.sub_message.sub_message.string_value) def testAddParameter_NestedMessages(self): message = SuperSuperMessage() builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') # Add a repeated empty message. self.assertTrue(builder.add_parameter( 'pre.sub_message.sub_messages-0', [''])) sub_message = message.sub_message.sub_messages[0] self.assertTrue(1, len(message.sub_message.sub_messages)) self.assertTrue(isinstance(sub_message, test_util.OptionalMessage)) self.assertEquals(None, getattr(sub_message, 'int64_value')) self.assertEquals(None, getattr(sub_message, 'string_value')) self.assertEquals(None, getattr(sub_message, 'enum_value')) # Add a repeated message with value. self.assertTrue(builder.add_parameter( 'pre.sub_message.sub_messages-1.int64_value', ['10'])) self.assertTrue(2, len(message.sub_message.sub_messages)) self.assertTrue(10, message.sub_message.sub_messages[1].int64_value) # Add another value to the same nested message. self.assertTrue(builder.add_parameter( 'pre.sub_message.sub_messages-1.string_value', ['a string'])) self.assertTrue(2, len(message.sub_message.sub_messages)) self.assertEquals(10, message.sub_message.sub_messages[1].int64_value) self.assertEquals('a string', message.sub_message.sub_messages[1].string_value) def testAddParameter_RepeatedValues(self): message = test_util.RepeatedMessage() builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') self.assertTrue(builder.add_parameter('pre.int64_value-0', ['20'])) self.assertTrue(builder.add_parameter('pre.int64_value-1', ['30'])) self.assertEquals([20, 30], message.int64_value) self.assertTrue(builder.add_parameter('pre.string_value-0', ['hi'])) self.assertTrue(builder.add_parameter('pre.string_value-1', ['lo'])) self.assertTrue(builder.add_parameter('pre.string_value-1', ['dups overwrite'])) self.assertEquals(['hi', 'dups overwrite'], message.string_value) def testAddParameter_InvalidValuesMayRepeat(self): message = test_util.OptionalMessage() builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') self.assertFalse(builder.add_parameter('nothing', [1, 2, 3])) def testAddParameter_RepeatedParameters(self): message = test_util.OptionalMessage() builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') self.assertRaises(messages.DecodeError, builder.add_parameter, 'pre.int64_value', [1, 2, 3]) self.assertRaises(messages.DecodeError, builder.add_parameter, 'pre.int64_value', []) def testAddParameter_UnexpectedNestedValue(self): """Test getting a nested value on a non-message sub-field.""" message = test_util.HasNestedMessage() builder = protourlencode.URLEncodedRequestBuilder(message, 'pre.') self.assertFalse(builder.add_parameter('pre.nested.a_value.whatever', ['1'])) def testInvalidFieldFormat(self): message = test_util.OptionalMessage() builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') self.assertFalse(builder.add_parameter('pre.illegal%20', ['1'])) def testAddParameter_UnexpectedNestedValue(self): """Test getting a nested value on a non-message sub-field There is an odd corner case where if trying to insert a repeated value on an nested repeated message that would normally succeed in being created should fail. This case can only be tested when the first message of the nested messages already exists. Another case is trying to access an indexed value nested within a non-message field. """ class HasRepeated(messages.Message): values = messages.IntegerField(1, repeated=True) class HasNestedRepeated(messages.Message): nested = messages.MessageField(HasRepeated, 1, repeated=True) message = HasNestedRepeated() builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') self.assertTrue(builder.add_parameter('pre.nested-0.values-0', ['1'])) # Try to create an indexed value on a non-message field. self.assertFalse(builder.add_parameter('pre.nested-0.values-0.unknown-0', ['1'])) # Try to create an out of range indexed field on an otherwise valid # repeated message field. self.assertFalse(builder.add_parameter('pre.nested-1.values-1', ['1'])) class ProtourlencodeConformanceTest(test_util.TestCase, test_util.ProtoConformanceTestBase): PROTOLIB = protourlencode encoded_partial = urllib.urlencode([('double_value', 1.23), ('int64_value', -100000000000), ('int32_value', 1020), ('string_value', u'a string'), ('enum_value', 'VAL2'), ]) encoded_full = urllib.urlencode([('double_value', 1.23), ('float_value', -2.5), ('int64_value', -100000000000), ('uint64_value', 102020202020), ('int32_value', 1020), ('bool_value', 'true'), ('string_value', u'a string\u044f'.encode('utf-8')), ('bytes_value', 'a bytes\xff\xfe'), ('enum_value', 'VAL2'), ]) encoded_repeated = urllib.urlencode([('double_value-0', 1.23), ('double_value-1', 2.3), ('float_value-0', -2.5), ('float_value-1', 0.5), ('int64_value-0', -100000000000), ('int64_value-1', 20), ('uint64_value-0', 102020202020), ('uint64_value-1', 10), ('int32_value-0', 1020), ('int32_value-1', 718), ('bool_value-0', 'true'), ('bool_value-1', 'false'), ('string_value-0', u'a string\u044f'.encode('utf-8')), ('string_value-1', u'another string'.encode('utf-8')), ('bytes_value-0', 'a bytes\xff\xfe'), ('bytes_value-1', 'another bytes'), ('enum_value-0', 'VAL2'), ('enum_value-1', 'VAL1'), ]) encoded_nested = urllib.urlencode([('nested.a_value', 'a string'), ]) encoded_repeated_nested = urllib.urlencode( [('repeated_nested-0.a_value', 'a string'), ('repeated_nested-1.a_value', 'another string'), ]) unexpected_tag_message = 'unexpected=whatever' encoded_default_assigned = urllib.urlencode([('a_value', 'a default'), ]) encoded_nested_empty = urllib.urlencode([('nested', '')]) encoded_repeated_nested_empty = urllib.urlencode([('repeated_nested-0', ''), ('repeated_nested-1', '')]) encoded_extend_message = urllib.urlencode([('int64_value-0', 400), ('int64_value-1', 50), ('int64_value-2', 6000)]) encoded_string_types = urllib.urlencode( [('string_value', 'Latin')]) encoded_invalid_enum = urllib.urlencode([('enum_value', 'undefined')]) def testParameterPrefix(self): """Test using the 'prefix' parameter to encode_message.""" class MyMessage(messages.Message): number = messages.IntegerField(1) names = messages.StringField(2, repeated=True) message = MyMessage() message.number = 10 message.names = [u'Fred', u'Lisa'] encoded_message = protourlencode.encode_message(message, prefix='prefix-') self.assertEquals({'prefix-number': ['10'], 'prefix-names-0': ['Fred'], 'prefix-names-1': ['Lisa'], }, cgi.parse_qs(encoded_message)) self.assertEquals(message, protourlencode.decode_message(MyMessage, encoded_message, prefix='prefix-')) def testProtourlencodeUnrecognizedField(self): """Test that unrecognized fields are saved and can be accessed.""" class MyMessage(messages.Message): number = messages.IntegerField(1) decoded = protourlencode.decode_message(MyMessage, self.unexpected_tag_message) self.assertEquals(1, len(decoded.all_unrecognized_fields())) self.assertEquals('unexpected', decoded.all_unrecognized_fields()[0]) # Unknown values set to a list of however many values had that name. self.assertEquals((['whatever'], messages.Variant.STRING), decoded.get_unrecognized_field_info('unexpected')) repeated_unknown = urllib.urlencode([('repeated', 400), ('repeated', 'test'), ('repeated', '123.456')]) decoded2 = protourlencode.decode_message(MyMessage, repeated_unknown) self.assertEquals((['400', 'test', '123.456'], messages.Variant.STRING), decoded2.get_unrecognized_field_info('repeated')) def testDecodeInvalidDateTime(self): class MyMessage(messages.Message): a_datetime = message_types.DateTimeField(1) self.assertRaises(messages.DecodeError, protourlencode.decode_message, MyMessage, 'a_datetime=invalid') if __name__ == '__main__': unittest.main() protorpc-standalone-0.9.1/protorpc/registry.py0000755000076500000240000002011412277637135022600 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Service regsitry for service discovery. The registry service can be deployed on a server in order to provide a central place where remote clients can discover available. On the server side, each service is registered by their name which is unique to the registry. Typically this name provides enough information to identify the service and locate it within a server. For example, for an HTTP based registry the name is the URL path on the host where the service is invocable. The registry is also able to resolve the full descriptor.FileSet necessary to describe the service and all required data-types (messages and enums). A configured registry is itself a remote service and should reference itself. """ import sys from . import descriptor from . import messages from . import remote from . import util __all__ = [ 'ServiceMapping', 'ServicesResponse', 'GetFileSetRequest', 'GetFileSetResponse', 'RegistryService', ] class ServiceMapping(messages.Message): """Description of registered service. Fields: name: Name of service. On HTTP based services this will be the URL path used for invocation. definition: Fully qualified name of the service definition. Useful for clients that can look up service definitions based on an existing repository of definitions. """ name = messages.StringField(1, required=True) definition = messages.StringField(2, required=True) class ServicesResponse(messages.Message): """Response containing all registered services. May also contain complete descriptor file-set for all services known by the registry. Fields: services: Service mappings for all registered services in registry. file_set: Descriptor file-set describing all services, messages and enum types needed for use with all requested services if asked for in the request. """ services = messages.MessageField(ServiceMapping, 1, repeated=True) class GetFileSetRequest(messages.Message): """Request for service descriptor file-set. Request to retrieve file sets for specific services. Fields: names: Names of services to retrieve file-set for. """ names = messages.StringField(1, repeated=True) class GetFileSetResponse(messages.Message): """Descriptor file-set for all names in GetFileSetRequest. Fields: file_set: Descriptor file-set containing all descriptors for services, messages and enum types needed for listed names in request. """ file_set = messages.MessageField(descriptor.FileSet, 1, required=True) class RegistryService(remote.Service): """Registry service. Maps names to services and is able to describe all descriptor file-sets necessary to use contined services. On an HTTP based server, the name is the URL path to the service. """ @util.positional(2) def __init__(self, registry, modules=None): """Constructor. Args: registry: Map of name to service class. This map is not copied and may be modified after the reigstry service has been configured. modules: Module dict to draw descriptors from. Defaults to sys.modules. """ # Private Attributes: # __registry: Map of name to service class. Refers to same instance as # registry parameter. # __modules: Mapping of module name to module. # __definition_to_modules: Mapping of definition types to set of modules # that they refer to. This cache is used to make repeated look-ups # faster and to prevent circular references from causing endless loops. self.__registry = registry if modules is None: modules = sys.modules self.__modules = modules # This cache will only last for a single request. self.__definition_to_modules = {} def __find_modules_for_message(self, message_type): """Find modules referred to by a message type. Determines the entire list of modules ultimately referred to by message_type by iterating over all of its message and enum fields. Includes modules referred to fields within its referred messages. Args: message_type: Message type to find all referring modules for. Returns: Set of modules referred to by message_type by traversing all its message and enum fields. """ # TODO(rafek): Maybe this should be a method on Message and Service? def get_dependencies(message_type, seen=None): """Get all dependency definitions of a message type. This function works by collecting the types of all enumeration and message fields defined within the message type. When encountering a message field, it will recursivly find all of the associated message's dependencies. It will terminate on circular dependencies by keeping track of what definitions it already via the seen set. Args: message_type: Message type to get dependencies for. seen: Set of definitions that have already been visited. Returns: All dependency message and enumerated types associated with this message including the message itself. """ if seen is None: seen = set() seen.add(message_type) for field in message_type.all_fields(): if isinstance(field, messages.MessageField): if field.message_type not in seen: get_dependencies(field.message_type, seen) elif isinstance(field, messages.EnumField): seen.add(field.type) return seen found_modules = self.__definition_to_modules.setdefault(message_type, set()) if not found_modules: dependencies = get_dependencies(message_type) found_modules.update(self.__modules[definition.__module__] for definition in dependencies) return found_modules def __describe_file_set(self, names): """Get file-set for named services. Args: names: List of names to get file-set for. Returns: descriptor.FileSet containing all the descriptors for all modules ultimately referred to by all service types request by names parameter. """ service_modules = set() if names: for service in (self.__registry[name] for name in names): found_modules = self.__definition_to_modules.setdefault(service, set()) if not found_modules: found_modules.add(self.__modules[service.__module__]) for method_name in service.all_remote_methods(): method = getattr(service, method_name) for message_type in (method.remote.request_type, method.remote.response_type): found_modules.update( self.__find_modules_for_message(message_type)) service_modules.update(found_modules) return descriptor.describe_file_set(service_modules) @property def registry(self): """Get service registry associated with this service instance.""" return self.__registry @remote.method(response_type=ServicesResponse) def services(self, request): """Get all registered services.""" response = ServicesResponse() response.services = [] for name, service_class in self.__registry.iteritems(): mapping = ServiceMapping() mapping.name = name.decode('utf-8') mapping.definition = service_class.definition_name().decode('utf-8') response.services.append(mapping) return response @remote.method(GetFileSetRequest, GetFileSetResponse) def get_file_set(self, request): """Get file-set for registered servies.""" response = GetFileSetResponse() response.file_set = self.__describe_file_set(request.names) return response protorpc-standalone-0.9.1/protorpc/registry_test.py0000755000076500000240000000742512277637135023651 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.message.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import sys import unittest from protorpc import descriptor from protorpc import message_types from protorpc import messages from protorpc import registry from protorpc import remote from protorpc import test_util class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = registry class MyService1(remote.Service): """Test service that refers to messages in another module.""" @remote.method(test_util.NestedMessage, test_util.NestedMessage) def a_method(self, request): pass class MyService2(remote.Service): """Test service that does not refer to messages in another module.""" class RegistryServiceTest(test_util.TestCase): def setUp(self): self.registry = { 'my-service1': MyService1, 'my-service2': MyService2, } self.modules = { __name__: sys.modules[__name__], test_util.__name__: test_util, } self.registry_service = registry.RegistryService(self.registry, modules=self.modules) def CheckServiceMappings(self, mappings): module_name = test_util.get_module_name(RegistryServiceTest) service1_mapping = registry.ServiceMapping() service1_mapping.name = 'my-service1' service1_mapping.definition = '%s.MyService1' % module_name service2_mapping = registry.ServiceMapping() service2_mapping.name = 'my-service2' service2_mapping.definition = '%s.MyService2' % module_name self.assertIterEqual(mappings, [service1_mapping, service2_mapping]) def testServices(self): response = self.registry_service.services(message_types.VoidMessage()) self.CheckServiceMappings(response.services) def testGetFileSet_All(self): request = registry.GetFileSetRequest() request.names = ['my-service1', 'my-service2'] response = self.registry_service.get_file_set(request) expected_file_set = descriptor.describe_file_set(self.modules.values()) self.assertIterEqual(expected_file_set.files, response.file_set.files) def testGetFileSet_None(self): request = registry.GetFileSetRequest() response = self.registry_service.get_file_set(request) self.assertEquals(descriptor.FileSet(), response.file_set) def testGetFileSet_ReferenceOtherModules(self): request = registry.GetFileSetRequest() request.names = ['my-service1'] response = self.registry_service.get_file_set(request) # Will suck in and describe the test_util module. expected_file_set = descriptor.describe_file_set(self.modules.values()) self.assertIterEqual(expected_file_set.files, response.file_set.files) def testGetFileSet_DoNotReferenceOtherModules(self): request = registry.GetFileSetRequest() request.names = ['my-service2'] response = self.registry_service.get_file_set(request) # Service does not reference test_util, so will only describe this module. expected_file_set = descriptor.describe_file_set([self.modules[__name__]]) self.assertIterEqual(expected_file_set.files, response.file_set.files) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/remote.py0000755000076500000240000011422012277637135022225 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Remote service library. This module contains classes that are useful for building remote services that conform to a standard request and response model. To conform to this model a service must be like the following class: # Each service instance only handles a single request and is then discarded. # Make these objects light weight. class Service(object): # It must be possible to construct service objects without any parameters. # If your constructor needs extra information you should provide a # no-argument factory function to create service instances. def __init__(self): ... # Each remote method must use the 'method' decorator, passing the request # and response message types. The remote method itself must take a single # parameter which is an instance of RequestMessage and return an instance # of ResponseMessage. @method(RequestMessage, ResponseMessage) def remote_method(self, request): # Return an instance of ResponseMessage. # A service object may optionally implement an 'initialize_request_state' # method that takes as a parameter a single instance of a RequestState. If # a service does not implement this method it will not receive the request # state. def initialize_request_state(self, state): ... The 'Service' class is provided as a convenient base class that provides the above functionality. It implements all required and optional methods for a service. It also has convenience methods for creating factory functions that can pass persistent global state to a new service instance. The 'method' decorator is used to declare which methods of a class are meant to service RPCs. While this decorator is not responsible for handling actual remote method invocations, such as handling sockets, handling various RPC protocols and checking messages for correctness, it does attach information to methods that responsible classes can examine and ensure the correctness of the RPC. When the method decorator is used on a method, the wrapper method will have a 'remote' property associated with it. The 'remote' property contains the request_type and response_type expected by the methods implementation. On its own, the method decorator does not provide any support for subclassing remote methods. In order to extend a service, one would need to redecorate the sub-classes methods. For example: class MyService(Service): @method(DoSomethingRequest, DoSomethingResponse) def do_stuff(self, request): ... implement do_stuff ... class MyBetterService(MyService): @method(DoSomethingRequest, DoSomethingResponse) def do_stuff(self, request): response = super(MyBetterService, self).do_stuff.remote.method(request) ... do stuff with response ... return response A Service subclass also has a Stub class that can be used with a transport for making RPCs. When a stub is created, it is capable of doing both synchronous and asynchronous RPCs if the underlying transport supports it. To make a stub using an HTTP transport do: my_service = MyService.Stub(HttpTransport('')) For synchronous calls, just call the expected methods on the service stub: request = DoSomethingRequest() ... response = my_service.do_something(request) Each stub instance has an async object that can be used for initiating asynchronous RPCs if the underlying protocol transport supports it. To make an asynchronous call, do: rpc = my_service.async.do_something(request) response = rpc.get_response() """ from __future__ import with_statement __author__ = 'rafek@google.com (Rafe Kaplan)' import logging import sys import threading from wsgiref import headers as wsgi_headers from . import message_types from . import messages from . import protobuf from . import protojson from . import util __all__ = [ 'ApplicationError', 'MethodNotFoundError', 'NetworkError', 'RequestError', 'RpcError', 'ServerError', 'ServiceConfigurationError', 'ServiceDefinitionError', 'HttpRequestState', 'ProtocolConfig', 'Protocols', 'RequestState', 'RpcState', 'RpcStatus', 'Service', 'StubBase', 'check_rpc_status', 'get_remote_method_info', 'is_error_status', 'method', 'remote', ] class ServiceDefinitionError(messages.Error): """Raised when a service is improperly defined.""" class ServiceConfigurationError(messages.Error): """Raised when a service is incorrectly configured.""" # TODO: Use error_name to map to specific exception message types. class RpcStatus(messages.Message): """Status of on-going or complete RPC. Fields: state: State of RPC. error_name: Error name set by application. Only set when status is APPLICATION_ERROR. For use by application to transmit specific reason for error. error_message: Error message associated with status. """ class State(messages.Enum): """Enumeration of possible RPC states. Values: OK: Completed successfully. RUNNING: Still running, not complete. REQUEST_ERROR: Request was malformed or incomplete. SERVER_ERROR: Server experienced an unexpected error. NETWORK_ERROR: An error occured on the network. APPLICATION_ERROR: The application is indicating an error. When in this state, RPC should also set application_error. """ OK = 0 RUNNING = 1 REQUEST_ERROR = 2 SERVER_ERROR = 3 NETWORK_ERROR = 4 APPLICATION_ERROR = 5 METHOD_NOT_FOUND_ERROR = 6 state = messages.EnumField(State, 1, required=True) error_message = messages.StringField(2) error_name = messages.StringField(3) RpcState = RpcStatus.State class RpcError(messages.Error): """Base class for RPC errors. Each sub-class of RpcError is associated with an error value from RpcState and has an attribute STATE that refers to that value. """ def __init__(self, message, cause=None): super(RpcError, self).__init__(message) self.cause = cause @classmethod def from_state(cls, state): """Get error class from RpcState. Args: state: RpcState value. Can be enum value itself, string or int. Returns: Exception class mapped to value if state is an error. Returns None if state is OK or RUNNING. """ return _RPC_STATE_TO_ERROR.get(RpcState(state)) class RequestError(RpcError): """Raised when wrong request objects received during method invocation.""" STATE = RpcState.REQUEST_ERROR class MethodNotFoundError(RequestError): """Raised when unknown method requested by RPC.""" STATE = RpcState.METHOD_NOT_FOUND_ERROR class NetworkError(RpcError): """Raised when network error occurs during RPC.""" STATE = RpcState.NETWORK_ERROR class ServerError(RpcError): """Unexpected error occured on server.""" STATE = RpcState.SERVER_ERROR class ApplicationError(RpcError): """Raised for application specific errors. Attributes: error_name: Application specific error name for exception. """ STATE = RpcState.APPLICATION_ERROR def __init__(self, message, error_name=None): """Constructor. Args: message: Application specific error message. error_name: Application specific error name. Must be None, string or unicode string. """ super(ApplicationError, self).__init__(message) self.error_name = error_name def __str__(self): return self.args[0] def __repr__(self): if self.error_name is None: error_format = '' else: error_format = ', %r' % self.error_name return '%s(%r%s)' % (type(self).__name__, self.args[0], error_format) _RPC_STATE_TO_ERROR = { RpcState.REQUEST_ERROR: RequestError, RpcState.NETWORK_ERROR: NetworkError, RpcState.SERVER_ERROR: ServerError, RpcState.APPLICATION_ERROR: ApplicationError, RpcState.METHOD_NOT_FOUND_ERROR: MethodNotFoundError, } class _RemoteMethodInfo(object): """Object for encapsulating remote method information. An instance of this method is associated with the 'remote' attribute of the methods 'invoke_remote_method' instance. Instances of this class are created by the remote decorator and should not be created directly. """ def __init__(self, method, request_type, response_type): """Constructor. Args: method: The method which implements the remote method. This is a function that will act as an instance method of a class definition that is decorated by '@method'. It must always take 'self' as its first parameter. request_type: Expected request type for the remote method. response_type: Expected response type for the remote method. """ self.__method = method self.__request_type = request_type self.__response_type = response_type @property def method(self): """Original undecorated method.""" return self.__method @property def request_type(self): """Expected request type for remote method.""" if isinstance(self.__request_type, basestring): self.__request_type = messages.find_definition( self.__request_type, relative_to=sys.modules[self.__method.__module__]) return self.__request_type @property def response_type(self): """Expected response type for remote method.""" if isinstance(self.__response_type, basestring): self.__response_type = messages.find_definition( self.__response_type, relative_to=sys.modules[self.__method.__module__]) return self.__response_type def method(request_type=message_types.VoidMessage, response_type=message_types.VoidMessage): """Method decorator for creating remote methods. Args: request_type: Message type of expected request. response_type: Message type of expected response. Returns: 'remote_method_wrapper' function. Raises: TypeError: if the request_type or response_type parameters are not proper subclasses of messages.Message. """ if (not isinstance(request_type, basestring) and (not isinstance(request_type, type) or not issubclass(request_type, messages.Message) or request_type is messages.Message)): raise TypeError( 'Must provide message class for request-type. Found %s', request_type) if (not isinstance(response_type, basestring) and (not isinstance(response_type, type) or not issubclass(response_type, messages.Message) or response_type is messages.Message)): raise TypeError( 'Must provide message class for response-type. Found %s', response_type) def remote_method_wrapper(method): """Decorator used to wrap method. Args: method: Original method being wrapped. Returns: 'invoke_remote_method' function responsible for actual invocation. This invocation function instance is assigned an attribute 'remote' which contains information about the remote method: request_type: Expected request type for remote method. response_type: Response type returned from remote method. Raises: TypeError: If request_type or response_type is not a subclass of Message or is the Message class itself. """ def invoke_remote_method(service_instance, request): """Function used to replace original method. Invoke wrapped remote method. Checks to ensure that request and response objects are the correct types. Does not check whether messages are initialized. Args: service_instance: The service object whose method is being invoked. This is passed to 'self' during the invocation of the original method. request: Request message. Returns: Results of calling wrapped remote method. Raises: RequestError: Request object is not of the correct type. ServerError: Response object is not of the correct type. """ if not isinstance(request, remote_method_info.request_type): raise RequestError('Method %s.%s expected request type %s, ' 'received %s' % (type(service_instance).__name__, method.__name__, remote_method_info.request_type, type(request))) response = method(service_instance, request) if not isinstance(response, remote_method_info.response_type): raise ServerError('Method %s.%s expected response type %s, ' 'sent %s' % (type(service_instance).__name__, method.__name__, remote_method_info.response_type, type(response))) return response remote_method_info = _RemoteMethodInfo(method, request_type, response_type) invoke_remote_method.remote = remote_method_info invoke_remote_method.__name__ = method.__name__ return invoke_remote_method return remote_method_wrapper def remote(request_type, response_type): """Temporary backward compatibility alias for method.""" logging.warning('The remote decorator has been renamed method. It will be ' 'removed in very soon from future versions of ProtoRPC.') return method(request_type, response_type) def get_remote_method_info(method): """Get remote method info object from remote method. Returns: Remote method info object if method is a remote method, else None. """ if not callable(method): return None try: method_info = method.remote except AttributeError: return None if not isinstance(method_info, _RemoteMethodInfo): return None return method_info class StubBase(object): """Base class for client side service stubs. The remote method stubs are created by the _ServiceClass meta-class when a Service class is first created. The resulting stub will extend both this class and the service class it handles communications for. Assume that there is a service: class NewContactRequest(messages.Message): name = messages.StringField(1, required=True) phone = messages.StringField(2) email = messages.StringField(3) class NewContactResponse(message.Message): contact_id = messages.StringField(1) class AccountService(remote.Service): @remote.method(NewContactRequest, NewContactResponse): def new_contact(self, request): ... implementation ... A stub of this service can be called in two ways. The first is to pass in a correctly initialized NewContactRequest message: request = NewContactRequest() request.name = 'Bob Somebody' request.phone = '+1 415 555 1234' response = account_service_stub.new_contact(request) The second way is to pass in keyword parameters that correspond with the root request message type: account_service_stub.new_contact(name='Bob Somebody', phone='+1 415 555 1234') The second form will create a request message of the appropriate type. """ def __init__(self, transport): """Constructor. Args: transport: Underlying transport to communicate with remote service. """ self.__transport = transport @property def transport(self): """Transport used to communicate with remote service.""" return self.__transport class _ServiceClass(type): """Meta-class for service class.""" def __new_async_method(cls, remote): """Create asynchronous method for Async handler. Args: remote: RemoteInfo to create method for. """ def async_method(self, *args, **kwargs): """Asynchronous remote method. Args: self: Instance of StubBase.Async subclass. Stub methods either take a single positional argument when a full request message is passed in, or keyword arguments, but not both. See docstring for StubBase for more information on how to use remote stub methods. Returns: Rpc instance used to represent asynchronous RPC. """ if args and kwargs: raise TypeError('May not provide both args and kwargs') if not args: # Construct request object from arguments. request = remote.request_type() for name, value in kwargs.iteritems(): setattr(request, name, value) else: # First argument is request object. request = args[0] return self.transport.send_rpc(remote, request) async_method.__name__ = remote.method.__name__ async_method = util.positional(2)(async_method) async_method.remote = remote return async_method def __new_sync_method(cls, async_method): """Create synchronous method for stub. Args: async_method: asynchronous method to delegate calls to. """ def sync_method(self, *args, **kwargs): """Synchronous remote method. Args: self: Instance of StubBase.Async subclass. args: Tuple (request,): request: Request object. kwargs: Field values for request. Must be empty if request object is provided. Returns: Response message from synchronized RPC. """ return async_method(self.async, *args, **kwargs).response sync_method.__name__ = async_method.__name__ sync_method.remote = async_method.remote return sync_method def __create_async_methods(cls, remote_methods): """Construct a dictionary of asynchronous methods based on remote methods. Args: remote_methods: Dictionary of methods with associated RemoteInfo objects. Returns: Dictionary of asynchronous methods with assocaited RemoteInfo objects. Results added to AsyncStub subclass. """ async_methods = {} for method_name, method in remote_methods.iteritems(): async_methods[method_name] = cls.__new_async_method(method.remote) return async_methods def __create_sync_methods(cls, async_methods): """Construct a dictionary of synchronous methods based on remote methods. Args: async_methods: Dictionary of async methods to delegate calls to. Returns: Dictionary of synchronous methods with assocaited RemoteInfo objects. Results added to Stub subclass. """ sync_methods = {} for method_name, async_method in async_methods.iteritems(): sync_methods[method_name] = cls.__new_sync_method(async_method) return sync_methods def __new__(cls, name, bases, dct): """Instantiate new service class instance.""" if StubBase not in bases: # Collect existing remote methods. base_methods = {} for base in bases: try: remote_methods = base.__remote_methods except AttributeError: pass else: base_methods.update(remote_methods) # Set this class private attribute so that base_methods do not have # to be recacluated in __init__. dct['_ServiceClass__base_methods'] = base_methods for attribute, value in dct.iteritems(): base_method = base_methods.get(attribute, None) if base_method: if not callable(value): raise ServiceDefinitionError( 'Must override %s in %s with a method.' % ( attribute, name)) if get_remote_method_info(value): raise ServiceDefinitionError( 'Do not use method decorator when overloading remote method %s ' 'on service %s.' % (attribute, name)) base_remote_method_info = get_remote_method_info(base_method) remote_decorator = method( base_remote_method_info.request_type, base_remote_method_info.response_type) new_remote_method = remote_decorator(value) dct[attribute] = new_remote_method return type.__new__(cls, name, bases, dct) def __init__(cls, name, bases, dct): """Create uninitialized state on new class.""" type.__init__(cls, name, bases, dct) # Only service implementation classes should have remote methods and stub # sub classes created. Stub implementations have their own methods passed # in to the type constructor. if StubBase not in bases: # Create list of remote methods. cls.__remote_methods = dict(cls.__base_methods) for attribute, value in dct.iteritems(): value = getattr(cls, attribute) remote_method_info = get_remote_method_info(value) if remote_method_info: cls.__remote_methods[attribute] = value # Build asynchronous stub class. stub_attributes = {'Service': cls} async_methods = cls.__create_async_methods(cls.__remote_methods) stub_attributes.update(async_methods) async_class = type('AsyncStub', (StubBase, cls), stub_attributes) cls.AsyncStub = async_class # Constructor for synchronous stub class. def __init__(self, transport): """Constructor. Args: transport: Underlying transport to communicate with remote service. """ super(cls.Stub, self).__init__(transport) self.async = cls.AsyncStub(transport) # Build synchronous stub class. stub_attributes = {'Service': cls, '__init__': __init__} stub_attributes.update(cls.__create_sync_methods(async_methods)) cls.Stub = type('Stub', (StubBase, cls), stub_attributes) @staticmethod def all_remote_methods(cls): """Get all remote methods of service. Returns: Dict from method name to unbound method. """ return dict(cls.__remote_methods) class RequestState(object): """Request state information. Properties: remote_host: Remote host name where request originated. remote_address: IP address where request originated. server_host: Host of server within which service resides. server_port: Post which service has recevied request from. """ @util.positional(1) def __init__(self, remote_host=None, remote_address=None, server_host=None, server_port=None): """Constructor. Args: remote_host: Assigned to property. remote_address: Assigned to property. server_host: Assigned to property. server_port: Assigned to property. """ self.__remote_host = remote_host self.__remote_address = remote_address self.__server_host = server_host self.__server_port = server_port @property def remote_host(self): return self.__remote_host @property def remote_address(self): return self.__remote_address @property def server_host(self): return self.__server_host @property def server_port(self): return self.__server_port def _repr_items(self): for name in ['remote_host', 'remote_address', 'server_host', 'server_port']: yield name, getattr(self, name) def __repr__(self): """String representation of state.""" state = [self.__class__.__name__] for name, value in self._repr_items(): if value: state.append('%s=%r' % (name, value)) return '<%s>' % (' '.join(state),) class HttpRequestState(RequestState): """HTTP request state information. NOTE: Does not attempt to represent certain types of information from the request such as the query string as query strings are not permitted in ProtoRPC URLs unless required by the underlying message format. Properties: headers: wsgiref.headers.Headers instance of HTTP request headers. http_method: HTTP method as a string. service_path: Path on HTTP service where service is mounted. This path will not include the remote method name. """ @util.positional(1) def __init__(self, http_method=None, service_path=None, headers=None, **kwargs): """Constructor. Args: Same as RequestState, including: http_method: Assigned to property. service_path: Assigned to property. headers: HTTP request headers. If instance of Headers, assigned to property without copying. If dict, will convert to name value pairs for use with Headers constructor. Otherwise, passed as parameters to Headers constructor. """ super(HttpRequestState, self).__init__(**kwargs) self.__http_method = http_method self.__service_path = service_path # Initialize headers. if isinstance(headers, dict): header_list = [] for key, value in sorted(headers.items()): if not isinstance(value, list): value = [value] for item in value: header_list.append((key, item)) headers = header_list self.__headers = wsgi_headers.Headers(headers or []) @property def http_method(self): return self.__http_method @property def service_path(self): return self.__service_path @property def headers(self): return self.__headers def _repr_items(self): for item in super(HttpRequestState, self)._repr_items(): yield item for name in ['http_method', 'service_path']: yield name, getattr(self, name) yield 'headers', list(self.headers.items()) class Service(object): """Service base class. Base class used for defining remote services. Contains reflection functions, useful helpers and built-in remote methods. Services are expected to be constructed via either a constructor or factory which takes no parameters. However, it might be required that some state or configuration is passed in to a service across multiple requests. To do this, define parameters to the constructor of the service and use the 'new_factory' class method to build a constructor that will transmit parameters to the constructor. For example: class MyService(Service): def __init__(self, configuration, state): self.configuration = configuration self.state = state configuration = MyServiceConfiguration() global_state = MyServiceState() my_service_factory = MyService.new_factory(configuration, state=global_state) The contract with any service handler is that a new service object is created to handle each user request, and that the construction does not take any parameters. The factory satisfies this condition: new_instance = my_service_factory() assert new_instance.state is global_state Attributes: request_state: RequestState set via initialize_request_state. """ __metaclass__ = _ServiceClass __request_state = None @classmethod def all_remote_methods(cls): """Get all remote methods for service class. Built-in methods do not appear in the dictionary of remote methods. Returns: Dictionary mapping method name to remote method. """ return _ServiceClass.all_remote_methods(cls) @classmethod def new_factory(cls, *args, **kwargs): """Create factory for service. Useful for passing configuration or state objects to the service. Accepts arbitrary parameters and keywords, however, underlying service must accept also accept not other parameters in its constructor. Args: args: Args to pass to service constructor. kwargs: Keyword arguments to pass to service constructor. Returns: Factory function that will create a new instance and forward args and keywords to the constructor. """ def service_factory(): return cls(*args, **kwargs) # Update docstring so that it is easier to debug. full_class_name = '%s.%s' % (cls.__module__, cls.__name__) service_factory.func_doc = ( 'Creates new instances of service %s.\n\n' 'Returns:\n' ' New instance of %s.' % (cls.__name__, full_class_name)) # Update name so that it is easier to debug the factory function. service_factory.func_name = '%s_service_factory' % cls.__name__ service_factory.service_class = cls return service_factory def initialize_request_state(self, request_state): """Save request state for use in remote method. Args: request_state: RequestState instance. """ self.__request_state = request_state @classmethod def definition_name(cls): """Get definition name for Service class. Package name is determined by the global 'package' attribute in the module that contains the Service definition. If no 'package' attribute is available, uses module name. If no module is found, just uses class name as name. Returns: Fully qualified service name. """ try: return cls.__definition_name except AttributeError: outer_definition_name = cls.outer_definition_name() if outer_definition_name is None: cls.__definition_name = cls.__name__ else: cls.__definition_name = '%s.%s' % (outer_definition_name, cls.__name__) return cls.__definition_name @classmethod def outer_definition_name(cls): """Get outer definition name. Returns: Package for service. Services are never nested inside other definitions. """ return cls.definition_package() @classmethod def definition_package(cls): """Get package for service. Returns: Package name for service. """ try: return cls.__definition_package except AttributeError: cls.__definition_package = util.get_package_for_module(cls.__module__) return cls.__definition_package @property def request_state(self): """Request state associated with this Service instance.""" return self.__request_state def is_error_status(status): """Function that determines whether the RPC status is an error. Args: status: Initialized RpcStatus message to check for errors. """ status.check_initialized() return RpcError.from_state(status.state) is not None def check_rpc_status(status): """Function converts an error status to a raised exception. Args: status: Initialized RpcStatus message to check for errors. Raises: RpcError according to state set on status, if it is an error state. """ status.check_initialized() error_class = RpcError.from_state(status.state) if error_class is not None: if error_class is ApplicationError: raise error_class(status.error_message, status.error_name) else: raise error_class(status.error_message) class ProtocolConfig(object): """Configuration for single protocol mapping. A read-only protocol configuration provides a given protocol implementation with a name and a set of content-types that it recognizes. Properties: protocol: The protocol implementation for configuration (usually a module, for example, protojson, protobuf, etc.). This is an object that has the following attributes: CONTENT_TYPE: Used as the default content-type if default_content_type is not set. ALTERNATIVE_CONTENT_TYPES (optional): A list of alternative content-types to the default that indicate the same protocol. encode_message: Function that matches the signature of ProtocolConfig.encode_message. Used for encoding a ProtoRPC message. decode_message: Function that matches the signature of ProtocolConfig.decode_message. Used for decoding a ProtoRPC message. name: Name of protocol configuration. default_content_type: The default content type for the protocol. Overrides CONTENT_TYPE defined on protocol. alternative_content_types: A list of alternative content-types supported by the protocol. Must not contain the default content-type, nor duplicates. Overrides ALTERNATIVE_CONTENT_TYPE defined on protocol. content_types: A list of all content-types supported by configuration. Combination of default content-type and alternatives. """ def __init__(self, protocol, name, default_content_type=None, alternative_content_types=None): """Constructor. Args: protocol: The protocol implementation for configuration. name: The name of the protocol configuration. default_content_type: The default content-type for protocol. If none provided it will check protocol.CONTENT_TYPE. alternative_content_types: A list of content-types. If none provided, it will check protocol.ALTERNATIVE_CONTENT_TYPES. If that attribute does not exist, will be an empty tuple. Raises: ServiceConfigurationError if there are any duplicate content-types. """ self.__protocol = protocol self.__name = name self.__default_content_type = (default_content_type or protocol.CONTENT_TYPE).lower() if alternative_content_types is None: alternative_content_types = getattr(protocol, 'ALTERNATIVE_CONTENT_TYPES', ()) self.__alternative_content_types = tuple( content_type.lower() for content_type in alternative_content_types) self.__content_types = ( (self.__default_content_type,) + self.__alternative_content_types) # Detect duplicate content types in definition. previous_type = None for content_type in sorted(self.content_types): if content_type == previous_type: raise ServiceConfigurationError( 'Duplicate content-type %s' % content_type) previous_type = content_type @property def protocol(self): return self.__protocol @property def name(self): return self.__name @property def default_content_type(self): return self.__default_content_type @property def alternate_content_types(self): return self.__alternative_content_types @property def content_types(self): return self.__content_types def encode_message(self, message): """Encode message. Args: message: Message instance to encode. Returns: String encoding of Message instance encoded in protocol's format. """ return self.__protocol.encode_message(message) def decode_message(self, message_type, encoded_message): """Decode buffer to Message instance. Args: message_type: Message type to decode data to. encoded_message: Encoded version of message as string. Returns: Decoded instance of message_type. """ return self.__protocol.decode_message(message_type, encoded_message) class Protocols(object): """Collection of protocol configurations. Used to describe a complete set of content-type mappings for multiple protocol configurations. Properties: names: Sorted list of the names of registered protocols. content_types: Sorted list of supported content-types. """ __default_protocols = None __lock = threading.Lock() def __init__(self): """Constructor.""" self.__by_name = {} self.__by_content_type = {} def add_protocol_config(self, config): """Add a protocol configuration to protocol mapping. Args: config: A ProtocolConfig. Raises: ServiceConfigurationError if protocol.name is already registered or any of it's content-types are already registered. """ if config.name in self.__by_name: raise ServiceConfigurationError( 'Protocol name %r is already in use' % config.name) for content_type in config.content_types: if content_type in self.__by_content_type: raise ServiceConfigurationError( 'Content type %r is already in use' % content_type) self.__by_name[config.name] = config self.__by_content_type.update((t, config) for t in config.content_types) def add_protocol(self, *args, **kwargs): """Add a protocol configuration from basic parameters. Simple helper method that creates and registeres a ProtocolConfig instance. """ self.add_protocol_config(ProtocolConfig(*args, **kwargs)) @property def names(self): return tuple(sorted(self.__by_name)) @property def content_types(self): return tuple(sorted(self.__by_content_type)) def lookup_by_name(self, name): """Look up a ProtocolConfig by name. Args: name: Name of protocol to look for. Returns: ProtocolConfig associated with name. Raises: KeyError if there is no protocol for name. """ return self.__by_name[name.lower()] def lookup_by_content_type(self, content_type): """Look up a ProtocolConfig by content-type. Args: content_type: Content-type to find protocol configuration for. Returns: ProtocolConfig associated with content-type. Raises: KeyError if there is no protocol for content-type. """ return self.__by_content_type[content_type.lower()] @classmethod def new_default(cls): """Create default protocols configuration. Returns: New Protocols instance configured for protobuf and protorpc. """ protocols = cls() protocols.add_protocol(protobuf, 'protobuf') protocols.add_protocol(protojson.ProtoJson.get_default(), 'protojson') return protocols @classmethod def get_default(cls): """Get the global default Protocols instance. Returns: Current global default Protocols instance. """ default_protocols = cls.__default_protocols if default_protocols is None: with cls.__lock: default_protocols = cls.__default_protocols if default_protocols is None: default_protocols = cls.new_default() cls.__default_protocols = default_protocols return default_protocols @classmethod def set_default(cls, protocols): """Set the global default Protocols instance. Args: protocols: A Protocols instance. Raises: TypeError: If protocols is not an instance of Protocols. """ if not isinstance(protocols, Protocols): raise TypeError( 'Expected value of type "Protocols", found %r' % protocols) with cls.__lock: cls.__default_protocols = protocols protorpc-standalone-0.9.1/protorpc/remote_test.py0000755000076500000240000007463612277637135023304 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.remote.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import sys import types import unittest from wsgiref import headers from protorpc import descriptor from protorpc import message_types from protorpc import messages from protorpc import protobuf from protorpc import protojson from protorpc import remote from protorpc import test_util from protorpc import transport import mox class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = remote class Request(messages.Message): """Test request message.""" value = messages.StringField(1) class Response(messages.Message): """Test response message.""" value = messages.StringField(1) class MyService(remote.Service): @remote.method(Request, Response) def remote_method(self, request): response = Response() response.value = request.value return response class SimpleRequest(messages.Message): """Simple request message type used for tests.""" param1 = messages.StringField(1) param2 = messages.StringField(2) class SimpleResponse(messages.Message): """Simple response message type used for tests.""" class BasicService(remote.Service): """A basic service with decorated remote method.""" def __init__(self): self.request_ids = [] @remote.method(SimpleRequest, SimpleResponse) def remote_method(self, request): self.request_ids.append(id(request)) return SimpleResponse() class RpcErrorTest(test_util.TestCase): def testFromStatus(self): for state in remote.RpcState: exception = remote.RpcError.from_state self.assertEquals(remote.ServerError, remote.RpcError.from_state('SERVER_ERROR')) class ApplicationErrorTest(test_util.TestCase): def testErrorCode(self): self.assertEquals('blam', remote.ApplicationError('an error', 'blam').error_name) def testStr(self): self.assertEquals('an error', str(remote.ApplicationError('an error', 1))) def testRepr(self): self.assertEquals("ApplicationError('an error', 1)", repr(remote.ApplicationError('an error', 1))) self.assertEquals("ApplicationError('an error')", repr(remote.ApplicationError('an error'))) class MethodTest(test_util.TestCase): """Test remote method decorator.""" def testMethod(self): """Test use of remote decorator.""" self.assertEquals(SimpleRequest, BasicService.remote_method.remote.request_type) self.assertEquals(SimpleResponse, BasicService.remote_method.remote.response_type) self.assertTrue(isinstance(BasicService.remote_method.remote.method, types.FunctionType)) def testMethodMessageResolution(self): """Test use of remote decorator to resolve message types by name.""" class OtherService(remote.Service): @remote.method('SimpleRequest', 'SimpleResponse') def remote_method(self, request): pass self.assertEquals(SimpleRequest, OtherService.remote_method.remote.request_type) self.assertEquals(SimpleResponse, OtherService.remote_method.remote.response_type) def testMethodMessageResolution_NotFound(self): """Test failure to find message types.""" class OtherService(remote.Service): @remote.method('NoSuchRequest', 'NoSuchResponse') def remote_method(self, request): pass self.assertRaisesWithRegexpMatch( messages.DefinitionNotFoundError, 'Could not find definition for NoSuchRequest', getattr, OtherService.remote_method.remote, 'request_type') self.assertRaisesWithRegexpMatch( messages.DefinitionNotFoundError, 'Could not find definition for NoSuchResponse', getattr, OtherService.remote_method.remote, 'response_type') def testInvocation(self): """Test that invocation passes request through properly.""" service = BasicService() request = SimpleRequest() self.assertEquals(SimpleResponse(), service.remote_method(request)) self.assertEquals([id(request)], service.request_ids) def testInvocation_WrongRequestType(self): """Wrong request type passed to remote method.""" service = BasicService() self.assertRaises(remote.RequestError, service.remote_method, 'wrong') self.assertRaises(remote.RequestError, service.remote_method, None) self.assertRaises(remote.RequestError, service.remote_method, SimpleResponse()) def testInvocation_WrongResponseType(self): """Wrong response type returned from remote method.""" class AnotherService(object): @remote.method(SimpleRequest, SimpleResponse) def remote_method(self, unused_request): return self.return_this service = AnotherService() service.return_this = 'wrong' self.assertRaises(remote.ServerError, service.remote_method, SimpleRequest()) service.return_this = None self.assertRaises(remote.ServerError, service.remote_method, SimpleRequest()) service.return_this = SimpleRequest() self.assertRaises(remote.ServerError, service.remote_method, SimpleRequest()) def testBadRequestType(self): """Test bad request types used in remote definition.""" for request_type in (None, 1020, messages.Message, str): def declare(): class BadService(object): @remote.method(request_type, SimpleResponse) def remote_method(self, request): pass self.assertRaises(TypeError, declare) def testBadResponseType(self): """Test bad response types used in remote definition.""" for response_type in (None, 1020, messages.Message, str): def declare(): class BadService(object): @remote.method(SimpleRequest, response_type) def remote_method(self, request): pass self.assertRaises(TypeError, declare) class GetRemoteMethodTest(test_util.TestCase): """Test for is_remote_method.""" def testGetRemoteMethod(self): """Test valid remote method detection.""" class Service(object): @remote.method(Request, Response) def remote_method(self, request): pass self.assertEquals(Service.remote_method.remote, remote.get_remote_method_info(Service.remote_method)) self.assertTrue(Service.remote_method.remote, remote.get_remote_method_info(Service().remote_method)) def testGetNotRemoteMethod(self): """Test positive result on a remote method.""" class NotService(object): def not_remote_method(self, request): pass def fn(self): pass class NotReallyRemote(object): """Test negative result on many bad values for remote methods.""" def not_really(self, request): pass not_really.remote = 'something else' for not_remote in [NotService.not_remote_method, NotService().not_remote_method, NotReallyRemote.not_really, NotReallyRemote().not_really, None, 1, 'a string', fn]: self.assertEquals(None, remote.get_remote_method_info(not_remote)) class RequestStateTest(test_util.TestCase): """Test request state.""" STATE_CLASS = remote.RequestState def testConstructor(self): """Test constructor.""" state = self.STATE_CLASS(remote_host='remote-host', remote_address='remote-address', server_host='server-host', server_port=10) self.assertEquals('remote-host', state.remote_host) self.assertEquals('remote-address', state.remote_address) self.assertEquals('server-host', state.server_host) self.assertEquals(10, state.server_port) state = self.STATE_CLASS() self.assertEquals(None, state.remote_host) self.assertEquals(None, state.remote_address) self.assertEquals(None, state.server_host) self.assertEquals(None, state.server_port) def testConstructorError(self): """Test unexpected keyword argument.""" self.assertRaises(TypeError, self.STATE_CLASS, x=10) def testRepr(self): """Test string representation.""" self.assertEquals('<%s>' % self.STATE_CLASS.__name__, repr(self.STATE_CLASS())) self.assertEquals("<%s remote_host='abc'>" % self.STATE_CLASS.__name__, repr(self.STATE_CLASS(remote_host='abc'))) self.assertEquals("<%s remote_host='abc' " "remote_address='def'>" % self.STATE_CLASS.__name__, repr(self.STATE_CLASS(remote_host='abc', remote_address='def'))) self.assertEquals("<%s remote_host='abc' " "remote_address='def' " "server_host='ghi'>" % self.STATE_CLASS.__name__, repr(self.STATE_CLASS(remote_host='abc', remote_address='def', server_host='ghi'))) self.assertEquals("<%s remote_host='abc' " "remote_address='def' " "server_host='ghi' " 'server_port=102>' % self.STATE_CLASS.__name__, repr(self.STATE_CLASS(remote_host='abc', remote_address='def', server_host='ghi', server_port=102))) class HttpRequestStateTest(RequestStateTest): STATE_CLASS = remote.HttpRequestState def testHttpMethod(self): state = remote.HttpRequestState(http_method='GET') self.assertEquals('GET', state.http_method) def testHttpMethod(self): state = remote.HttpRequestState(service_path='/bar') self.assertEquals('/bar', state.service_path) def testHeadersList(self): state = remote.HttpRequestState( headers=[('a', 'b'), ('c', 'd'), ('c', 'e')]) self.assertEquals(['a', 'c', 'c'], state.headers.keys()) self.assertEquals(['b'], state.headers.get_all('a')) self.assertEquals(['d', 'e'], state.headers.get_all('c')) def testHeadersDict(self): state = remote.HttpRequestState(headers={'a': 'b', 'c': ['d', 'e']}) self.assertEquals(['a', 'c', 'c'], sorted(state.headers.keys())) self.assertEquals(['b'], state.headers.get_all('a')) self.assertEquals(['d', 'e'], state.headers.get_all('c')) def testRepr(self): super(HttpRequestStateTest, self).testRepr() self.assertEquals("<%s remote_host='abc' " "remote_address='def' " "server_host='ghi' " 'server_port=102 ' "http_method='POST' " "service_path='/bar' " "headers=[('a', 'b'), ('c', 'd')]>" % self.STATE_CLASS.__name__, repr(self.STATE_CLASS(remote_host='abc', remote_address='def', server_host='ghi', server_port=102, http_method='POST', service_path='/bar', headers={'a': 'b', 'c': 'd'}, ))) class ServiceTest(test_util.TestCase): """Test Service class.""" def testServiceBase_AllRemoteMethods(self): """Test that service base class has no remote methods.""" self.assertEquals({}, remote.Service.all_remote_methods()) def testAllRemoteMethods(self): """Test all_remote_methods with properly Service subclass.""" self.assertEquals({'remote_method': MyService.remote_method}, MyService.all_remote_methods()) def testAllRemoteMethods_SubClass(self): """Test all_remote_methods on a sub-class of a service.""" class SubClass(MyService): @remote.method(Request, Response) def sub_class_method(self, request): pass self.assertEquals({'remote_method': SubClass.remote_method, 'sub_class_method': SubClass.sub_class_method, }, SubClass.all_remote_methods()) def testOverrideMethod(self): """Test that trying to override a remote method with remote decorator.""" class SubClass(MyService): def remote_method(self, request): response = super(SubClass, self).remote_method(request) response.value = '(%s)' % response.value return response self.assertEquals({'remote_method': SubClass.remote_method, }, SubClass.all_remote_methods()) instance = SubClass() self.assertEquals('(Hello)', instance.remote_method(Request(value='Hello')).value) self.assertEquals(Request, SubClass.remote_method.remote.request_type) self.assertEquals(Response, SubClass.remote_method.remote.response_type) def testOverrideMethodWithRemote(self): """Test trying to override a remote method with remote decorator.""" def do_override(): class SubClass(MyService): @remote.method(Request, Response) def remote_method(self, request): pass self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, 'Do not use method decorator when ' 'overloading remote method remote_method ' 'on service SubClass', do_override) def testOverrideMethodWithInvalidValue(self): """Test trying to override a remote method with remote decorator.""" def do_override(bad_value): class SubClass(MyService): remote_method = bad_value for bad_value in [None, 1, 'string', {}]: self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, 'Must override remote_method in ' 'SubClass with a method', do_override, bad_value) def testCallingRemoteMethod(self): """Test invoking a remote method.""" expected = Response() expected.value = 'what was passed in' request = Request() request.value = 'what was passed in' service = MyService() self.assertEquals(expected, service.remote_method(request)) def testFactory(self): """Test using factory to pass in state.""" class StatefulService(remote.Service): def __init__(self, a, b, c=None): self.a = a self.b = b self.c = c state = [1, 2, 3] factory = StatefulService.new_factory(1, state) module_name = ServiceTest.__module__ pattern = ('Creates new instances of service StatefulService.\n\n' 'Returns:\n' ' New instance of %s.StatefulService.' % module_name) self.assertEqual(pattern, factory.func_doc) self.assertEquals('StatefulService_service_factory', factory.func_name) self.assertEquals(StatefulService, factory.service_class) service = factory() self.assertEquals(1, service.a) self.assertEquals(id(state), id(service.b)) self.assertEquals(None, service.c) factory = StatefulService.new_factory(2, b=3, c=4) service = factory() self.assertEquals(2, service.a) self.assertEquals(3, service.b) self.assertEquals(4, service.c) def testFactoryError(self): """Test misusing a factory.""" # Passing positional argument that is not accepted by class. self.assertRaises(TypeError, remote.Service.new_factory(1)) # Passing keyword argument that is not accepted by class. self.assertRaises(TypeError, remote.Service.new_factory(x=1)) class StatefulService(remote.Service): def __init__(self, a): pass # Missing required parameter. self.assertRaises(TypeError, StatefulService.new_factory()) def testDefinitionName(self): """Test getting service definition name.""" class TheService(remote.Service): pass module_name = test_util.get_module_name(ServiceTest) self.assertEqual(TheService.definition_name(), '%s.TheService' % module_name) self.assertTrue(TheService.outer_definition_name(), module_name) self.assertTrue(TheService.definition_package(), module_name) def testDefinitionNameWithPackage(self): """Test getting service definition name when package defined.""" global package package = 'my.package' try: class TheService(remote.Service): pass self.assertEquals('my.package.TheService', TheService.definition_name()) self.assertEquals('my.package', TheService.outer_definition_name()) self.assertEquals('my.package', TheService.definition_package()) finally: del package def testDefinitionNameWithNoModule(self): """Test getting service definition name when package defined.""" module = sys.modules[__name__] try: del sys.modules[__name__] class TheService(remote.Service): pass self.assertEquals('TheService', TheService.definition_name()) self.assertEquals(None, TheService.outer_definition_name()) self.assertEquals(None, TheService.definition_package()) finally: sys.modules[__name__] = module class StubTest(test_util.TestCase): def setUp(self): self.mox = mox.Mox() self.transport = self.mox.CreateMockAnything() def testDefinitionName(self): self.assertEquals(BasicService.definition_name(), BasicService.Stub.definition_name()) self.assertEquals(BasicService.outer_definition_name(), BasicService.Stub.outer_definition_name()) self.assertEquals(BasicService.definition_package(), BasicService.Stub.definition_package()) def testRemoteMethods(self): self.assertEquals(BasicService.all_remote_methods(), BasicService.Stub.all_remote_methods()) def testSync_WithRequest(self): stub = BasicService.Stub(self.transport) request = SimpleRequest() request.param1 = 'val1' request.param2 = 'val2' response = SimpleResponse() rpc = transport.Rpc(request) rpc.set_response(response) self.transport.send_rpc(BasicService.remote_method.remote, request).AndReturn(rpc) self.mox.ReplayAll() self.assertEquals(SimpleResponse(), stub.remote_method(request)) self.mox.VerifyAll() def testSync_WithKwargs(self): stub = BasicService.Stub(self.transport) request = SimpleRequest() request.param1 = 'val1' request.param2 = 'val2' response = SimpleResponse() rpc = transport.Rpc(request) rpc.set_response(response) self.transport.send_rpc(BasicService.remote_method.remote, request).AndReturn(rpc) self.mox.ReplayAll() self.assertEquals(SimpleResponse(), stub.remote_method(param1='val1', param2='val2')) self.mox.VerifyAll() def testAsync_WithRequest(self): stub = BasicService.Stub(self.transport) request = SimpleRequest() request.param1 = 'val1' request.param2 = 'val2' response = SimpleResponse() rpc = transport.Rpc(request) self.transport.send_rpc(BasicService.remote_method.remote, request).AndReturn(rpc) self.mox.ReplayAll() self.assertEquals(rpc, stub.async.remote_method(request)) self.mox.VerifyAll() def testAsync_WithKwargs(self): stub = BasicService.Stub(self.transport) request = SimpleRequest() request.param1 = 'val1' request.param2 = 'val2' response = SimpleResponse() rpc = transport.Rpc(request) self.transport.send_rpc(BasicService.remote_method.remote, request).AndReturn(rpc) self.mox.ReplayAll() self.assertEquals(rpc, stub.async.remote_method(param1='val1', param2='val2')) self.mox.VerifyAll() def testAsync_WithRequestAndKwargs(self): stub = BasicService.Stub(self.transport) request = SimpleRequest() request.param1 = 'val1' request.param2 = 'val2' response = SimpleResponse() self.mox.ReplayAll() self.assertRaisesWithRegexpMatch( TypeError, r'May not provide both args and kwargs', stub.async.remote_method, request, param1='val1', param2='val2') self.mox.VerifyAll() def testAsync_WithTooManyPositionals(self): stub = BasicService.Stub(self.transport) request = SimpleRequest() request.param1 = 'val1' request.param2 = 'val2' response = SimpleResponse() self.mox.ReplayAll() self.assertRaisesWithRegexpMatch( TypeError, r'remote_method\(\) takes at most 2 positional arguments \(3 given\)', stub.async.remote_method, request, 'another value') self.mox.VerifyAll() class IsErrorStatusTest(test_util.TestCase): def testIsError(self): for state in (s for s in remote.RpcState if s > remote.RpcState.RUNNING): status = remote.RpcStatus(state=state) self.assertTrue(remote.is_error_status(status)) def testIsNotError(self): for state in (s for s in remote.RpcState if s <= remote.RpcState.RUNNING): status = remote.RpcStatus(state=state) self.assertFalse(remote.is_error_status(status)) def testStateNone(self): self.assertRaises(messages.ValidationError, remote.is_error_status, remote.RpcStatus()) class CheckRpcStatusTest(test_util.TestCase): def testStateNone(self): self.assertRaises(messages.ValidationError, remote.check_rpc_status, remote.RpcStatus()) def testNoError(self): for state in (remote.RpcState.OK, remote.RpcState.RUNNING): remote.check_rpc_status(remote.RpcStatus(state=state)) def testErrorState(self): status = remote.RpcStatus(state=remote.RpcState.REQUEST_ERROR, error_message='a request error') self.assertRaisesWithRegexpMatch(remote.RequestError, 'a request error', remote.check_rpc_status, status) def testApplicationErrorState(self): status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR, error_message='an application error', error_name='blam') try: remote.check_rpc_status(status) self.fail('Should have raised application error.') except remote.ApplicationError, err: self.assertEquals('an application error', str(err)) self.assertEquals('blam', err.error_name) class ProtocolConfigTest(test_util.TestCase): def testConstructor(self): config = remote.ProtocolConfig( protojson, 'proto1', 'application/X-Json', iter(['text/Json', 'text/JavaScript'])) self.assertEquals(protojson, config.protocol) self.assertEquals('proto1', config.name) self.assertEquals('application/x-json', config.default_content_type) self.assertEquals(('text/json', 'text/javascript'), config.alternate_content_types) self.assertEquals(('application/x-json', 'text/json', 'text/javascript'), config.content_types) def testConstructorDefaults(self): config = remote.ProtocolConfig(protojson, 'proto2') self.assertEquals(protojson, config.protocol) self.assertEquals('proto2', config.name) self.assertEquals('application/json', config.default_content_type) self.assertEquals(('application/x-javascript', 'text/javascript', 'text/x-javascript', 'text/x-json', 'text/json'), config.alternate_content_types) self.assertEquals(('application/json', 'application/x-javascript', 'text/javascript', 'text/x-javascript', 'text/x-json', 'text/json'), config.content_types) def testEmptyAlternativeTypes(self): config = remote.ProtocolConfig(protojson, 'proto2', alternative_content_types=()) self.assertEquals(protojson, config.protocol) self.assertEquals('proto2', config.name) self.assertEquals('application/json', config.default_content_type) self.assertEquals((), config.alternate_content_types) self.assertEquals(('application/json',), config.content_types) def testDuplicateContentTypes(self): self.assertRaises(remote.ServiceConfigurationError, remote.ProtocolConfig, protojson, 'json', 'text/plain', ('text/plain',)) self.assertRaises(remote.ServiceConfigurationError, remote.ProtocolConfig, protojson, 'json', 'text/plain', ('text/html', 'text/html')) def testEncodeMessage(self): config = remote.ProtocolConfig(protojson, 'proto2') encoded_message = config.encode_message( remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, error_message='bad error')) # Convert back to a dictionary from JSON. dict_message = protojson.json.loads(encoded_message) self.assertEquals({'state': 'SERVER_ERROR', 'error_message': 'bad error'}, dict_message) def testDecodeMessage(self): config = remote.ProtocolConfig(protojson, 'proto2') self.assertEquals( remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, error_message="bad error"), config.decode_message( remote.RpcStatus, '{"state": "SERVER_ERROR", "error_message": "bad error"}')) class ProtocolsTest(test_util.TestCase): def setUp(self): self.protocols = remote.Protocols() def testEmpty(self): self.assertEquals((), self.protocols.names) self.assertEquals((), self.protocols.content_types) def testAddProtocolAllDefaults(self): self.protocols.add_protocol(protojson, 'json') self.assertEquals(('json',), self.protocols.names) self.assertEquals(('application/json', 'application/x-javascript', 'text/javascript', 'text/json', 'text/x-javascript', 'text/x-json'), self.protocols.content_types) def testAddProtocolNoDefaultAlternatives(self): class Protocol(object): CONTENT_TYPE = 'text/plain' self.protocols.add_protocol(Protocol, 'text') self.assertEquals(('text',), self.protocols.names) self.assertEquals(('text/plain',), self.protocols.content_types) def testAddProtocolOverrideDefaults(self): self.protocols.add_protocol(protojson, 'json', default_content_type='text/blar', alternative_content_types=('text/blam', 'text/blim')) self.assertEquals(('json',), self.protocols.names) self.assertEquals(('text/blam', 'text/blar', 'text/blim'), self.protocols.content_types) def testLookupByName(self): self.protocols.add_protocol(protojson, 'json') self.protocols.add_protocol(protojson, 'json2', default_content_type='text/plain', alternative_content_types=()) self.assertEquals('json', self.protocols.lookup_by_name('JsOn').name) self.assertEquals('json2', self.protocols.lookup_by_name('Json2').name) def testLookupByContentType(self): self.protocols.add_protocol(protojson, 'json') self.protocols.add_protocol(protojson, 'json2', default_content_type='text/plain', alternative_content_types=()) self.assertEquals( 'json', self.protocols.lookup_by_content_type('AppliCation/Json').name) self.assertEquals( 'json', self.protocols.lookup_by_content_type('text/x-Json').name) self.assertEquals( 'json2', self.protocols.lookup_by_content_type('text/Plain').name) def testNewDefault(self): protocols = remote.Protocols.new_default() self.assertEquals(('protobuf', 'protojson'), protocols.names) protobuf_protocol = protocols.lookup_by_name('protobuf') self.assertEquals(protobuf, protobuf_protocol.protocol) protojson_protocol = protocols.lookup_by_name('protojson') self.assertEquals(protojson.ProtoJson.get_default(), protojson_protocol.protocol) def testGetDefaultProtocols(self): protocols = remote.Protocols.get_default() self.assertEquals(('protobuf', 'protojson'), protocols.names) protobuf_protocol = protocols.lookup_by_name('protobuf') self.assertEquals(protobuf, protobuf_protocol.protocol) protojson_protocol = protocols.lookup_by_name('protojson') self.assertEquals(protojson.ProtoJson.get_default(), protojson_protocol.protocol) self.assertTrue(protocols is remote.Protocols.get_default()) def testSetDefaultProtocols(self): protocols = remote.Protocols() remote.Protocols.set_default(protocols) self.assertTrue(protocols is remote.Protocols.get_default()) def testSetDefaultWithoutProtocols(self): self.assertRaises(TypeError, remote.Protocols.set_default, None) self.assertRaises(TypeError, remote.Protocols.set_default, 'hi protocols') self.assertRaises(TypeError, remote.Protocols.set_default, {}) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/test_util.py0000755000076500000240000005376212277637135022763 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Test utilities for message testing. Includes module interface test to ensure that public parts of module are correctly declared in __all__. Includes message types that correspond to those defined in services_test.proto. Includes additional test utilities to make sure encoding/decoding libraries conform. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import cgi import datetime import inspect import os import re import socket import types import unittest from . import message_types from . import messages from . import util # Unicode of the word "Russian" in cyrillic. RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439' # All characters binary value interspersed with nulls. BINARY = ''.join(chr(value) + '\0' for value in range(256)) class TestCase(unittest.TestCase): def assertRaisesWithRegexpMatch(self, exception, regexp, function, *params, **kwargs): """Check that exception is raised and text matches regular expression. Args: exception: Exception type that is expected. regexp: String regular expression that is expected in error message. function: Callable to test. params: Parameters to forward to function. kwargs: Keyword arguments to forward to function. """ try: function(*params, **kwargs) self.fail('Expected exception %s was not raised' % exception.__name__) except exception, err: match = bool(re.match(regexp, str(err))) self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp, err)) def assertHeaderSame(self, header1, header2): """Check that two HTTP headers are the same. Args: header1: Header value string 1. header2: header value string 2. """ value1, params1 = cgi.parse_header(header1) value2, params2 = cgi.parse_header(header2) self.assertEqual(value1, value2) self.assertEqual(params1, params2) def assertIterEqual(self, iter1, iter2): """Check that two iterators or iterables are equal independent of order. Similar to Python 2.7 assertItemsEqual. Named differently in order to avoid potential conflict. Args: iter1: An iterator or iterable. iter2: An iterator or iterable. """ list1 = list(iter1) list2 = list(iter2) unmatched1 = list() while list1: item1 = list1[0] del list1[0] for index in range(len(list2)): if item1 == list2[index]: del list2[index] break else: unmatched1.append(item1) error_message = [] for item in unmatched1: error_message.append( ' Item from iter1 not found in iter2: %r' % item) for item in list2: error_message.append( ' Item from iter2 not found in iter1: %r' % item) if error_message: self.fail('Collections not equivalent:\n' + '\n'.join(error_message)) class ModuleInterfaceTest(object): """Test to ensure module interface is carefully constructed. A module interface is the set of public objects listed in the module __all__ attribute. Modules that that are considered public should have this interface carefully declared. At all times, the __all__ attribute should have objects intended to be publically used and all other objects in the module should be considered unused. Protected attributes (those beginning with '_') and other imported modules should not be part of this set of variables. An exception is for variables that begin and end with '__' which are implicitly part of the interface (eg. __name__, __file__, __all__ itself, etc.). Modules that are imported in to the tested modules are an exception and may be left out of the __all__ definition. The test is done by checking the value of what would otherwise be a public name and not allowing it to be exported if it is an instance of a module. Modules that are explicitly exported are for the time being not permitted. To use this test class a module should define a new class that inherits first from ModuleInterfaceTest and then from test_util.TestCase. No other tests should be added to this test case, making the order of inheritance less important, but if setUp for some reason is overidden, it is important that ModuleInterfaceTest is first in the list so that its setUp method is invoked. Multiple inheretance is required so that ModuleInterfaceTest is not itself a test, and is not itself executed as one. The test class is expected to have the following class attributes defined: MODULE: A reference to the module that is being validated for interface correctness. Example: Module definition (hello.py): import sys __all__ = ['hello'] def _get_outputter(): return sys.stdout def hello(): _get_outputter().write('Hello\n') Test definition: import unittest from protorpc import test_util import hello class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = hello class HelloTest(test_util.TestCase): ... Test 'hello' module ... if __name__ == '__main__': unittest.main() """ def setUp(self): """Set up makes sure that MODULE and IMPORTED_MODULES is defined. This is a basic configuration test for the test itself so does not get it's own test case. """ if not hasattr(self, 'MODULE'): self.fail( "You must define 'MODULE' on ModuleInterfaceTest sub-class %s." % type(self).__name__) def testAllExist(self): """Test that all attributes defined in __all__ exist.""" missing_attributes = [] for attribute in self.MODULE.__all__: if not hasattr(self.MODULE, attribute): missing_attributes.append(attribute) if missing_attributes: self.fail('%s of __all__ are not defined in module.' % missing_attributes) def testAllExported(self): """Test that all public attributes not imported are in __all__.""" missing_attributes = [] for attribute in dir(self.MODULE): if not attribute.startswith('_'): if (attribute not in self.MODULE.__all__ and not isinstance(getattr(self.MODULE, attribute), types.ModuleType) and attribute != 'with_statement'): missing_attributes.append(attribute) if missing_attributes: self.fail('%s are not modules and not defined in __all__.' % missing_attributes) def testNoExportedProtectedVariables(self): """Test that there are no protected variables listed in __all__.""" protected_variables = [] for attribute in self.MODULE.__all__: if attribute.startswith('_'): protected_variables.append(attribute) if protected_variables: self.fail('%s are protected variables and may not be exported.' % protected_variables) def testNoExportedModules(self): """Test that no modules exist in __all__.""" exported_modules = [] for attribute in self.MODULE.__all__: try: value = getattr(self.MODULE, attribute) except AttributeError: # This is a different error case tested for in testAllExist. pass else: if isinstance(value, types.ModuleType): exported_modules.append(attribute) if exported_modules: self.fail('%s are modules and may not be exported.' % exported_modules) class NestedMessage(messages.Message): """Simple message that gets nested in another message.""" a_value = messages.StringField(1, required=True) class HasNestedMessage(messages.Message): """Message that has another message nested in it.""" nested = messages.MessageField(NestedMessage, 1) repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True) class HasDefault(messages.Message): """Has a default value.""" a_value = messages.StringField(1, default=u'a default') class OptionalMessage(messages.Message): """Contains all message types.""" class SimpleEnum(messages.Enum): """Simple enumeration type.""" VAL1 = 1 VAL2 = 2 double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE) float_value = messages.FloatField(2, variant=messages.Variant.FLOAT) int64_value = messages.IntegerField(3, variant=messages.Variant.INT64) uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64) int32_value = messages.IntegerField(5, variant=messages.Variant.INT32) bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL) string_value = messages.StringField(7, variant=messages.Variant.STRING) bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES) enum_value = messages.EnumField(SimpleEnum, 10) # TODO(rafek): Add support for these variants. # uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32) # sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32) # sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64) class RepeatedMessage(messages.Message): """Contains all message types as repeated fields.""" class SimpleEnum(messages.Enum): """Simple enumeration type.""" VAL1 = 1 VAL2 = 2 double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE, repeated=True) float_value = messages.FloatField(2, variant=messages.Variant.FLOAT, repeated=True) int64_value = messages.IntegerField(3, variant=messages.Variant.INT64, repeated=True) uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64, repeated=True) int32_value = messages.IntegerField(5, variant=messages.Variant.INT32, repeated=True) bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL, repeated=True) string_value = messages.StringField(7, variant=messages.Variant.STRING, repeated=True) bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES, repeated=True) #uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32) enum_value = messages.EnumField(SimpleEnum, 10, repeated=True) #sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32) #sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64) class HasOptionalNestedMessage(messages.Message): nested = messages.MessageField(OptionalMessage, 1) repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True) class ProtoConformanceTestBase(object): """Protocol conformance test base class. Each supported protocol should implement two methods that support encoding and decoding of Message objects in that format: encode_message(message) - Serialize to encoding. encode_message(message, encoded_message) - Deserialize from encoding. Tests for the modules where these functions are implemented should extend this class in order to support basic behavioral expectations. This ensures that protocols correctly encode and decode message transparently to the caller. In order to support these test, the base class should also extend the TestCase class and implement the following class attributes which define the encoded version of certain protocol buffers: encoded_partial: encoded_full: encoded_repeated: encoded_nested: > encoded_repeated_nested: , ] > unexpected_tag_message: An encoded message that has an undefined tag or number in the stream. encoded_default_assigned: encoded_nested_empty: > encoded_invalid_enum: """ encoded_empty_message = '' def testEncodeInvalidMessage(self): message = NestedMessage() self.assertRaises(messages.ValidationError, self.PROTOLIB.encode_message, message) def CompareEncoded(self, expected_encoded, actual_encoded): """Compare two encoded protocol values. Can be overridden by sub-classes to special case comparison. For example, to eliminate white space from output that is not relevant to encoding. Args: expected_encoded: Expected string encoded value. actual_encoded: Actual string encoded value. """ self.assertEquals(expected_encoded, actual_encoded) def EncodeDecode(self, encoded, expected_message): message = self.PROTOLIB.decode_message(type(expected_message), encoded) self.assertEquals(expected_message, message) self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message)) def testEmptyMessage(self): self.EncodeDecode(self.encoded_empty_message, OptionalMessage()) def testPartial(self): """Test message with a few values set.""" message = OptionalMessage() message.double_value = 1.23 message.int64_value = -100000000000 message.int32_value = 1020 message.string_value = u'a string' message.enum_value = OptionalMessage.SimpleEnum.VAL2 self.EncodeDecode(self.encoded_partial, message) def testFull(self): """Test all types.""" message = OptionalMessage() message.double_value = 1.23 message.float_value = -2.5 message.int64_value = -100000000000 message.uint64_value = 102020202020 message.int32_value = 1020 message.bool_value = True message.string_value = u'a string\u044f' message.bytes_value = 'a bytes\xff\xfe' message.enum_value = OptionalMessage.SimpleEnum.VAL2 self.EncodeDecode(self.encoded_full, message) def testRepeated(self): """Test repeated fields.""" message = RepeatedMessage() message.double_value = [1.23, 2.3] message.float_value = [-2.5, 0.5] message.int64_value = [-100000000000, 20] message.uint64_value = [102020202020, 10] message.int32_value = [1020, 718] message.bool_value = [True, False] message.string_value = [u'a string\u044f', u'another string'] message.bytes_value = ['a bytes\xff\xfe', 'another bytes'] message.enum_value = [RepeatedMessage.SimpleEnum.VAL2, RepeatedMessage.SimpleEnum.VAL1] self.EncodeDecode(self.encoded_repeated, message) def testNested(self): """Test nested messages.""" nested_message = NestedMessage() nested_message.a_value = u'a string' message = HasNestedMessage() message.nested = nested_message self.EncodeDecode(self.encoded_nested, message) def testRepeatedNested(self): """Test repeated nested messages.""" nested_message1 = NestedMessage() nested_message1.a_value = u'a string' nested_message2 = NestedMessage() nested_message2.a_value = u'another string' message = HasNestedMessage() message.repeated_nested = [nested_message1, nested_message2] self.EncodeDecode(self.encoded_repeated_nested, message) def testStringTypes(self): """Test that encoding str on StringField works.""" message = OptionalMessage() message.string_value = 'Latin' self.EncodeDecode(self.encoded_string_types, message) def testEncodeUninitialized(self): """Test that cannot encode uninitialized message.""" required = NestedMessage() self.assertRaisesWithRegexpMatch(messages.ValidationError, "Message NestedMessage is missing " "required field a_value", self.PROTOLIB.encode_message, required) def testUnexpectedField(self): """Test decoding and encoding unexpected fields.""" loaded_message = self.PROTOLIB.decode_message(OptionalMessage, self.unexpected_tag_message) # Message should be equal to an empty message, since unknown values aren't # included in equality. self.assertEquals(OptionalMessage(), loaded_message) # Verify that the encoded message matches the source, including the # unknown value. self.assertEquals(self.unexpected_tag_message, self.PROTOLIB.encode_message(loaded_message)) def testDoNotSendDefault(self): """Test that default is not sent when nothing is assigned.""" self.EncodeDecode(self.encoded_empty_message, HasDefault()) def testSendDefaultExplicitlyAssigned(self): """Test that default is sent when explcitly assigned.""" message = HasDefault() message.a_value = HasDefault.a_value.default self.EncodeDecode(self.encoded_default_assigned, message) def testEncodingNestedEmptyMessage(self): """Test encoding a nested empty message.""" message = HasOptionalNestedMessage() message.nested = OptionalMessage() self.EncodeDecode(self.encoded_nested_empty, message) def testEncodingRepeatedNestedEmptyMessage(self): """Test encoding a nested empty message.""" message = HasOptionalNestedMessage() message.repeated_nested = [OptionalMessage(), OptionalMessage()] self.EncodeDecode(self.encoded_repeated_nested_empty, message) def testContentType(self): self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str)) def testDecodeInvalidEnumType(self): self.assertRaisesWithRegexpMatch(messages.DecodeError, 'Invalid enum value ', self.PROTOLIB.decode_message, OptionalMessage, self.encoded_invalid_enum) def testDateTimeNoTimeZone(self): """Test that DateTimeFields are encoded/decoded correctly.""" class MyMessage(messages.Message): value = message_types.DateTimeField(1) value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000) message = MyMessage(value=value) decoded = self.PROTOLIB.decode_message( MyMessage, self.PROTOLIB.encode_message(message)) self.assertEquals(decoded.value, value) def testDateTimeWithTimeZone(self): """Test DateTimeFields with time zones.""" class MyMessage(messages.Message): value = message_types.DateTimeField(1) value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000, util.TimeZoneOffset(8 * 60)) message = MyMessage(value=value) decoded = self.PROTOLIB.decode_message( MyMessage, self.PROTOLIB.encode_message(message)) self.assertEquals(decoded.value, value) def do_with(context, function, *args, **kwargs): """Simulate a with statement. Avoids need to import with from future. Does not support simulation of 'as'. Args: context: Context object normally used with 'with'. function: Callable to evoke. Replaces with-block. """ context.__enter__() try: function(*args, **kwargs) except: context.__exit__(*sys.exc_info()) finally: context.__exit__(None, None, None) def pick_unused_port(): """Find an unused port to use in tests. Derived from Damon Kohlers example: http://code.activestate.com/recipes/531822-pick-unused-port """ temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: temp.bind(('localhost', 0)) port = temp.getsockname()[1] finally: temp.close() return port def get_module_name(module_attribute): """Get the module name. Args: module_attribute: An attribute of the module. Returns: The fully qualified module name or simple module name where 'module_attribute' is defined if the module name is "__main__". """ if module_attribute.__module__ == '__main__': module_file = inspect.getfile(module_attribute) default = os.path.basename(module_file).split('.')[0] return default else: return module_attribute.__module__ protorpc-standalone-0.9.1/protorpc/transport.py0000755000076500000240000003100512277637135022765 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Transport library for ProtoRPC. Contains underlying infrastructure used for communicating RPCs over low level transports such as HTTP. Includes HTTP transport built over urllib2. """ import httplib import logging import os import socket import sys import urlparse from . import messages from . import protobuf from . import remote from . import util __all__ = [ 'RpcStateError', 'HttpTransport', 'LocalTransport', 'Rpc', 'Transport', ] class RpcStateError(messages.Error): """Raised when trying to put RPC in to an invalid state.""" class Rpc(object): """Represents a client side RPC. An RPC is created by the transport class and is used with a single RPC. While an RPC is still in process, the response is set to None. When it is complete the response will contain the response message. """ def __init__(self, request): """Constructor. Args: request: Request associated with this RPC. """ self.__request = request self.__response = None self.__state = remote.RpcState.RUNNING self.__error_message = None self.__error_name = None @property def request(self): """Request associated with RPC.""" return self.__request @property def response(self): """Response associated with RPC.""" self.wait() self.__check_status() return self.__response @property def state(self): """State associated with RPC.""" return self.__state @property def error_message(self): """Error, if any, associated with RPC.""" self.wait() return self.__error_message @property def error_name(self): """Error name, if any, associated with RPC.""" self.wait() return self.__error_name def wait(self): """Wait for an RPC to finish.""" if self.__state == remote.RpcState.RUNNING: self._wait_impl() def _wait_impl(self): """Implementation for wait().""" raise NotImplementedError() def __check_status(self): error_class = remote.RpcError.from_state(self.__state) if error_class is not None: if error_class is remote.ApplicationError: raise error_class(self.__error_message, self.__error_name) else: raise error_class(self.__error_message) def __set_state(self, state, error_message=None, error_name=None): if self.__state != remote.RpcState.RUNNING: raise RpcStateError( 'RPC must be in RUNNING state to change to %s' % state) if state == remote.RpcState.RUNNING: raise RpcStateError('RPC is already in RUNNING state') self.__state = state self.__error_message = error_message self.__error_name = error_name def set_response(self, response): # TODO: Even more specific type checking. if not isinstance(response, messages.Message): raise TypeError('Expected Message type, received %r' % (response)) self.__response = response self.__set_state(remote.RpcState.OK) def set_status(self, status): status.check_initialized() self.__set_state(status.state, status.error_message, status.error_name) class Transport(object): """Transport base class. Provides basic support for implementing a ProtoRPC transport such as one that can send and receive messages over HTTP. Implementations override _start_rpc. This method receives a RemoteInfo instance and a request Message. The transport is expected to set the rpc response or raise an exception before termination. """ @util.positional(1) def __init__(self, protocol=protobuf): """Constructor. Args: protocol: If string, will look up a protocol from the default Protocols instance by name. Can also be an instance of remote.ProtocolConfig. If neither, it must be an object that implements a protocol interface by implementing encode_message, decode_message and set CONTENT_TYPE. For example, the modules protobuf and protojson can be used directly. """ if isinstance(protocol, basestring): protocols = remote.Protocols.get_default() try: protocol = protocols.lookup_by_name(protocol) except KeyError: protocol = protocols.lookup_by_content_type(protocol) if isinstance(protocol, remote.ProtocolConfig): self.__protocol = protocol.protocol self.__protocol_config = protocol else: self.__protocol = protocol self.__protocol_config = remote.ProtocolConfig( protocol, 'default', default_content_type=protocol.CONTENT_TYPE) @property def protocol(self): """Protocol associated with this transport.""" return self.__protocol @property def protocol_config(self): """Protocol associated with this transport.""" return self.__protocol_config def send_rpc(self, remote_info, request): """Initiate sending an RPC over the transport. Args: remote_info: RemoteInfo instance describing remote method. request: Request message to send to service. Returns: An Rpc instance intialized with the request.. """ request.check_initialized() rpc = self._start_rpc(remote_info, request) return rpc def _start_rpc(self, remote_info, request): """Start a remote procedure call. Args: remote_info: RemoteInfo instance describing remote method. request: Request message to send to service. Returns: An Rpc instance initialized with the request. """ raise NotImplementedError() class HttpTransport(Transport): """Transport for communicating with HTTP servers.""" @util.positional(2) def __init__(self, service_url, protocol=protobuf): """Constructor. Args: service_url: URL where the service is located. All communication via the transport will go to this URL. protocol: The protocol implementation. Must implement encode_message and decode_message. Can also be an instance of remote.ProtocolConfig. """ super(HttpTransport, self).__init__(protocol=protocol) self.__service_url = service_url def __get_rpc_status(self, response, content): """Get RPC status from HTTP response. Args: response: HTTPResponse object. content: Content read from HTTP response. Returns: RpcStatus object parsed from response, else an RpcStatus with a generic HTTP error. """ # Status above 400 may have RpcStatus content. if response.status >= 400: content_type = response.getheader('content-type') if content_type == self.protocol_config.default_content_type: try: rpc_status = self.protocol.decode_message(remote.RpcStatus, content) except Exception, decode_err: logging.warning( 'An error occurred trying to parse status: %s\n%s', str(decode_err), content) else: if rpc_status.is_initialized(): return rpc_status else: logging.warning( 'Body does not result in an initialized RpcStatus message:\n%s', content) # If no RpcStatus message present, attempt to forward any content. If empty # use standard error message. if not content.strip(): content = httplib.responses.get(response.status, 'Unknown Error') return remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, error_message='HTTP Error %s: %s' % ( response.status, content or 'Unknown Error')) def __set_response(self, remote_info, connection, rpc): """Set response on RPC. Sets response or status from HTTP request. Implements the wait method of Rpc instance. Args: remote_info: Remote info for invoked RPC. connection: HTTPConnection that is making request. rpc: Rpc instance. """ try: response = connection.getresponse() content = response.read() if response.status == httplib.OK: response = self.protocol.decode_message(remote_info.response_type, content) rpc.set_response(response) else: status = self.__get_rpc_status(response, content) rpc.set_status(status) finally: connection.close() def _start_rpc(self, remote_info, request): """Start a remote procedure call. Args: remote_info: A RemoteInfo instance for this RPC. request: The request message for this RPC. Returns: An Rpc instance initialized with a Request. """ method_url = '%s.%s' % (self.__service_url, remote_info.method.func_name) encoded_request = self.protocol.encode_message(request) url = urlparse.urlparse(method_url) if url.scheme == 'https': connection_type = httplib.HTTPSConnection else: connection_type = httplib.HTTPConnection connection = connection_type(url.hostname, url.port) try: self._send_http_request(connection, url.path, encoded_request) rpc = Rpc(request) except remote.RpcError: # Pass through all ProtoRPC errors connection.close() raise except socket.error, err: connection.close() raise remote.NetworkError('Socket error: %s %r' % (type(err).__name__, err.args), err) except Exception, err: connection.close() raise remote.NetworkError('Error communicating with HTTP server', err) else: wait_impl = lambda: self.__set_response(remote_info, connection, rpc) rpc._wait_impl = wait_impl return rpc def _send_http_request(self, connection, http_path, encoded_request): connection.request( 'POST', http_path, encoded_request, headers={'Content-type': self.protocol_config.default_content_type, 'Content-length': len(encoded_request)}) class LocalTransport(Transport): """Local transport that sends messages directly to services. Useful in tests or creating code that can work with either local or remote services. Using LocalTransport is preferrable to simply instantiating a single instance of a service and reusing it. The entire request process involves instantiating a new instance of a service, initializing it with request state and then invoking the remote method for every request. """ def __init__(self, service_factory): """Constructor. Args: service_factory: Service factory or class. """ super(LocalTransport, self).__init__() self.__service_class = getattr(service_factory, 'service_class', service_factory) self.__service_factory = service_factory @property def service_class(self): return self.__service_class @property def service_factory(self): return self.__service_factory def _start_rpc(self, remote_info, request): """Start a remote procedure call. Args: remote_info: RemoteInfo instance describing remote method. request: Request message to send to service. Returns: An Rpc instance initialized with the request. """ rpc = Rpc(request) def wait_impl(): instance = self.__service_factory() try: initalize_request_state = instance.initialize_request_state except AttributeError: pass else: host = unicode(os.uname()[1]) initalize_request_state(remote.RequestState(remote_host=host, remote_address=u'127.0.0.1', server_host=host, server_port=-1)) try: response = remote_info.method(instance, request) assert isinstance(response, remote_info.response_type) except remote.ApplicationError: raise except: exc_type, exc_value, traceback = sys.exc_info() message = 'Unexpected error %s: %s' % (exc_type.__name__, exc_value) raise remote.ServerError, message, traceback rpc.set_response(response) rpc._wait_impl = wait_impl return rpc protorpc-standalone-0.9.1/protorpc/transport_test.py0000755000076500000240000004034012277637135024026 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import httplib import os import socket import unittest from protorpc import messages from protorpc import protobuf from protorpc import protojson from protorpc import remote from protorpc import test_util from protorpc import transport from protorpc import webapp_test_util from protorpc.wsgi import util as wsgi_util import mox package = 'transport_test' class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = transport class Message(messages.Message): value = messages.StringField(1) class Service(remote.Service): @remote.method(Message, Message) def method(self, request): pass # Remove when RPC is no longer subclasses. class TestRpc(transport.Rpc): waited = False def _wait_impl(self): self.waited = True class RpcTest(test_util.TestCase): def setUp(self): self.request = Message(value=u'request') self.response = Message(value=u'response') self.status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR, error_message='an error', error_name='blam') self.rpc = TestRpc(self.request) def testConstructor(self): self.assertEquals(self.request, self.rpc.request) self.assertEquals(remote.RpcState.RUNNING, self.rpc.state) self.assertEquals(None, self.rpc.error_message) self.assertEquals(None, self.rpc.error_name) def response(self): self.assertFalse(self.rpc.waited) self.assertEquals(None, self.rpc.response) self.assertTrue(self.rpc.waited) def testSetResponse(self): self.rpc.set_response(self.response) self.assertEquals(self.request, self.rpc.request) self.assertEquals(remote.RpcState.OK, self.rpc.state) self.assertEquals(self.response, self.rpc.response) self.assertEquals(None, self.rpc.error_message) self.assertEquals(None, self.rpc.error_name) def testSetResponseAlreadySet(self): self.rpc.set_response(self.response) self.assertRaisesWithRegexpMatch( transport.RpcStateError, 'RPC must be in RUNNING state to change to OK', self.rpc.set_response, self.response) def testSetResponseAlreadyError(self): self.rpc.set_status(self.status) self.assertRaisesWithRegexpMatch( transport.RpcStateError, 'RPC must be in RUNNING state to change to OK', self.rpc.set_response, self.response) def testSetStatus(self): self.rpc.set_status(self.status) self.assertEquals(self.request, self.rpc.request) self.assertEquals(remote.RpcState.APPLICATION_ERROR, self.rpc.state) self.assertEquals('an error', self.rpc.error_message) self.assertEquals('blam', self.rpc.error_name) self.assertRaisesWithRegexpMatch(remote.ApplicationError, 'an error', getattr, self.rpc, 'response') def testSetStatusAlreadySet(self): self.rpc.set_response(self.response) self.assertRaisesWithRegexpMatch( transport.RpcStateError, 'RPC must be in RUNNING state to change to OK', self.rpc.set_response, self.response) def testSetNonMessage(self): self.assertRaisesWithRegexpMatch( TypeError, 'Expected Message type, received 10', self.rpc.set_response, 10) def testSetStatusAlreadyError(self): self.rpc.set_status(self.status) self.assertRaisesWithRegexpMatch( transport.RpcStateError, 'RPC must be in RUNNING state to change to OK', self.rpc.set_response, self.response) def testSetUninitializedStatus(self): self.assertRaises(messages.ValidationError, self.rpc.set_status, remote.RpcStatus()) class TransportTest(test_util.TestCase): def setUp(self): remote.Protocols.set_default(remote.Protocols.new_default()) def do_test(self, protocol, trans): request = Message() request.value = u'request' response = Message() response.value = u'response' encoded_request = protocol.encode_message(request) encoded_response = protocol.encode_message(response) self.assertEquals(protocol, trans.protocol) received_rpc = [None] def transport_rpc(remote, rpc_request): self.assertEquals(remote, Service.method.remote) self.assertEquals(request, rpc_request) rpc = TestRpc(request) rpc.set_response(response) return rpc trans._start_rpc = transport_rpc rpc = trans.send_rpc(Service.method.remote, request) self.assertEquals(response, rpc.response) def testDefaultProtocol(self): trans = transport.Transport() self.do_test(protobuf, trans) self.assertEquals(protobuf, trans.protocol_config.protocol) self.assertEquals('default', trans.protocol_config.name) def testAlternateProtocol(self): trans = transport.Transport(protocol=protojson) self.do_test(protojson, trans) self.assertEquals(protojson, trans.protocol_config.protocol) self.assertEquals('default', trans.protocol_config.name) def testProtocolConfig(self): protocol_config = remote.ProtocolConfig( protojson, 'protoconfig', 'image/png') trans = transport.Transport(protocol=protocol_config) self.do_test(protojson, trans) self.assertTrue(trans.protocol_config is protocol_config) def testProtocolByName(self): remote.Protocols.get_default().add_protocol( protojson, 'png', 'image/png', ()) trans = transport.Transport(protocol='png') self.do_test(protojson, trans) @remote.method(Message, Message) def my_method(self, request): self.fail('self.my_method should not be directly invoked.') class FakeConnectionClass(object): def __init__(self, mox): self.request = mox.CreateMockAnything() self.response = mox.CreateMockAnything() class HttpTransportTest(webapp_test_util.WebServerTestBase): def setUp(self): # Do not need much parent construction functionality. self.schema = 'http' self.server = None self.request = Message(value=u'The request value') self.encoded_request = protojson.encode_message(self.request) self.response = Message(value=u'The response value') self.encoded_response = protojson.encode_message(self.response) def testCallSucceeds(self): self.ResetServer(wsgi_util.static_page(self.encoded_response, content_type='application/json')) rpc = self.connection.send_rpc(my_method.remote, self.request) self.assertEquals(self.response, rpc.response) def testHttps(self): self.schema = 'https' self.ResetServer(wsgi_util.static_page(self.encoded_response, content_type='application/json')) # Create a fake https connection function that really just calls http. self.used_https = False def https_connection(*args, **kwargs): self.used_https = True return httplib.HTTPConnection(*args, **kwargs) original_https_connection = httplib.HTTPSConnection httplib.HTTPSConnection = https_connection try: rpc = self.connection.send_rpc(my_method.remote, self.request) finally: httplib.HTTPSConnection = original_https_connection self.assertEquals(self.response, rpc.response) self.assertTrue(self.used_https) def testHttpSocketError(self): self.ResetServer(wsgi_util.static_page(self.encoded_response, content_type='application/json')) bad_transport = transport.HttpTransport('http://localhost:-1/blar') try: bad_transport.send_rpc(my_method.remote, self.request) except remote.NetworkError, err: self.assertTrue(str(err).startswith('Socket error: gaierror (')) self.assertEquals(socket.gaierror, type(err.cause)) self.assertEquals(8, abs(err.cause.args[0])) # Sign is sys depednent. else: self.fail('Expected error') def testHttpRequestError(self): self.ResetServer(wsgi_util.static_page(self.encoded_response, content_type='application/json')) def request_error(*args, **kwargs): raise TypeError('Generic Error') original_request = httplib.HTTPConnection.request httplib.HTTPConnection.request = request_error try: try: self.connection.send_rpc(my_method.remote, self.request) except remote.NetworkError, err: self.assertEquals('Error communicating with HTTP server', str(err)) self.assertEquals(TypeError, type(err.cause)) self.assertEquals('Generic Error', str(err.cause)) else: self.fail('Expected error') finally: httplib.HTTPConnection.request = original_request def testHandleGenericServiceError(self): self.ResetServer(wsgi_util.error(httplib.INTERNAL_SERVER_ERROR, 'arbitrary error', content_type='text/plain')) rpc = self.connection.send_rpc(my_method.remote, self.request) try: rpc.response except remote.ServerError, err: self.assertEquals('HTTP Error 500: arbitrary error', str(err).strip()) else: self.fail('Expected ServerError') def testHandleGenericServiceErrorNoMessage(self): self.ResetServer(wsgi_util.error(httplib.NOT_IMPLEMENTED, ' ', content_type='text/plain')) rpc = self.connection.send_rpc(my_method.remote, self.request) try: rpc.response except remote.ServerError, err: self.assertEquals('HTTP Error 501: Not Implemented', str(err).strip()) else: self.fail('Expected ServerError') def testHandleStatusContent(self): self.ResetServer(wsgi_util.static_page('{"state": "REQUEST_ERROR",' ' "error_message": "a request error"' '}', status=httplib.BAD_REQUEST, content_type='application/json')) rpc = self.connection.send_rpc(my_method.remote, self.request) try: rpc.response except remote.RequestError, err: self.assertEquals('a request error', str(err)) else: self.fail('Expected RequestError') def testHandleApplicationError(self): self.ResetServer(wsgi_util.static_page('{"state": "APPLICATION_ERROR",' ' "error_message": "an app error",' ' "error_name": "MY_ERROR_NAME"}', status=httplib.BAD_REQUEST, content_type='application/json')) rpc = self.connection.send_rpc(my_method.remote, self.request) try: rpc.response except remote.ApplicationError, err: self.assertEquals('an app error', str(err)) self.assertEquals('MY_ERROR_NAME', err.error_name) else: self.fail('Expected RequestError') def testHandleUnparsableErrorContent(self): self.ResetServer(wsgi_util.static_page('oops', status=httplib.BAD_REQUEST, content_type='application/json')) rpc = self.connection.send_rpc(my_method.remote, self.request) try: rpc.response except remote.ServerError, err: self.assertEquals('HTTP Error 400: oops', str(err)) else: self.fail('Expected ServerError') def testHandleEmptyBadRpcStatus(self): self.ResetServer(wsgi_util.static_page('{"error_message": "x"}', status=httplib.BAD_REQUEST, content_type='application/json')) rpc = self.connection.send_rpc(my_method.remote, self.request) try: rpc.response except remote.ServerError, err: self.assertEquals('HTTP Error 400: {"error_message": "x"}', str(err)) else: self.fail('Expected ServerError') def testUseProtocolConfigContentType(self): expected_content_type = 'image/png' def expect_content_type(environ, start_response): self.assertEquals(expected_content_type, environ['CONTENT_TYPE']) app = wsgi_util.static_page('', content_type=environ['CONTENT_TYPE']) return app(environ, start_response) self.ResetServer(expect_content_type) protocol_config = remote.ProtocolConfig(protojson, 'json', 'image/png') self.connection = self.CreateTransport(self.service_url, protocol_config) rpc = self.connection.send_rpc(my_method.remote, self.request) self.assertEquals(Message(), rpc.response) class SimpleRequest(messages.Message): content = messages.StringField(1) class SimpleResponse(messages.Message): content = messages.StringField(1) factory_value = messages.StringField(2) remote_host = messages.StringField(3) remote_address = messages.StringField(4) server_host = messages.StringField(5) server_port = messages.IntegerField(6) class LocalService(remote.Service): def __init__(self, factory_value='default'): self.factory_value = factory_value @remote.method(SimpleRequest, SimpleResponse) def call_method(self, request): return SimpleResponse(content=request.content, factory_value=self.factory_value, remote_host=self.request_state.remote_host, remote_address=self.request_state.remote_address, server_host=self.request_state.server_host, server_port=self.request_state.server_port) @remote.method() def raise_totally_unexpected(self, request): raise TypeError('Kablam') @remote.method() def raise_unexpected(self, request): raise remote.RequestError('Huh?') @remote.method() def raise_application_error(self, request): raise remote.ApplicationError('App error', 10) class LocalTransportTest(test_util.TestCase): def CreateService(self, factory_value='default'): return def testBasicCallWithClass(self): stub = LocalService.Stub(transport.LocalTransport(LocalService)) response = stub.call_method(content='Hello') self.assertEquals(SimpleResponse(content='Hello', factory_value='default', remote_host=os.uname()[1], remote_address='127.0.0.1', server_host=os.uname()[1], server_port=-1), response) def testBasicCallWithFactory(self): stub = LocalService.Stub( transport.LocalTransport(LocalService.new_factory('assigned'))) response = stub.call_method(content='Hello') self.assertEquals(SimpleResponse(content='Hello', factory_value='assigned', remote_host=os.uname()[1], remote_address='127.0.0.1', server_host=os.uname()[1], server_port=-1), response) def testTotallyUnexpectedError(self): stub = LocalService.Stub(transport.LocalTransport(LocalService)) self.assertRaisesWithRegexpMatch( remote.ServerError, 'Unexpected error TypeError: Kablam', stub.raise_totally_unexpected) def testUnexpectedError(self): stub = LocalService.Stub(transport.LocalTransport(LocalService)) self.assertRaisesWithRegexpMatch( remote.ServerError, 'Unexpected error RequestError: Huh?', stub.raise_unexpected) def testApplicationError(self): stub = LocalService.Stub(transport.LocalTransport(LocalService)) self.assertRaisesWithRegexpMatch( remote.ApplicationError, 'App error', stub.raise_application_error) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/util.py0000755000076500000240000003444212277637135021716 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Common utility library.""" from __future__ import with_statement __author__ = ['rafek@google.com (Rafe Kaplan)', 'guido@google.com (Guido van Rossum)', ] import cgi import datetime import inspect import os import re import sys __all__ = ['AcceptItem', 'AcceptError', 'Error', 'choose_content_type', 'decode_datetime', 'get_package_for_module', 'pad_string', 'parse_accept_header', 'positional', 'PROTORPC_PROJECT_URL', 'TimeZoneOffset', ] class Error(Exception): """Base class for protorpc exceptions.""" class AcceptError(Error): """Raised when there is an error parsing the accept header.""" PROTORPC_PROJECT_URL = 'http://code.google.com/p/google-protorpc' _TIME_ZONE_RE_STRING = r""" # Examples: # +01:00 # -05:30 # Z12:00 ((?PZ) | (?P[-+]) (?P\d\d) : (?P\d\d))$ """ _TIME_ZONE_RE = re.compile(_TIME_ZONE_RE_STRING, re.IGNORECASE | re.VERBOSE) def pad_string(string): """Pad a string for safe HTTP error responses. Prevents Internet Explorer from displaying their own error messages when sent as the content of error responses. Args: string: A string. Returns: Formatted string left justified within a 512 byte field. """ return string.ljust(512) def positional(max_positional_args): """A decorator to declare that only the first N arguments may be positional. This decorator makes it easy to support Python 3 style keyword-only parameters. For example, in Python 3 it is possible to write: def fn(pos1, *, kwonly1=None, kwonly1=None): ... All named parameters after * must be a keyword: fn(10, 'kw1', 'kw2') # Raises exception. fn(10, kwonly1='kw1') # Ok. Example: To define a function like above, do: @positional(1) def fn(pos1, kwonly1=None, kwonly2=None): ... If no default value is provided to a keyword argument, it becomes a required keyword argument: @positional(0) def fn(required_kw): ... This must be called with the keyword parameter: fn() # Raises exception. fn(10) # Raises exception. fn(required_kw=10) # Ok. When defining instance or class methods always remember to account for 'self' and 'cls': class MyClass(object): @positional(2) def my_method(self, pos1, kwonly1=None): ... @classmethod @positional(2) def my_method(cls, pos1, kwonly1=None): ... One can omit the argument to 'positional' altogether, and then no arguments with default values may be passed positionally. This would be equivalent to placing a '*' before the first argument with a default value in Python 3. If there are no arguments with default values, and no argument is given to 'positional', an error is raised. @positional def fn(arg1, arg2, required_kw1=None, required_kw2=0): ... fn(1, 3, 5) # Raises exception. fn(1, 3) # Ok. fn(1, 3, required_kw1=5) # Ok. Args: max_positional_arguments: Maximum number of positional arguments. All parameters after the this index must be keyword only. Returns: A decorator that prevents using arguments after max_positional_args from being used as positional parameters. Raises: TypeError if a keyword-only argument is provided as a positional parameter. ValueError if no maximum number of arguments is provided and the function has no arguments with default values. """ def positional_decorator(wrapped): def positional_wrapper(*args, **kwargs): if len(args) > max_positional_args: plural_s = '' if max_positional_args != 1: plural_s = 's' raise TypeError('%s() takes at most %d positional argument%s ' '(%d given)' % (wrapped.__name__, max_positional_args, plural_s, len(args))) return wrapped(*args, **kwargs) return positional_wrapper if isinstance(max_positional_args, (int, long)): return positional_decorator else: args, _, _, defaults = inspect.getargspec(max_positional_args) if defaults is None: raise ValueError( 'Functions with no keyword arguments must specify ' 'max_positional_args') return positional(len(args) - len(defaults))(max_positional_args) # TODO(rafek): Support 'level' from the Accept header standard. class AcceptItem(object): """Encapsulate a single entry of an Accept header. Parses and extracts relevent values from an Accept header and implements a sort order based on the priority of each requested type as defined here: http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html Accept headers are normally a list of comma separated items. Each item has the format of a normal HTTP header. For example: Accept: text/plain, text/html, text/*, */* This header means to prefer plain text over HTML, HTML over any other kind of text and text over any other kind of supported format. This class does not attempt to parse the list of items from the Accept header. The constructor expects the unparsed sub header and the index within the Accept header that the fragment was found. Properties: index: The index that this accept item was found in the Accept header. main_type: The main type of the content type. sub_type: The sub type of the content type. q: The q value extracted from the header as a float. If there is no q value, defaults to 1.0. values: All header attributes parsed form the sub-header. sort_key: A tuple (no_main_type, no_sub_type, q, no_values, index): no_main_type: */* has the least priority. no_sub_type: Items with no sub-type have less priority. q: Items with lower q value have less priority. no_values: Items with no values have less priority. index: Index of item in accept header is the last priority. """ __CONTENT_TYPE_REGEX = re.compile(r'^([^/]+)/([^/]+)$') def __init__(self, accept_header, index): """Parse component of an Accept header. Args: accept_header: Unparsed sub-expression of accept header. index: The index that this accept item was found in the Accept header. """ accept_header = accept_header.lower() content_type, values = cgi.parse_header(accept_header) match = self.__CONTENT_TYPE_REGEX.match(content_type) if not match: raise AcceptError('Not valid Accept header: %s' % accept_header) self.__index = index self.__main_type = match.group(1) self.__sub_type = match.group(2) self.__q = float(values.get('q', 1)) self.__values = values if self.__main_type == '*': self.__main_type = None if self.__sub_type == '*': self.__sub_type = None self.__sort_key = (not self.__main_type, not self.__sub_type, -self.__q, not self.__values, self.__index) @property def index(self): return self.__index @property def main_type(self): return self.__main_type @property def sub_type(self): return self.__sub_type @property def q(self): return self.__q @property def values(self): """Copy the dictionary of values parsed from the header fragment.""" return dict(self.__values) @property def sort_key(self): return self.__sort_key def match(self, content_type): """Determine if the given accept header matches content type. Args: content_type: Unparsed content type string. Returns: True if accept header matches content type, else False. """ content_type, _ = cgi.parse_header(content_type) match = self.__CONTENT_TYPE_REGEX.match(content_type.lower()) if not match: return False main_type, sub_type = match.group(1), match.group(2) if not(main_type and sub_type): return False return ((self.__main_type is None or self.__main_type == main_type) and (self.__sub_type is None or self.__sub_type == sub_type)) def __cmp__(self, other): """Comparison operator based on sort keys.""" if not isinstance(other, AcceptItem): return NotImplemented return cmp(self.sort_key, other.sort_key) def __str__(self): """Rebuilds Accept header.""" content_type = '%s/%s' % (self.__main_type or '*', self.__sub_type or '*') values = self.values if values: value_strings = ['%s=%s' % (i, v) for i, v in values.iteritems()] return '%s; %s' % (content_type, '; '.join(value_strings)) else: return content_type def __repr__(self): return 'AcceptItem(%r, %d)' % (str(self), self.__index) def parse_accept_header(accept_header): """Parse accept header. Args: accept_header: Unparsed accept header. Does not include name of header. Returns: List of AcceptItem instances sorted according to their priority. """ accept_items = [] for index, header in enumerate(accept_header.split(',')): accept_items.append(AcceptItem(header, index)) return sorted(accept_items) def choose_content_type(accept_header, supported_types): """Choose most appropriate supported type based on what client accepts. Args: accept_header: Unparsed accept header. Does not include name of header. supported_types: List of content-types supported by the server. The index of the supported types determines which supported type is prefered by the server should the accept header match more than one at the same priority. Returns: The preferred supported type if the accept header matches any, else None. """ for accept_item in parse_accept_header(accept_header): for supported_type in supported_types: if accept_item.match(supported_type): return supported_type return None @positional(1) def get_package_for_module(module): """Get package name for a module. Helper calculates the package name of a module. Args: module: Module to get name for. If module is a string, try to find module in sys.modules. Returns: If module contains 'package' attribute, uses that as package name. Else, if module is not the '__main__' module, the module __name__. Else, the base name of the module file name. Else None. """ if isinstance(module, basestring): try: module = sys.modules[module] except KeyError: return None try: return unicode(module.package) except AttributeError: if module.__name__ == '__main__': try: file_name = module.__file__ except AttributeError: pass else: base_name = os.path.basename(file_name) split_name = os.path.splitext(base_name) if len(split_name) == 1: return unicode(base_name) else: return u'.'.join(split_name[:-1]) return unicode(module.__name__) class TimeZoneOffset(datetime.tzinfo): """Time zone information as encoded/decoded for DateTimeFields.""" def __init__(self, offset): """Initialize a time zone offset. Args: offset: Integer or timedelta time zone offset, in minutes from UTC. This can be negative. """ super(TimeZoneOffset, self).__init__() if isinstance(offset, datetime.timedelta): offset = offset.total_seconds() self.__offset = offset def utcoffset(self, dt): """Get the a timedelta with the time zone's offset from UTC. Returns: The time zone offset from UTC, as a timedelta. """ return datetime.timedelta(minutes=self.__offset) def dst(self, dt): """Get the daylight savings time offset. The formats that ProtoRPC uses to encode/decode time zone information don't contain any information about daylight savings time. So this always returns a timedelta of 0. Returns: A timedelta of 0. """ return datetime.timedelta(0) def decode_datetime(encoded_datetime): """Decode a DateTimeField parameter from a string to a python datetime. Args: encoded_datetime: A string in RFC 3339 format. Returns: A datetime object with the date and time specified in encoded_datetime. Raises: ValueError: If the string is not in a recognized format. """ # Check if the string includes a time zone offset. Break out the # part that doesn't include time zone info. Convert to uppercase # because all our comparisons should be case-insensitive. time_zone_match = _TIME_ZONE_RE.search(encoded_datetime) if time_zone_match: time_string = encoded_datetime[:time_zone_match.start(1)].upper() else: time_string = encoded_datetime.upper() if '.' in time_string: format_string = '%Y-%m-%dT%H:%M:%S.%f' else: format_string = '%Y-%m-%dT%H:%M:%S' decoded_datetime = datetime.datetime.strptime(time_string, format_string) if not time_zone_match: return decoded_datetime # Time zone info was included in the parameter. Add a tzinfo # object to the datetime. Datetimes can't be changed after they're # created, so we'll need to create a new one. if time_zone_match.group('z'): offset_minutes = 0 else: sign = time_zone_match.group('sign') hours, minutes = [int(value) for value in time_zone_match.group('hours', 'minutes')] offset_minutes = hours * 60 + minutes if sign == '-': offset_minutes *= -1 return datetime.datetime(decoded_datetime.year, decoded_datetime.month, decoded_datetime.day, decoded_datetime.hour, decoded_datetime.minute, decoded_datetime.second, decoded_datetime.microsecond, TimeZoneOffset(offset_minutes)) protorpc-standalone-0.9.1/protorpc/util_test.py0000755000076500000240000003361312277637135022754 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.util.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import datetime import new import random import sys import unittest from protorpc import test_util from protorpc import util class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = util class PadStringTest(test_util.TestCase): def testPadEmptyString(self): self.assertEquals(' ' * 512, util.pad_string('')) def testPadString(self): self.assertEquals('hello' + (507 * ' '), util.pad_string('hello')) def testPadLongString(self): self.assertEquals('x' * 1000, util.pad_string('x' * 1000)) class UtilTest(test_util.TestCase): def testDecoratedFunction_LengthZero(self): @util.positional(0) def fn(kwonly=1): return [kwonly] self.assertEquals([1], fn()) self.assertEquals([2], fn(kwonly=2)) self.assertRaisesWithRegexpMatch(TypeError, r'fn\(\) takes at most 0 positional ' r'arguments \(1 given\)', fn, 1) def testDecoratedFunction_LengthOne(self): @util.positional(1) def fn(pos, kwonly=1): return [pos, kwonly] self.assertEquals([1, 1], fn(1)) self.assertEquals([2, 2], fn(2, kwonly=2)) self.assertRaisesWithRegexpMatch(TypeError, r'fn\(\) takes at most 1 positional ' r'argument \(2 given\)', fn, 2, 3) def testDecoratedFunction_LengthTwoWithDefault(self): @util.positional(2) def fn(pos1, pos2=1, kwonly=1): return [pos1, pos2, kwonly] self.assertEquals([1, 1, 1], fn(1)) self.assertEquals([2, 2, 1], fn(2, 2)) self.assertEquals([2, 3, 4], fn(2, 3, kwonly=4)) self.assertRaisesWithRegexpMatch(TypeError, r'fn\(\) takes at most 2 positional ' r'arguments \(3 given\)', fn, 2, 3, 4) def testDecoratedMethod(self): class MyClass(object): @util.positional(2) def meth(self, pos1, kwonly=1): return [pos1, kwonly] self.assertEquals([1, 1], MyClass().meth(1)) self.assertEquals([2, 2], MyClass().meth(2, kwonly=2)) self.assertRaisesWithRegexpMatch(TypeError, r'meth\(\) takes at most 2 positional ' r'arguments \(3 given\)', MyClass().meth, 2, 3) def testDefaultDecoration(self): @util.positional def fn(a, b, c=None): return a, b, c self.assertEquals((1, 2, 3), fn(1, 2, c=3)) self.assertEquals((3, 4, None), fn(3, b=4)) self.assertRaisesWithRegexpMatch(TypeError, r'fn\(\) takes at most 2 positional ' r'arguments \(3 given\)', fn, 2, 3, 4) def testDefaultDecorationNoKwdsFails(self): def fn(a): return a with self.assertRaisesRegexp( ValueError, 'Functions with no keyword arguments must specify ' 'max_positional_args'): util.positional(fn) class AcceptItemTest(test_util.TestCase): def CheckAttributes(self, item, main_type, sub_type, q=1, values={}, index=1): self.assertEquals(index, item.index) self.assertEquals(main_type, item.main_type) self.assertEquals(sub_type, item.sub_type) self.assertEquals(q, item.q) self.assertEquals(values, item.values) def testParse(self): self.CheckAttributes(util.AcceptItem('*/*', 1), None, None) self.CheckAttributes(util.AcceptItem('text/*', 1), 'text', None) self.CheckAttributes(util.AcceptItem('text/plain', 1), 'text', 'plain') self.CheckAttributes( util.AcceptItem('text/plain; q=0.3', 1), 'text', 'plain', 0.3, values={'q': '0.3'}) self.CheckAttributes( util.AcceptItem('text/plain; level=2', 1), 'text', 'plain', values={'level': '2'}) self.CheckAttributes( util.AcceptItem('text/plain', 10), 'text', 'plain', index=10) def testCaseInsensitive(self): self.CheckAttributes(util.AcceptItem('Text/Plain', 1), 'text', 'plain') def testBadValue(self): self.assertRaises(util.AcceptError, util.AcceptItem, 'bad value', 1) self.assertRaises(util.AcceptError, util.AcceptItem, 'bad value/', 1) self.assertRaises(util.AcceptError, util.AcceptItem, '/bad value', 1) def testSortKey(self): item = util.AcceptItem('main/sub; q=0.2; level=3', 11) self.assertEquals((False, False, -0.2, False, 11), item.sort_key) item = util.AcceptItem('main/*', 12) self.assertEquals((False, True, -1, True, 12), item.sort_key) item = util.AcceptItem('*/*', 1) self.assertEquals((True, True, -1, True, 1), item.sort_key) def testSort(self): i1 = util.AcceptItem('text/*', 1) i2 = util.AcceptItem('text/html', 2) i3 = util.AcceptItem('text/html; q=0.9', 3) i4 = util.AcceptItem('text/html; q=0.3', 4) i5 = util.AcceptItem('text/xml', 5) i6 = util.AcceptItem('text/html; level=1', 6) i7 = util.AcceptItem('*/*', 7) items = [i1, i2 ,i3 ,i4 ,i5 ,i6, i7] random.shuffle(items) self.assertEquals([i6, i2, i5, i3, i4, i1, i7], sorted(items)) def testMatchAll(self): item = util.AcceptItem('*/*', 1) self.assertTrue(item.match('text/html')) self.assertTrue(item.match('text/plain; level=1')) self.assertTrue(item.match('image/png')) self.assertTrue(item.match('image/png; q=0.3')) def testMatchMainType(self): item = util.AcceptItem('text/*', 1) self.assertTrue(item.match('text/html')) self.assertTrue(item.match('text/plain; level=1')) self.assertFalse(item.match('image/png')) self.assertFalse(item.match('image/png; q=0.3')) def testMatchFullType(self): item = util.AcceptItem('text/plain', 1) self.assertFalse(item.match('text/html')) self.assertTrue(item.match('text/plain; level=1')) self.assertFalse(item.match('image/png')) self.assertFalse(item.match('image/png; q=0.3')) def testMatchCaseInsensitive(self): item = util.AcceptItem('text/plain', 1) self.assertTrue(item.match('tExt/pLain')) def testStr(self): self.assertHeaderSame('*/*', str(util.AcceptItem('*/*', 1))) self.assertHeaderSame('text/*', str(util.AcceptItem('text/*', 1))) self.assertHeaderSame('text/plain', str(util.AcceptItem('text/plain', 1))) self.assertHeaderSame('text/plain; q=0.2', str(util.AcceptItem('text/plain; q=0.2', 1))) self.assertHeaderSame( 'text/plain; q=0.2; level=1', str(util.AcceptItem('text/plain; level=1; q=0.2', 1))) def testRepr(self): self.assertEquals("AcceptItem('*/*', 1)", repr(util.AcceptItem('*/*', 1))) self.assertEquals("AcceptItem('text/plain', 11)", repr(util.AcceptItem('text/plain', 11))) def testValues(self): item = util.AcceptItem('text/plain; a=1; b=2; c=3;', 1) values = item.values self.assertEquals(dict(a="1", b="2", c="3"), values) values['a'] = "7" self.assertNotEquals(values, item.values) class ParseAcceptHeaderTest(test_util.TestCase): def testIndex(self): accept_header = """text/*, text/html, text/html; q=0.9, text/xml, text/html; level=1, */*""" accepts = util.parse_accept_header(accept_header) self.assertEquals(6, len(accepts)) self.assertEquals([4, 1, 3, 2, 0, 5], [a.index for a in accepts]) class ChooseContentTypeTest(test_util.TestCase): def testIgnoreUnrequested(self): self.assertEquals('application/json', util.choose_content_type( 'text/plain, application/json, */*', ['application/X-Google-protobuf', 'application/json' ])) def testUseCorrectPreferenceIndex(self): self.assertEquals('application/json', util.choose_content_type( '*/*, text/plain, application/json', ['application/X-Google-protobuf', 'application/json' ])) def testPreferFirstInList(self): self.assertEquals('application/X-Google-protobuf', util.choose_content_type( '*/*', ['application/X-Google-protobuf', 'application/json' ])) def testCaseInsensitive(self): self.assertEquals('application/X-Google-protobuf', util.choose_content_type( 'application/x-google-protobuf', ['application/X-Google-protobuf', 'application/json' ])) class GetPackageForModuleTest(test_util.TestCase): def setUp(self): self.original_modules = dict(sys.modules) def tearDown(self): sys.modules.clear() sys.modules.update(self.original_modules) def CreateModule(self, name, file_name=None): if file_name is None: file_name = '%s.py' % name module = new.module(name) sys.modules[name] = module return module def assertPackageEquals(self, expected, actual): self.assertEquals(expected, actual) if actual is not None: self.assertTrue(isinstance(actual, unicode)) def testByString(self): module = self.CreateModule('service_module') module.package = 'my_package' self.assertPackageEquals('my_package', util.get_package_for_module('service_module')) def testModuleNameNotInSys(self): self.assertPackageEquals(None, util.get_package_for_module('service_module')) def testHasPackage(self): module = self.CreateModule('service_module') module.package = 'my_package' self.assertPackageEquals('my_package', util.get_package_for_module(module)) def testHasModuleName(self): module = self.CreateModule('service_module') self.assertPackageEquals('service_module', util.get_package_for_module(module)) def testIsMain(self): module = self.CreateModule('__main__') module.__file__ = '/bing/blam/bloom/blarm/my_file.py' self.assertPackageEquals('my_file', util.get_package_for_module(module)) def testIsMainCompiled(self): module = self.CreateModule('__main__') module.__file__ = '/bing/blam/bloom/blarm/my_file.pyc' self.assertPackageEquals('my_file', util.get_package_for_module(module)) def testNoExtension(self): module = self.CreateModule('__main__') module.__file__ = '/bing/blam/bloom/blarm/my_file' self.assertPackageEquals('my_file', util.get_package_for_module(module)) def testNoPackageAtAll(self): module = self.CreateModule('__main__') self.assertPackageEquals('__main__', util.get_package_for_module(module)) class DateTimeTests(test_util.TestCase): def testDecodeDateTime(self): """Test that a RFC 3339 datetime string is decoded properly.""" for datetime_string, datetime_vals in ( ('2012-09-30T15:31:50.262', (2012, 9, 30, 15, 31, 50, 262000)), ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): decoded = util.decode_datetime(datetime_string) expected = datetime.datetime(*datetime_vals) self.assertEquals(expected, decoded) def testDateTimeTimeZones(self): """Test that a datetime string with a timezone is decoded correctly.""" for datetime_string, datetime_vals in ( ('2012-09-30T15:31:50.262-06:00', (2012, 9, 30, 15, 31, 50, 262000, util.TimeZoneOffset(-360))), ('2012-09-30T15:31:50.262+01:30', (2012, 9, 30, 15, 31, 50, 262000, util.TimeZoneOffset(90))), ('2012-09-30T15:31:50+00:05', (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(5))), ('2012-09-30T15:31:50+00:00', (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(0))), ('2012-09-30t15:31:50-00:00', (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(0))), ('2012-09-30t15:31:50z', (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(0))), ('2012-09-30T15:31:50-23:00', (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(-1380)))): decoded = util.decode_datetime(datetime_string) expected = datetime.datetime(*datetime_vals) self.assertEquals(expected, decoded) def testDecodeDateTimeInvalid(self): """Test that decoding malformed datetime strings raises execptions.""" for datetime_string in ('invalid', '2012-09-30T15:31:50.', '-08:00 2012-09-30T15:31:50.262', '2012-09-30T15:31', '2012-09-30T15:31Z', '2012-09-30T15:31:50ZZ', '2012-09-30T15:31:50.262 blah blah -08:00', '1000-99-99T25:99:99.999-99:99'): self.assertRaises(ValueError, util.decode_datetime, datetime_string) def testTimeZoneOffsetDelta(self): """Test that delta works with TimeZoneOffset.""" time_zone = util.TimeZoneOffset(datetime.timedelta(minutes=3)) epoch = time_zone.utcoffset(datetime.datetime.utcfromtimestamp(0)) self.assertEqual(10800, epoch.total_seconds()) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/webapp/0000755000076500000240000000000012300027071021606 5ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc/webapp/__init__.py0000755000076500000240000000121012277637135023741 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2011 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # __author__ = 'rafek@google.com (Rafe Kaplan)' protorpc-standalone-0.9.1/protorpc/webapp/forms.py0000755000076500000240000001240412277637135023337 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Webapp forms interface to ProtoRPC services. This webapp application is automatically configured to work with ProtoRPCs that have a configured protorpc.RegistryService. This webapp is automatically added to the registry service URL at /forms (default is /protorpc/form) when configured using the service_handlers.service_mapping function. """ import os import webapp2 as webapp try: from .google_imports import template except: import logging logging.warning('template module not available, some features may not work.') __all__ = ['FormsHandler', 'ResourceHandler', 'DEFAULT_REGISTRY_PATH', ] _TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static') _FORMS_TEMPLATE = os.path.join(_TEMPLATES_DIR, 'forms.html') _METHODS_TEMPLATE = os.path.join(_TEMPLATES_DIR, 'methods.html') DEFAULT_REGISTRY_PATH = '/protorpc' class ResourceHandler(webapp.RequestHandler): """Serves static resources without needing to add static files to app.yaml.""" __RESOURCE_MAP = { 'forms.js': 'text/javascript', 'jquery-1.4.2.min.js': 'text/javascript', 'jquery.json-2.2.min.js': 'text/javascript', } def get(self, relative): """Serve known static files. If static file is not known, will return 404 to client. Response items are cached for 300 seconds. Args: relative: Name of static file relative to main FormsHandler. """ content_type = self.__RESOURCE_MAP.get(relative, None) if not content_type: self.response.set_status(404) self.response.out.write('Resource not found.') return path = os.path.join(_TEMPLATES_DIR, relative) self.response.headers['Content-Type'] = content_type static_file = open(path) try: contents = static_file.read() finally: static_file.close() self.response.out.write(contents) class FormsHandler(webapp.RequestHandler): """Handler for display HTML/javascript forms of ProtoRPC method calls. When accessed with no query parameters, will show a web page that displays all services and methods on the associated registry path. Links on this page fill in the service_path and method_name query parameters back to this same handler. When provided with service_path and method_name parameters will display a dynamic form representing the request message for that method. When sent, the form sends a JSON request to the ProtoRPC method and displays the response in the HTML page. Attribute: registry_path: Read-only registry path known by this handler. """ def __init__(self, registry_path=DEFAULT_REGISTRY_PATH): """Constructor. When configuring a FormsHandler to use with a webapp application do not pass the request handler class in directly. Instead use new_factory to ensure that the FormsHandler is created with the correct registry path for each request. Args: registry_path: Absolute path on server where the ProtoRPC RegsitryService is located. """ assert registry_path self.__registry_path = registry_path @property def registry_path(self): return self.__registry_path def get(self): """Send forms and method page to user. By default, displays a web page listing all services and methods registered on the server. Methods have links to display the actual method form. If both parameters are set, will display form for method. Query Parameters: service_path: Path to service to display method of. Optional. method_name: Name of method to display form for. Optional. """ params = {'forms_path': self.request.path.rstrip('/'), 'hostname': self.request.host, 'registry_path': self.__registry_path, } service_path = self.request.get('path', None) method_name = self.request.get('method', None) if service_path and method_name: form_template = _METHODS_TEMPLATE params['service_path'] = service_path params['method_name'] = method_name else: form_template = _FORMS_TEMPLATE self.response.out.write(template.render(form_template, params)) @classmethod def new_factory(cls, registry_path=DEFAULT_REGISTRY_PATH): """Construct a factory for use with WSGIApplication. This method is called automatically with the correct registry path when services are configured via service_handlers.service_mapping. Args: registry_path: Absolute path on server where the ProtoRPC RegsitryService is located. Returns: Factory function that creates a properly configured FormsHandler instance. """ def forms_factory(): return cls(registry_path) return forms_factory protorpc-standalone-0.9.1/protorpc/webapp/forms_test.py0000755000076500000240000000772012277637135024403 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.forms.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import os import unittest from protorpc import test_util from protorpc import webapp_test_util from protorpc.webapp import forms from protorpc.webapp.google_imports import template class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = forms def RenderTemplate(name, **params): """Load content from static file. Args: name: Name of static file to load from static directory. params: Passed in to webapp template generator. Returns: Contents of static file. """ path = os.path.join(forms._TEMPLATES_DIR, name) return template.render(path, params) class ResourceHandlerTest(webapp_test_util.RequestHandlerTestBase): def CreateRequestHandler(self): return forms.ResourceHandler() def DoStaticContentTest(self, name, expected_type): """Run the static content test. Loads expected static content from source and compares with results in response. Checks content-type and cache header. Args: name: Name of file that should be served. expected_type: Expected content-type of served file. """ self.handler.get(name) content = RenderTemplate(name) self.CheckResponse('200 OK', {'content-type': expected_type, }, content) def testGet(self): self.DoStaticContentTest('forms.js', 'text/javascript') def testNoSuchFile(self): self.handler.get('unknown.txt') self.CheckResponse('404 Not Found', {}, 'Resource not found.') class FormsHandlerTest(webapp_test_util.RequestHandlerTestBase): def CreateRequestHandler(self): handler = forms.FormsHandler('/myreg') self.assertEquals('/myreg', handler.registry_path) return handler def testGetForm(self): self.handler.get() content = RenderTemplate( 'forms.html', forms_path='/tmp/myhandler', hostname=self.request.host, registry_path='/myreg') self.CheckResponse('200 OK', {}, content) def testGet_MissingPath(self): self.ResetHandler({'QUERY_STRING': 'method=my_method'}) self.handler.get() content = RenderTemplate( 'forms.html', forms_path='/tmp/myhandler', hostname=self.request.host, registry_path='/myreg') self.CheckResponse('200 OK', {}, content) def testGet_MissingMethod(self): self.ResetHandler({'QUERY_STRING': 'path=/my-path'}) self.handler.get() content = RenderTemplate( 'forms.html', forms_path='/tmp/myhandler', hostname=self.request.host, registry_path='/myreg') self.CheckResponse('200 OK', {}, content) def testGetMethod(self): self.ResetHandler({'QUERY_STRING': 'path=/my-path&method=my_method'}) self.handler.get() content = RenderTemplate( 'methods.html', forms_path='/tmp/myhandler', hostname=self.request.host, registry_path='/myreg', service_path='/my-path', method_name='my_method') self.CheckResponse('200 OK', {}, content) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/webapp/google_imports.py0000755000076500000240000000117112277637135025241 0ustar jeremydwstaff00000000000000"""Dynamically decide from where to import other SDK modules. All other protorpc.webapp code should import other SDK modules from this module. If necessary, add new imports here (in both places). """ __author__ = 'yey@google.com (Ye Yuan)' # pylint: disable=g-import-not-at-top # pylint: disable=unused-import import os import sys try: from google.appengine import ext normal_environment = True except ImportError: normal_environment = False if normal_environment: from google.appengine.ext import webapp from google.appengine.ext.webapp import util as webapp_util from google.appengine.ext.webapp import template protorpc-standalone-0.9.1/protorpc/webapp/service_handlers.py0000755000076500000240000007114212300025654025515 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Handlers for remote services. This module contains classes that may be used to build a service on top of the App Engine Webapp framework. The services request handler can be configured to handle requests in a number of different request formats. All different request formats must have a way to map the request to the service handlers defined request message.Message class. The handler can also send a response in any format that can be mapped from the response message.Message class. Participants in an RPC: There are four classes involved with the life cycle of an RPC. Service factory: A user-defined service factory that is responsible for instantiating an RPC service. The methods intended for use as RPC methods must be decorated by the 'remote' decorator. RPCMapper: Responsible for determining whether or not a specific request matches a particular RPC format and translating between the actual request/response and the underlying message types. A single instance of an RPCMapper sub-class is required per service configuration. Each mapper must be usable across multiple requests. ServiceHandler: A webapp.RequestHandler sub-class that responds to the webapp framework. It mediates between the RPCMapper and service implementation class during a request. As determined by the Webapp framework, a new ServiceHandler instance is created to handle each user request. A handler is never used to handle more than one request. ServiceHandlerFactory: A class that is responsible for creating new, properly configured ServiceHandler instance for each request. The factory is configured by providing it with a set of RPCMapper instances. When the Webapp framework invokes the service handler, the handler creates a new service class instance. The service class instance is provided with a reference to the handler. A single instance of an RPCMapper sub-class is required to configure each service. Each mapper instance must be usable across multiple requests. RPC mappers: RPC mappers translate between a single HTTP based RPC protocol and the underlying service implementation. Each RPC mapper must configured with the following information to determine if it is an appropriate mapper for a given request: http_methods: Set of HTTP methods supported by handler. content_types: Set of supported content types. default_content_type: Default content type for handler responses. Built-in mapper implementations: URLEncodedRPCMapper: Matches requests that are compatible with post forms with the 'application/x-www-form-urlencoded' content-type (this content type is the default if none is specified. It translates post parameters into request parameters. ProtobufRPCMapper: Matches requests that are compatible with post forms with the 'application/x-google-protobuf' content-type. It reads the contents of a binary post request. Public Exceptions: Error: Base class for service handler errors. ServiceConfigurationError: Raised when a service not correctly configured. RequestError: Raised by RPC mappers when there is an error in its request or request format. ResponseError: Raised by RPC mappers when there is an error in its response. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import httplib import logging import webapp2 as webapp #from .google_imports import webapp_util from .. import messages from .. import protobuf from .. import protojson from .. import protourlencode from .. import registry from .. import remote from .. import util from . import forms __all__ = [ 'Error', 'RequestError', 'ResponseError', 'ServiceConfigurationError', 'DEFAULT_REGISTRY_PATH', 'ProtobufRPCMapper', 'RPCMapper', 'ServiceHandler', 'ServiceHandlerFactory', 'URLEncodedRPCMapper', 'JSONRPCMapper', 'service_mapping', 'run_services', ] class Error(Exception): """Base class for all errors in service handlers module.""" class ServiceConfigurationError(Error): """When service configuration is incorrect.""" class RequestError(Error): """Error occurred when building request.""" class ResponseError(Error): """Error occurred when building response.""" _URLENCODED_CONTENT_TYPE = protourlencode.CONTENT_TYPE _PROTOBUF_CONTENT_TYPE = protobuf.CONTENT_TYPE _JSON_CONTENT_TYPE = protojson.CONTENT_TYPE _EXTRA_JSON_CONTENT_TYPES = ['application/x-javascript', 'text/javascript', 'text/x-javascript', 'text/x-json', 'text/json', ] # The whole method pattern is an optional regex. It contains a single # group used for mapping to the query parameter. This is passed to the # parameters of 'get' and 'post' on the ServiceHandler. _METHOD_PATTERN = r'(?:\.([^?]*))?' DEFAULT_REGISTRY_PATH = forms.DEFAULT_REGISTRY_PATH class RPCMapper(object): """Interface to mediate between request and service object. Request mappers are implemented to support various types of RPC protocols. It is responsible for identifying whether a given request matches a particular protocol, resolve the remote method to invoke and mediate between the request and appropriate protocol messages for the remote method. """ @util.positional(4) def __init__(self, http_methods, default_content_type, protocol, content_types=None): """Constructor. Args: http_methods: Set of HTTP methods supported by mapper. default_content_type: Default content type supported by mapper. protocol: The protocol implementation. Must implement encode_message and decode_message. content_types: Set of additionally supported content types. """ self.__http_methods = frozenset(http_methods) self.__default_content_type = default_content_type self.__protocol = protocol if content_types is None: content_types = [] self.__content_types = frozenset([self.__default_content_type] + content_types) @property def http_methods(self): return self.__http_methods @property def default_content_type(self): return self.__default_content_type @property def content_types(self): return self.__content_types def build_request(self, handler, request_type): """Build request message based on request. Each request mapper implementation is responsible for converting a request to an appropriate message instance. Args: handler: RequestHandler instance that is servicing request. Must be initialized with request object and been previously determined to matching the protocol of the RPCMapper. request_type: Message type to build. Returns: Instance of request_type populated by protocol buffer in request body. Raises: RequestError if the mapper implementation is not able to correctly convert the request to the appropriate message. """ try: return self.__protocol.decode_message(request_type, handler.request.body) except (messages.ValidationError, messages.DecodeError), err: raise RequestError('Unable to parse request content: %s' % err) def build_response(self, handler, response, pad_string=False): """Build response based on service object response message. Each request mapper implementation is responsible for converting a response message to an appropriate handler response. Args: handler: RequestHandler instance that is servicing request. Must be initialized with request object and been previously determined to matching the protocol of the RPCMapper. response: Response message as returned from the service object. Raises: ResponseError if the mapper implementation is not able to correctly convert the message to an appropriate response. """ try: encoded_message = self.__protocol.encode_message(response) except messages.ValidationError, err: raise ResponseError('Unable to encode message: %s' % err) else: handler.response.headers['Content-Type'] = self.default_content_type handler.response.out.write(encoded_message) class ServiceHandlerFactory(object): """Factory class used for instantiating new service handlers. Normally a handler class is passed directly to the webapp framework so that it can be simply instantiated to handle a single request. The service handler, however, must be configured with additional information so that it knows how to instantiate a service object. This class acts the same as a normal RequestHandler class by overriding the __call__ method to correctly configures a ServiceHandler instance with a new service object. The factory must also provide a set of RPCMapper instances which examine a request to determine what protocol is being used and mediates between the request and the service object. The mapping of a service handler must have a single group indicating the part of the URL path that maps to the request method. This group must exist but can be optional for the request (the group may be followed by '?' in the regular expression matching the request). Usage: stock_factory = ServiceHandlerFactory(StockService) ... configure stock_factory by adding RPCMapper instances ... application = webapp.WSGIApplication( [stock_factory.mapping('/stocks')]) Default usage: application = webapp.WSGIApplication( [ServiceHandlerFactory.default(StockService).mapping('/stocks')]) """ def __init__(self, service_factory): """Constructor. Args: service_factory: Service factory to instantiate and provide to service handler. """ self.__service_factory = service_factory self.__request_mappers = [] def all_request_mappers(self): """Get all request mappers. Returns: Iterator of all request mappers used by this service factory. """ return iter(self.__request_mappers) def add_request_mapper(self, mapper): """Add request mapper to end of request mapper list.""" self.__request_mappers.append(mapper) def __call__(self): """Construct a new service handler instance.""" return ServiceHandler(self, self.__service_factory()) @property def service_factory(self): """Service factory associated with this factory.""" return self.__service_factory @staticmethod def __check_path(path): """Check a path parameter. Make sure a provided path parameter is compatible with the webapp URL mapping. Args: path: Path to check. This is a plain path, not a regular expression. Raises: ValueError if path does not start with /, path ends with /. """ if path.endswith('/'): raise ValueError('Path %s must not end with /.' % path) def mapping(self, path): """Convenience method to map service to application. Args: path: Path to map service to. It must be a simple path with a leading / and no trailing /. Returns: Mapping from service URL to service handler factory. """ self.__check_path(path) service_url_pattern = r'(%s)%s' % (path, _METHOD_PATTERN) return service_url_pattern, self @classmethod def default(cls, service_factory, parameter_prefix=''): """Convenience method to map default factory configuration to application. Creates a standardized default service factory configuration that pre-maps the URL encoded protocol handler to the factory. Args: service_factory: Service factory to instantiate and provide to service handler. method_parameter: The name of the form parameter used to determine the method to invoke used by the URLEncodedRPCMapper. If None, no parameter is used and the mapper will only match against the form path-name. Defaults to 'method'. parameter_prefix: If provided, all the parameters in the form are expected to begin with that prefix by the URLEncodedRPCMapper. Returns: Mapping from service URL to service handler factory. """ factory = cls(service_factory) factory.add_request_mapper(ProtobufRPCMapper()) factory.add_request_mapper(JSONRPCMapper()) return factory class ServiceHandler(webapp.RequestHandler): """Web handler for RPC service. Overridden methods: get: All requests handled by 'handle' method. HTTP method stored in attribute. Takes remote_method parameter as derived from the URL mapping. post: All requests handled by 'handle' method. HTTP method stored in attribute. Takes remote_method parameter as derived from the URL mapping. redirect: Not implemented for this service handler. New methods: handle: Handle request for both GET and POST. Attributes (in addition to attributes in RequestHandler): service: Service instance associated with request being handled. method: Method of request. Used by RPCMapper to determine match. remote_method: Sub-path as provided to the 'get' and 'post' methods. """ def __init__(self, factory, service): """Constructor. Args: factory: Instance of ServiceFactory used for constructing new service instances used for handling requests. service: Service instance used for handling RPC. """ self.__factory = factory self.__service = service @property def service(self): return self.__service def __show_info(self, service_path, remote_method): self.response.headers['content-type'] = 'text/plain; charset=utf-8' response_message = [] if remote_method: response_message.append('%s.%s is a ProtoRPC method.\n\n' %( service_path, remote_method)) else: response_message.append('%s is a ProtoRPC service.\n\n' % service_path) definition_name_function = getattr(self.__service, 'definition_name', None) if definition_name_function: definition_name = definition_name_function() else: definition_name = '%s.%s' % (self.__service.__module__, self.__service.__class__.__name__) response_message.append('Service %s\n\n' % definition_name) response_message.append('More about ProtoRPC: ') response_message.append('http://code.google.com/p/google-protorpc\n') self.response.out.write(util.pad_string(''.join(response_message))) def get(self, service_path, remote_method): """Handler method for GET requests. Args: service_path: Service path derived from request URL. remote_method: Sub-path after service path has been matched. """ self.handle('GET', service_path, remote_method) def post(self, service_path, remote_method): """Handler method for POST requests. Args: service_path: Service path derived from request URL. remote_method: Sub-path after service path has been matched. """ self.handle('POST', service_path, remote_method) def redirect(self, uri, permanent=False): """Not supported for services.""" raise NotImplementedError('Services do not currently support redirection.') def __send_error(self, http_code, status_state, error_message, mapper, error_name=None): status = remote.RpcStatus(state=status_state, error_message=error_message, error_name=error_name) mapper.build_response(self, status) self.response.headers['content-type'] = mapper.default_content_type logging.error(error_message) response_content = self.response.out.getvalue() padding = ' ' * max(0, 512 - len(response_content)) self.response.out.write(padding) self.response.set_status(http_code, error_message) def __send_simple_error(self, code, message, pad=True): """Send error to caller without embedded message.""" self.response.headers['content-type'] = 'text/plain; charset=utf-8' logging.error(message) self.response.set_status(code, message) response_message = httplib.responses.get(code, 'Unknown Error') if pad: response_message = util.pad_string(response_message) self.response.out.write(response_message) def __get_content_type(self): content_type = self.request.headers.get('content-type', None) if not content_type: content_type = self.request.environ.get('HTTP_CONTENT_TYPE', None) if not content_type: return None # Lop off parameters from the end (for example content-encoding) return content_type.split(';', 1)[0].lower() def __headers(self, content_type): for name in self.request.headers: name = name.lower() if name == 'content-type': value = content_type elif name == 'content-length': value = str(len(self.request.body)) else: value = self.request.headers.get(name, '') yield name, value def handle(self, http_method, service_path, remote_method): """Handle a service request. The handle method will handle either a GET or POST response. It is up to the individual mappers from the handler factory to determine which request methods they can service. If the protocol is not recognized, the request does not provide a correct request for that protocol or the service object does not support the requested RPC method, will return error code 400 in the response. Args: http_method: HTTP method of request. service_path: Service path derived from request URL. remote_method: Sub-path after service path has been matched. """ self.response.headers['x-content-type-options'] = 'nosniff' if not remote_method and http_method == 'GET': # Special case a normal get request, presumably via a browser. self.error(405) self.__show_info(service_path, remote_method) return content_type = self.__get_content_type() # Provide server state to the service. If the service object does not have # an "initialize_request_state" method, will not attempt to assign state. try: state_initializer = self.service.initialize_request_state except AttributeError: pass else: server_port = self.request.environ.get('SERVER_PORT', None) if server_port: server_port = int(server_port) request_state = remote.HttpRequestState( remote_host=self.request.environ.get('REMOTE_HOST', None), remote_address=self.request.environ.get('REMOTE_ADDR', None), server_host=self.request.environ.get('SERVER_HOST', None), server_port=server_port, http_method=http_method, service_path=service_path, headers=list(self.__headers(content_type))) state_initializer(request_state) if not content_type: self.__send_simple_error(400, 'Invalid RPC request: missing content-type') return # Search for mapper to mediate request. for mapper in self.__factory.all_request_mappers(): if content_type in mapper.content_types: break else: if http_method == 'GET': self.error(httplib.UNSUPPORTED_MEDIA_TYPE) self.__show_info(service_path, remote_method) else: self.__send_simple_error(httplib.UNSUPPORTED_MEDIA_TYPE, 'Unsupported content-type: %s' % content_type) return try: if http_method not in mapper.http_methods: if http_method == 'GET': self.error(httplib.METHOD_NOT_ALLOWED) self.__show_info(service_path, remote_method) else: self.__send_simple_error(httplib.METHOD_NOT_ALLOWED, 'Unsupported HTTP method: %s' % http_method) return try: try: method = getattr(self.service, remote_method) method_info = method.remote except AttributeError, err: self.__send_error( 400, remote.RpcState.METHOD_NOT_FOUND_ERROR, 'Unrecognized RPC method: %s' % remote_method, mapper) return request = mapper.build_request(self, method_info.request_type) except (RequestError, messages.DecodeError), err: self.__send_error(400, remote.RpcState.REQUEST_ERROR, 'Error parsing ProtoRPC request (%s)' % err, mapper) return try: response = method(request) except remote.ApplicationError, err: self.__send_error(400, remote.RpcState.APPLICATION_ERROR, err.message, mapper, err.error_name) return mapper.build_response(self, response) except Exception, err: logging.error('An unexpected error occured when handling RPC: %s', err, exc_info=1) self.__send_error(500, remote.RpcState.SERVER_ERROR, 'Internal Server Error', mapper) return # TODO(rafek): Support tag-id only forms. class URLEncodedRPCMapper(RPCMapper): """Request mapper for application/x-www-form-urlencoded forms. This mapper is useful for building forms that can invoke RPC. Many services are also configured to work using URL encoded request information because of its perceived ease of programming and debugging. The mapper must be provided with at least method_parameter or remote_method_pattern so that it is possible to determine how to determine the requests RPC method. If both are provided, the service will respond to both method request types, however, only one may be present in a given request. If both types are detected, the request will not match. """ def __init__(self, parameter_prefix=''): """Constructor. Args: parameter_prefix: If provided, all the parameters in the form are expected to begin with that prefix. """ # Private attributes: # __parameter_prefix: parameter prefix as provided by constructor # parameter. super(URLEncodedRPCMapper, self).__init__(['POST'], _URLENCODED_CONTENT_TYPE, self) self.__parameter_prefix = parameter_prefix def encode_message(self, message): """Encode a message using parameter prefix. Args: message: Message to URL Encode. Returns: URL encoded message. """ return protourlencode.encode_message(message, prefix=self.__parameter_prefix) @property def parameter_prefix(self): """Prefix all form parameters are expected to begin with.""" return self.__parameter_prefix def build_request(self, handler, request_type): """Build request from URL encoded HTTP request. Constructs message from names of URL encoded parameters. If this service handler has a parameter prefix, parameters must begin with it or are ignored. Args: handler: RequestHandler instance that is servicing request. request_type: Message type to build. Returns: Instance of request_type populated by protocol buffer in request parameters. Raises: RequestError if message type contains nested message field or repeated message field. Will raise RequestError if there are any repeated parameters. """ request = request_type() builder = protourlencode.URLEncodedRequestBuilder( request, prefix=self.__parameter_prefix) for argument in sorted(handler.request.arguments()): values = handler.request.get_all(argument) try: builder.add_parameter(argument, values) except messages.DecodeError, err: raise RequestError(str(err)) return request class ProtobufRPCMapper(RPCMapper): """Request mapper for application/x-protobuf service requests. This mapper will parse protocol buffer from a POST body and return the request as a protocol buffer. """ def __init__(self): super(ProtobufRPCMapper, self).__init__(['POST'], _PROTOBUF_CONTENT_TYPE, protobuf) class JSONRPCMapper(RPCMapper): """Request mapper for application/x-protobuf service requests. This mapper will parse protocol buffer from a POST body and return the request as a protocol buffer. """ def __init__(self): super(JSONRPCMapper, self).__init__( ['POST'], _JSON_CONTENT_TYPE, protojson, content_types=_EXTRA_JSON_CONTENT_TYPES) def service_mapping(services, registry_path=DEFAULT_REGISTRY_PATH): """Create a services mapping for use with webapp. Creates basic default configuration and registration for ProtoRPC services. Each service listed in the service mapping has a standard service handler factory created for it. The list of mappings can either be an explicit path to service mapping or just services. If mappings are just services, they will automatically be mapped to their default name. For exampel: package = 'my_package' class MyService(remote.Service): ... server_mapping([('/my_path', MyService), # Maps to /my_path MyService, # Maps to /my_package/MyService ]) Specifying a service mapping: Normally services are mapped to URL paths by specifying a tuple (path, service): path: The path the service resides on. service: The service class or service factory for creating new instances of the service. For more information about service factories, please see remote.Service.new_factory. If no tuple is provided, and therefore no path specified, a default path is calculated by using the fully qualified service name using a URL path separator for each of its components instead of a '.'. Args: services: Can be service type, service factory or string definition name of service being mapped or list of tuples (path, service): path: Path on server to map service to. service: Service type, service factory or string definition name of service being mapped. Can also be a dict. If so, the keys are treated as the path and values as the service. registry_path: Path to give to registry service. Use None to disable registry service. Returns: List of tuples defining a mapping of request handlers compatible with a webapp application. Raises: ServiceConfigurationError when duplicate paths are provided. """ if isinstance(services, dict): services = services.iteritems() mapping = [] registry_map = {} if registry_path is not None: registry_service = registry.RegistryService.new_factory(registry_map) services = list(services) + [(registry_path, registry_service)] mapping.append((registry_path + r'/form(?:/)?', forms.FormsHandler.new_factory(registry_path))) mapping.append((registry_path + r'/form/(.+)', forms.ResourceHandler)) paths = set() for service_item in services: infer_path = not isinstance(service_item, (list, tuple)) if infer_path: service = service_item else: service = service_item[1] service_class = getattr(service, 'service_class', service) if infer_path: path = '/' + service_class.definition_name().replace('.', '/') else: path = service_item[0] if path in paths: raise ServiceConfigurationError( 'Path %r is already defined in service mapping' % path.encode('utf-8')) else: paths.add(path) # Create service mapping for webapp. new_mapping = ServiceHandlerFactory.default(service).mapping(path) mapping.append(new_mapping) # Update registry with service class. registry_map[path] = service_class return mapping def run_services(services, registry_path=DEFAULT_REGISTRY_PATH): """Handle CGI request using service mapping. Args: Same as service_mapping. """ mappings = service_mapping(services, registry_path=registry_path) application = webapp.WSGIApplication(mappings) webapp_util.run_wsgi_app(application) protorpc-standalone-0.9.1/protorpc/webapp/service_handlers_test.py0000755000076500000240000013267712277637135026607 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for protorpc.service_handlers.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import cgi import cStringIO import os import re import sys import unittest import urllib from protorpc import messages from protorpc import protobuf from protorpc import protojson from protorpc import protourlencode from protorpc import message_types from protorpc import registry from protorpc import remote from protorpc import test_util from protorpc import util from protorpc import webapp_test_util from protorpc.webapp import forms from protorpc.webapp import service_handlers from protorpc.webapp.google_imports import webapp import mox package = 'test_package' class ModuleInterfaceTest(test_util.ModuleInterfaceTest, test_util.TestCase): MODULE = service_handlers class Enum1(messages.Enum): """A test enum class.""" VAL1 = 1 VAL2 = 2 VAL3 = 3 class Request1(messages.Message): """A test request message type.""" integer_field = messages.IntegerField(1) string_field = messages.StringField(2) enum_field = messages.EnumField(Enum1, 3) class Response1(messages.Message): """A test response message type.""" integer_field = messages.IntegerField(1) string_field = messages.StringField(2) enum_field = messages.EnumField(Enum1, 3) class SuperMessage(messages.Message): """A test message with a nested message field.""" sub_message = messages.MessageField(Request1, 1) sub_messages = messages.MessageField(Request1, 2, repeated=True) class SuperSuperMessage(messages.Message): """A test message with two levels of nested.""" sub_message = messages.MessageField(SuperMessage, 1) sub_messages = messages.MessageField(Request1, 2, repeated=True) class RepeatedMessage(messages.Message): """A test message with a repeated field.""" ints = messages.IntegerField(1, repeated=True) strings = messages.StringField(2, repeated=True) enums = messages.EnumField(Enum1, 3, repeated=True) class Service(object): """A simple service that takes a Request1 and returns Request2.""" @remote.method(Request1, Response1) def method1(self, request): response = Response1() if hasattr(request, 'integer_field'): response.integer_field = request.integer_field if hasattr(request, 'string_field'): response.string_field = request.string_field if hasattr(request, 'enum_field'): response.enum_field = request.enum_field return response @remote.method(RepeatedMessage, RepeatedMessage) def repeated_method(self, request): response = RepeatedMessage() if hasattr(request, 'ints'): response = request.ints return response def not_remote(self): pass def VerifyResponse(test, response, expected_status, expected_status_message, expected_content, expected_content_type='application/x-www-form-urlencoded'): def write(content): if expected_content == '': test.assertEquals(util.pad_string(''), content) else: test.assertNotEquals(-1, content.find(expected_content), 'Expected to find:\n%s\n\nActual content: \n%s' % ( expected_content, content)) def start_response(response, headers): status, message = response.split(' ', 1) test.assertEquals(expected_status, status) test.assertEquals(expected_status_message, message) for name, value in headers: if name.lower() == 'content-type': test.assertEquals(expected_content_type, value) for name, value in headers: if name.lower() == 'x-content-type-options': test.assertEquals('nosniff', value) elif name.lower() == 'content-type': test.assertFalse(value.lower().startswith('text/html')) return write response.wsgi_write(start_response) class ServiceHandlerFactoryTest(test_util.TestCase): """Tests for the service handler factory.""" def testAllRequestMappers(self): """Test all_request_mappers method.""" configuration = service_handlers.ServiceHandlerFactory(Service) mapper1 = service_handlers.RPCMapper(['whatever'], 'whatever', None) mapper2 = service_handlers.RPCMapper(['whatever'], 'whatever', None) configuration.add_request_mapper(mapper1) self.assertEquals([mapper1], list(configuration.all_request_mappers())) configuration.add_request_mapper(mapper2) self.assertEquals([mapper1, mapper2], list(configuration.all_request_mappers())) def testServiceFactory(self): """Test that service_factory attribute is set.""" handler_factory = service_handlers.ServiceHandlerFactory(Service) self.assertEquals(Service, handler_factory.service_factory) def testFactoryMethod(self): """Test that factory creates correct instance of class.""" factory = service_handlers.ServiceHandlerFactory(Service) handler = factory() self.assertTrue(isinstance(handler, service_handlers.ServiceHandler)) self.assertTrue(isinstance(handler.service, Service)) def testMapping(self): """Test the mapping method.""" factory = service_handlers.ServiceHandlerFactory(Service) path, mapped_factory = factory.mapping('/my_service') self.assertEquals(r'(/my_service)' + service_handlers._METHOD_PATTERN, path) self.assertEquals(id(factory), id(mapped_factory)) match = re.match(path, '/my_service.my_method') self.assertEquals('/my_service', match.group(1)) self.assertEquals('my_method', match.group(2)) path, mapped_factory = factory.mapping('/my_service/nested') self.assertEquals('(/my_service/nested)' + service_handlers._METHOD_PATTERN, path) match = re.match(path, '/my_service/nested.my_method') self.assertEquals('/my_service/nested', match.group(1)) self.assertEquals('my_method', match.group(2)) def testRegexMapping(self): """Test the mapping method using a regex.""" factory = service_handlers.ServiceHandlerFactory(Service) path, mapped_factory = factory.mapping('.*/my_service') self.assertEquals(r'(.*/my_service)' + service_handlers._METHOD_PATTERN, path) self.assertEquals(id(factory), id(mapped_factory)) match = re.match(path, '/whatever_preceeds/my_service.my_method') self.assertEquals('/whatever_preceeds/my_service', match.group(1)) self.assertEquals('my_method', match.group(2)) match = re.match(path, '/something_else/my_service.my_other_method') self.assertEquals('/something_else/my_service', match.group(1)) self.assertEquals('my_other_method', match.group(2)) def testMapping_BadPath(self): """Test bad parameterse to the mapping method.""" factory = service_handlers.ServiceHandlerFactory(Service) self.assertRaises(ValueError, factory.mapping, '/my_service/') def testDefault(self): """Test the default factory convenience method.""" handler_factory = service_handlers.ServiceHandlerFactory.default( Service, parameter_prefix='my_prefix.') self.assertEquals(Service, handler_factory.service_factory) mappers = handler_factory.all_request_mappers() # Verify Protobuf encoded mapper. protobuf_mapper = mappers.next() self.assertTrue(isinstance(protobuf_mapper, service_handlers.ProtobufRPCMapper)) # Verify JSON encoded mapper. json_mapper = mappers.next() self.assertTrue(isinstance(json_mapper, service_handlers.JSONRPCMapper)) # Should have no more mappers. self.assertRaises(StopIteration, mappers.next) class ServiceHandlerTest(webapp_test_util.RequestHandlerTestBase): """Test the ServiceHandler class.""" def setUp(self): self.mox = mox.Mox() self.service_factory = Service self.remote_host = 'remote.host.com' self.server_host = 'server.host.com' self.ResetRequestHandler() self.request = Request1() self.request.integer_field = 1 self.request.string_field = 'a' self.request.enum_field = Enum1.VAL1 def ResetRequestHandler(self): super(ServiceHandlerTest, self).setUp() def CreateService(self): return self.service_factory() def CreateRequestHandler(self): self.rpc_mapper1 = self.mox.CreateMock(service_handlers.RPCMapper) self.rpc_mapper1.http_methods = set(['POST']) self.rpc_mapper1.content_types = set(['application/x-www-form-urlencoded']) self.rpc_mapper1.default_content_type = 'application/x-www-form-urlencoded' self.rpc_mapper2 = self.mox.CreateMock(service_handlers.RPCMapper) self.rpc_mapper2.http_methods = set(['GET']) self.rpc_mapper2.content_types = set(['application/json']) self.rpc_mapper2.default_content_type = 'application/json' self.factory = service_handlers.ServiceHandlerFactory( self.CreateService) self.factory.add_request_mapper(self.rpc_mapper1) self.factory.add_request_mapper(self.rpc_mapper2) return self.factory() def GetEnvironment(self): """Create handler to test.""" environ = super(ServiceHandlerTest, self).GetEnvironment() if self.remote_host: environ['REMOTE_HOST'] = self.remote_host if self.server_host: environ['SERVER_HOST'] = self.server_host return environ def VerifyResponse(self, *args, **kwargs): VerifyResponse(self, self.response, *args, **kwargs) def ExpectRpcError(self, mapper, state, error_message, error_name=None): mapper.build_response(self.handler, remote.RpcStatus(state=state, error_message=error_message, error_name=error_name)) def testRedirect(self): """Test that redirection is disabled.""" self.assertRaises(NotImplementedError, self.handler.redirect, '/') def testFirstMapper(self): """Make sure service attribute works when matches first RPCMapper.""" self.rpc_mapper1.build_request( self.handler, Request1).AndReturn(self.request) def build_response(handler, response): output = '%s %s %s' % (response.integer_field, response.string_field, response.enum_field) handler.response.headers['content-type'] = ( 'application/x-www-form-urlencoded') handler.response.out.write(output) self.rpc_mapper1.build_response( self.handler, mox.IsA(Response1)).WithSideEffects(build_response) self.mox.ReplayAll() self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('200', 'OK', '1 a VAL1') self.mox.VerifyAll() def testSecondMapper(self): """Make sure service attribute works when matches first RPCMapper. Demonstrates the multiplicity of the RPCMapper configuration. """ self.rpc_mapper2.build_request( self.handler, Request1).AndReturn(self.request) def build_response(handler, response): output = '%s %s %s' % (response.integer_field, response.string_field, response.enum_field) handler.response.headers['content-type'] = ( 'application/x-www-form-urlencoded') handler.response.out.write(output) self.rpc_mapper2.build_response( self.handler, mox.IsA(Response1)).WithSideEffects(build_response) self.mox.ReplayAll() self.handler.request.headers['Content-Type'] = 'application/json' self.handler.handle('GET', '/my_service', 'method1') self.VerifyResponse('200', 'OK', '1 a VAL1') self.mox.VerifyAll() def testCaseInsensitiveContentType(self): """Ensure that matching content-type is case insensitive.""" request = Request1() request.integer_field = 1 request.string_field = 'a' request.enum_field = Enum1.VAL1 self.rpc_mapper1.build_request(self.handler, Request1).AndReturn(self.request) def build_response(handler, response): output = '%s %s %s' % (response.integer_field, response.string_field, response.enum_field) handler.response.out.write(output) handler.response.headers['content-type'] = 'text/plain' self.rpc_mapper1.build_response( self.handler, mox.IsA(Response1)).WithSideEffects(build_response) self.mox.ReplayAll() self.handler.request.headers['Content-Type'] = ('ApPlIcAtIoN/' 'X-wWw-FoRm-UrLeNcOdEd') self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('200', 'OK', '1 a VAL1', 'text/plain') self.mox.VerifyAll() def testContentTypeWithParameters(self): """Test that content types have parameters parsed out.""" request = Request1() request.integer_field = 1 request.string_field = 'a' request.enum_field = Enum1.VAL1 self.rpc_mapper1.build_request(self.handler, Request1).AndReturn(self.request) def build_response(handler, response): output = '%s %s %s' % (response.integer_field, response.string_field, response.enum_field) handler.response.headers['content-type'] = ( 'application/x-www-form-urlencoded') handler.response.out.write(output) self.rpc_mapper1.build_response( self.handler, mox.IsA(Response1)).WithSideEffects(build_response) self.mox.ReplayAll() self.handler.request.headers['Content-Type'] = ('application/' 'x-www-form-urlencoded' + '; a=b; c=d') self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('200', 'OK', '1 a VAL1') self.mox.VerifyAll() def testContentFromHeaderOnly(self): """Test getting content-type from HTTP_CONTENT_TYPE directly. Some bad web server implementations might decide not to set CONTENT_TYPE for POST requests where there is an empty body. In these cases, need to get content-type directly from webob environ key HTTP_CONTENT_TYPE. """ request = Request1() request.integer_field = 1 request.string_field = 'a' request.enum_field = Enum1.VAL1 self.rpc_mapper1.build_request(self.handler, Request1).AndReturn(self.request) def build_response(handler, response): output = '%s %s %s' % (response.integer_field, response.string_field, response.enum_field) handler.response.headers['Content-Type'] = ( 'application/x-www-form-urlencoded') handler.response.out.write(output) self.rpc_mapper1.build_response( self.handler, mox.IsA(Response1)).WithSideEffects(build_response) self.mox.ReplayAll() self.handler.request.headers['Content-Type'] = None self.handler.request.environ['HTTP_CONTENT_TYPE'] = ( 'application/x-www-form-urlencoded') self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('200', 'OK', '1 a VAL1', 'application/x-www-form-urlencoded') self.mox.VerifyAll() def testRequestState(self): """Make sure request state is passed in to handler that supports it.""" class ServiceWithState(object): initialize_request_state = self.mox.CreateMockAnything() @remote.method(Request1, Response1) def method1(self, request): return Response1() self.service_factory = ServiceWithState # Reset handler with new service type. self.ResetRequestHandler() self.rpc_mapper1.build_request( self.handler, Request1).AndReturn(Request1()) def build_response(handler, response): handler.response.headers['Content-Type'] = ( 'application/x-www-form-urlencoded') handler.response.out.write('whatever') self.rpc_mapper1.build_response( self.handler, mox.IsA(Response1)).WithSideEffects(build_response) def verify_state(state): return ( 'remote.host.com' == state.remote_host and '127.0.0.1' == state.remote_address and 'server.host.com' == state.server_host and 8080 == state.server_port and 'POST' == state.http_method and '/my_service' == state.service_path and 'application/x-www-form-urlencoded' == state.headers['content-type'] and 'dev_appserver_login="test:test@example.com:True"' == state.headers['cookie']) ServiceWithState.initialize_request_state(mox.Func(verify_state)) self.mox.ReplayAll() self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('200', 'OK', 'whatever') self.mox.VerifyAll() def testRequestState_MissingHosts(self): """Make sure missing state environment values are handled gracefully.""" class ServiceWithState(object): initialize_request_state = self.mox.CreateMockAnything() @remote.method(Request1, Response1) def method1(self, request): return Response1() self.service_factory = ServiceWithState self.remote_host = None self.server_host = None # Reset handler with new service type. self.ResetRequestHandler() self.rpc_mapper1.build_request( self.handler, Request1).AndReturn(Request1()) def build_response(handler, response): handler.response.headers['Content-Type'] = ( 'application/x-www-form-urlencoded') handler.response.out.write('whatever') self.rpc_mapper1.build_response( self.handler, mox.IsA(Response1)).WithSideEffects(build_response) def verify_state(state): return (None is state.remote_host and '127.0.0.1' == state.remote_address and None is state.server_host and 8080 == state.server_port) ServiceWithState.initialize_request_state(mox.Func(verify_state)) self.mox.ReplayAll() self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('200', 'OK', 'whatever') self.mox.VerifyAll() def testNoMatch_UnknownHTTPMethod(self): """Test what happens when no RPCMapper matches.""" self.mox.ReplayAll() self.handler.handle('UNKNOWN', '/my_service', 'does_not_matter') self.VerifyResponse('405', 'Unsupported HTTP method: UNKNOWN', 'Method Not Allowed', 'text/plain; charset=utf-8') self.mox.VerifyAll() def testNoMatch_GetNotSupported(self): """Test what happens when GET is not supported.""" self.mox.ReplayAll() self.handler.handle('GET', '/my_service', 'method1') self.VerifyResponse('405', 'Method Not Allowed', '/my_service.method1 is a ProtoRPC method.\n\n' 'Service %s.Service\n\n' 'More about ProtoRPC: ' 'http://code.google.com/p/google-protorpc' % (__name__,), 'text/plain; charset=utf-8') self.mox.VerifyAll() def testNoMatch_UnknownContentType(self): """Test what happens when no RPCMapper matches.""" self.mox.ReplayAll() self.handler.request.headers['Content-Type'] = 'image/png' self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('415', 'Unsupported content-type: image/png', 'Unsupported Media Type', 'text/plain; charset=utf-8') self.mox.VerifyAll() def testNoMatch_NoContentType(self): """Test what happens when no RPCMapper matches..""" self.mox.ReplayAll() self.handler.request.environ.pop('HTTP_CONTENT_TYPE', None) self.handler.request.headers.pop('Content-Type', None) self.handler.handle('/my_service', 'POST', 'method1') self.VerifyResponse('400', 'Invalid RPC request: missing content-type', 'Bad Request', 'text/plain; charset=utf-8') self.mox.VerifyAll() def testNoSuchMethod(self): """When service method not found.""" self.ExpectRpcError(self.rpc_mapper1, remote.RpcState.METHOD_NOT_FOUND_ERROR, 'Unrecognized RPC method: no_such_method') self.mox.ReplayAll() self.handler.handle('POST', '/my_service', 'no_such_method') self.VerifyResponse('400', 'Unrecognized RPC method: no_such_method', '') self.mox.VerifyAll() def testNoSuchRemoteMethod(self): """When service method exists but is not remote.""" self.ExpectRpcError(self.rpc_mapper1, remote.RpcState.METHOD_NOT_FOUND_ERROR, 'Unrecognized RPC method: not_remote') self.mox.ReplayAll() self.handler.handle('POST', '/my_service', 'not_remote') self.VerifyResponse('400', 'Unrecognized RPC method: not_remote', '') self.mox.VerifyAll() def testRequestError(self): """RequestError handling.""" def build_request(handler, request): raise service_handlers.RequestError('This is a request error') self.rpc_mapper1.build_request( self.handler, Request1).WithSideEffects(build_request) self.ExpectRpcError(self.rpc_mapper1, remote.RpcState.REQUEST_ERROR, 'Error parsing ProtoRPC request ' '(This is a request error)') self.mox.ReplayAll() self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('400', 'Error parsing ProtoRPC request ' '(This is a request error)', '') self.mox.VerifyAll() def testDecodeError(self): """DecodeError handling.""" def build_request(handler, request): raise messages.DecodeError('This is a decode error') self.rpc_mapper1.build_request( self.handler, Request1).WithSideEffects(build_request) self.ExpectRpcError(self.rpc_mapper1, remote.RpcState.REQUEST_ERROR, r'Error parsing ProtoRPC request ' r'(This is a decode error)') self.mox.ReplayAll() self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('400', 'Error parsing ProtoRPC request ' '(This is a decode error)', '') self.mox.VerifyAll() def testResponseException(self): """Test what happens when build_response raises ResponseError.""" self.rpc_mapper1.build_request( self.handler, Request1).AndReturn(self.request) self.rpc_mapper1.build_response( self.handler, mox.IsA(Response1)).AndRaise( service_handlers.ResponseError) self.ExpectRpcError(self.rpc_mapper1, remote.RpcState.SERVER_ERROR, 'Internal Server Error') self.mox.ReplayAll() self.handler.handle('POST', '/my_service', 'method1') self.VerifyResponse('500', 'Internal Server Error', '') self.mox.VerifyAll() def testGet(self): """Test that GET goes to 'handle' properly.""" self.handler.handle = self.mox.CreateMockAnything() self.handler.handle('GET', '/my_service', 'method1') self.handler.handle('GET', '/my_other_service', 'method2') self.mox.ReplayAll() self.handler.get('/my_service', 'method1') self.handler.get('/my_other_service', 'method2') self.mox.VerifyAll() def testPost(self): """Test that POST goes to 'handle' properly.""" self.handler.handle = self.mox.CreateMockAnything() self.handler.handle('POST', '/my_service', 'method1') self.handler.handle('POST', '/my_other_service', 'method2') self.mox.ReplayAll() self.handler.post('/my_service', 'method1') self.handler.post('/my_other_service', 'method2') self.mox.VerifyAll() def testGetNoMethod(self): self.handler.get('/my_service', '') self.assertEquals(405, self.handler.response.status) self.assertEquals( util.pad_string('/my_service is a ProtoRPC service.\n\n' 'Service %s.Service\n\n' 'More about ProtoRPC: ' 'http://code.google.com/p/google-protorpc\n' % __name__), self.handler.response.out.getvalue()) self.assertEquals( 'nosniff', self.handler.response.headers['x-content-type-options']) def testGetNotSupported(self): self.handler.get('/my_service', 'method1') self.assertEquals(405, self.handler.response.status) expected_message = ('/my_service.method1 is a ProtoRPC method.\n\n' 'Service %s.Service\n\n' 'More about ProtoRPC: ' 'http://code.google.com/p/google-protorpc\n' % __name__) self.assertEquals(util.pad_string(expected_message), self.handler.response.out.getvalue()) self.assertEquals( 'nosniff', self.handler.response.headers['x-content-type-options']) def testGetUnknownContentType(self): self.handler.request.headers['content-type'] = 'image/png' self.handler.get('/my_service', 'method1') self.assertEquals(415, self.handler.response.status) self.assertEquals( util.pad_string('/my_service.method1 is a ProtoRPC method.\n\n' 'Service %s.Service\n\n' 'More about ProtoRPC: ' 'http://code.google.com/p/google-protorpc\n' % __name__), self.handler.response.out.getvalue()) self.assertEquals( 'nosniff', self.handler.response.headers['x-content-type-options']) class MissingContentLengthTests(ServiceHandlerTest): """Test for when content-length is not set in the environment. This test moves CONTENT_LENGTH from the environment to the content-length header. """ def GetEnvironment(self): environment = super(MissingContentLengthTests, self).GetEnvironment() content_length = str(environment.pop('CONTENT_LENGTH', '0')) environment['HTTP_CONTENT_LENGTH'] = content_length return environment class MissingContentTypeTests(ServiceHandlerTest): """Test for when content-type is not set in the environment. This test moves CONTENT_TYPE from the environment to the content-type header. """ def GetEnvironment(self): environment = super(MissingContentTypeTests, self).GetEnvironment() content_type = str(environment.pop('CONTENT_TYPE', '')) environment['HTTP_CONTENT_TYPE'] = content_type return environment class RPCMapperTestBase(test_util.TestCase): def setUp(self): """Set up test framework.""" self.Reinitialize() def Reinitialize(self, input='', get=False, path_method='method1', content_type='text/plain'): """Allows reinitialization of test with custom input values and POST. Args: input: Query string or POST input. get: Use GET method if True. Use POST if False. """ self.factory = service_handlers.ServiceHandlerFactory(Service) self.service_handler = service_handlers.ServiceHandler(self.factory, Service()) self.service_handler.remote_method = path_method request_path = '/servicepath' if path_method: request_path += '/' + path_method if get: request_path += '?' + input if get: environ = {'wsgi.input': cStringIO.StringIO(''), 'CONTENT_LENGTH': '0', 'QUERY_STRING': input, 'REQUEST_METHOD': 'GET', 'PATH_INFO': request_path, } self.service_handler.method = 'GET' else: environ = {'wsgi.input': cStringIO.StringIO(input), 'CONTENT_LENGTH': str(len(input)), 'QUERY_STRING': '', 'REQUEST_METHOD': 'POST', 'PATH_INFO': request_path, } self.service_handler.method = 'POST' self.request = webapp.Request(environ) self.response = webapp.Response() self.service_handler.initialize(self.request, self.response) self.service_handler.request.headers['Content-Type'] = content_type class RPCMapperTest(RPCMapperTestBase, webapp_test_util.RequestHandlerTestBase): """Test the RPCMapper base class.""" def setUp(self): RPCMapperTestBase.setUp(self) webapp_test_util.RequestHandlerTestBase.setUp(self) self.mox = mox.Mox() self.protocol = self.mox.CreateMockAnything() def GetEnvironment(self): """Get environment. Return bogus content in body. Returns: dict of CGI environment. """ environment = super(RPCMapperTest, self).GetEnvironment() environment['wsgi.input'] = cStringIO.StringIO('my body') environment['CONTENT_LENGTH'] = len('my body') return environment def testContentTypes_JustDefault(self): """Test content type attributes.""" self.mox.ReplayAll() mapper = service_handlers.RPCMapper(['GET', 'POST'], 'my-content-type', self.protocol) self.assertEquals(frozenset(['GET', 'POST']), mapper.http_methods) self.assertEquals('my-content-type', mapper.default_content_type) self.assertEquals(frozenset(['my-content-type']), mapper.content_types) self.mox.VerifyAll() def testContentTypes_Extended(self): """Test content type attributes.""" self.mox.ReplayAll() mapper = service_handlers.RPCMapper(['GET', 'POST'], 'my-content-type', self.protocol, content_types=['a', 'b']) self.assertEquals(frozenset(['GET', 'POST']), mapper.http_methods) self.assertEquals('my-content-type', mapper.default_content_type) self.assertEquals(frozenset(['my-content-type', 'a', 'b']), mapper.content_types) self.mox.VerifyAll() def testBuildRequest(self): """Test building a request.""" expected_request = Request1() self.protocol.decode_message(Request1, 'my body').AndReturn(expected_request) self.mox.ReplayAll() mapper = service_handlers.RPCMapper(['POST'], 'my-content-type', self.protocol) request = mapper.build_request(self.handler, Request1) self.assertTrue(expected_request is request) def testBuildRequest_ValidationError(self): """Test building a request generating a validation error.""" expected_request = Request1() self.protocol.decode_message( Request1, 'my body').AndRaise(messages.ValidationError('xyz')) self.mox.ReplayAll() mapper = service_handlers.RPCMapper(['POST'], 'my-content-type', self.protocol) self.assertRaisesWithRegexpMatch( service_handlers.RequestError, 'Unable to parse request content: xyz', mapper.build_request, self.handler, Request1) def testBuildRequest_DecodeError(self): """Test building a request generating a decode error.""" expected_request = Request1() self.protocol.decode_message( Request1, 'my body').AndRaise(messages.DecodeError('xyz')) self.mox.ReplayAll() mapper = service_handlers.RPCMapper(['POST'], 'my-content-type', self.protocol) self.assertRaisesWithRegexpMatch( service_handlers.RequestError, 'Unable to parse request content: xyz', mapper.build_request, self.handler, Request1) def testBuildResponse(self): """Test building a response.""" response = Response1() self.protocol.encode_message(response).AndReturn('encoded') self.mox.ReplayAll() mapper = service_handlers.RPCMapper(['POST'], 'my-content-type', self.protocol) request = mapper.build_response(self.handler, response) self.assertEquals('my-content-type', self.handler.response.headers['Content-Type']) self.assertEquals('encoded', self.handler.response.out.getvalue()) def testBuildResponse(self): """Test building a response.""" response = Response1() self.protocol.encode_message(response).AndRaise( messages.ValidationError('xyz')) self.mox.ReplayAll() mapper = service_handlers.RPCMapper(['POST'], 'my-content-type', self.protocol) self.assertRaisesWithRegexpMatch(service_handlers.ResponseError, 'Unable to encode message: xyz', mapper.build_response, self.handler, response) class ProtocolMapperTestBase(object): """Base class for basic protocol mapper tests.""" def setUp(self): """Reinitialize test specifically for protocol buffer mapper.""" super(ProtocolMapperTestBase, self).setUp() self.Reinitialize(path_method='my_method', content_type='application/x-google-protobuf') self.request_message = Request1() self.request_message.integer_field = 1 self.request_message.string_field = u'something' self.request_message.enum_field = Enum1.VAL1 self.response_message = Response1() self.response_message.integer_field = 1 self.response_message.string_field = u'something' self.response_message.enum_field = Enum1.VAL1 def testBuildRequest(self): """Test request building.""" self.Reinitialize(self.protocol.encode_message(self.request_message), content_type=self.content_type) mapper = self.mapper() parsed_request = mapper.build_request(self.service_handler, Request1) self.assertEquals(self.request_message, parsed_request) def testBuildResponse(self): """Test response building.""" mapper = self.mapper() mapper.build_response(self.service_handler, self.response_message) self.assertEquals(self.protocol.encode_message(self.response_message), self.service_handler.response.out.getvalue()) def testWholeRequest(self): """Test the basic flow of a request with mapper class.""" body = self.protocol.encode_message(self.request_message) self.Reinitialize(input=body, content_type=self.content_type) self.factory.add_request_mapper(self.mapper()) self.service_handler.handle('POST', '/my_service', 'method1') VerifyResponse(self, self.service_handler.response, '200', 'OK', self.protocol.encode_message(self.response_message), self.content_type) class URLEncodedRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase): """Test the URL encoded RPC mapper.""" content_type = 'application/x-www-form-urlencoded' protocol = protourlencode mapper = service_handlers.URLEncodedRPCMapper def testBuildRequest_Prefix(self): """Test building request with parameter prefix.""" self.Reinitialize(urllib.urlencode([('prefix_integer_field', '10'), ('prefix_string_field', 'a string'), ('prefix_enum_field', 'VAL1'), ]), self.content_type) url_encoded_mapper = service_handlers.URLEncodedRPCMapper( parameter_prefix='prefix_') request = url_encoded_mapper.build_request(self.service_handler, Request1) self.assertEquals(10, request.integer_field) self.assertEquals('a string', request.string_field) self.assertEquals(Enum1.VAL1, request.enum_field) def testBuildRequest_DecodeError(self): """Test trying to build request that causes a decode error.""" self.Reinitialize(urllib.urlencode((('integer_field', '10'), ('integer_field', '20'), )), content_type=self.content_type) url_encoded_mapper = service_handlers.URLEncodedRPCMapper() self.assertRaises(service_handlers.RequestError, url_encoded_mapper.build_request, self.service_handler, Service.method1.remote.request_type) def testBuildResponse_Prefix(self): """Test building a response with parameter prefix.""" response = Response1() response.integer_field = 10 response.string_field = u'a string' response.enum_field = Enum1.VAL3 url_encoded_mapper = service_handlers.URLEncodedRPCMapper( parameter_prefix='prefix_') url_encoded_mapper.build_response(self.service_handler, response) self.assertEquals('application/x-www-form-urlencoded', self.response.headers['content-type']) self.assertEquals(cgi.parse_qs(self.response.out.getvalue(), True, True), {'prefix_integer_field': ['10'], 'prefix_string_field': [u'a string'], 'prefix_enum_field': ['VAL3'], }) class ProtobufRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase): """Test the protobuf encoded RPC mapper.""" content_type = 'application/octet-stream' protocol = protobuf mapper = service_handlers.ProtobufRPCMapper class JSONRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase): """Test the URL encoded RPC mapper.""" content_type = 'application/json' protocol = protojson mapper = service_handlers.JSONRPCMapper class MyService(remote.Service): def __init__(self, value='default'): self.value = value class ServiceMappingTest(test_util.TestCase): def CheckFormMappings(self, mapping, registry_path='/protorpc'): """Check to make sure that form mapping is configured as expected. Args: mapping: Mapping that should contain forms handlers. """ pattern, factory = mapping[0] self.assertEquals('%s/form(?:/)?' % registry_path, pattern) handler = factory() self.assertTrue(isinstance(handler, forms.FormsHandler)) self.assertEquals(registry_path, handler.registry_path) pattern, factory = mapping[1] self.assertEquals('%s/form/(.+)' % registry_path, pattern) self.assertEquals(forms.ResourceHandler, factory) def DoMappingTest(self, services, registry_path='/myreg', expected_paths=None): mapped_services = mapping = service_handlers.service_mapping(services, registry_path) if registry_path: form_mapping = mapping[:2] mapped_registry_path, mapped_registry_factory = mapping[-1] mapped_services = mapping[2:-1] self.CheckFormMappings(form_mapping, registry_path=registry_path) self.assertEquals(r'(%s)%s' % (registry_path, service_handlers._METHOD_PATTERN), mapped_registry_path) self.assertEquals(registry.RegistryService, mapped_registry_factory.service_factory.service_class) # Verify registry knows about other services. expected_registry = {registry_path: registry.RegistryService} for path, factory in dict(services).iteritems(): if isinstance(factory, type) and issubclass(factory, remote.Service): expected_registry[path] = factory else: expected_registry[path] = factory.service_class self.assertEquals(expected_registry, mapped_registry_factory().service.registry) # Verify that services are mapped to URL. self.assertEquals(len(services), len(mapped_services)) for path, service in dict(services).iteritems(): mapped_path = r'(%s)%s' % (path, service_handlers._METHOD_PATTERN) mapped_factory = dict(mapped_services)[mapped_path] self.assertEquals(service, mapped_factory.service_factory) def testServiceMapping_Empty(self): """Test an empty service mapping.""" self.DoMappingTest({}) def testServiceMapping_ByClass(self): """Test mapping a service by class.""" self.DoMappingTest({'/my-service': MyService}) def testServiceMapping_ByFactory(self): """Test mapping a service by factory.""" self.DoMappingTest({'/my-service': MyService.new_factory('new-value')}) def testServiceMapping_ByList(self): """Test mapping a service by factory.""" self.DoMappingTest( [('/my-service1', MyService.new_factory('service1')), ('/my-service2', MyService.new_factory('service2')), ]) def testServiceMapping_NoRegistry(self): """Test mapping a service by class.""" mapping = self.DoMappingTest({'/my-service': MyService}, None) def testDefaultMappingWithClass(self): """Test setting path just from the class. Path of the mapping will be the fully qualified ProtoRPC service name with '.' replaced with '/'. For example: com.nowhere.service.TheService -> /com/nowhere/service/TheService """ mapping = service_handlers.service_mapping([MyService]) mapped_services = mapping[2:-1] self.assertEquals(1, len(mapped_services)) path, factory = mapped_services[0] self.assertEquals( r'(/test_package/MyService)' + service_handlers._METHOD_PATTERN, path) self.assertEquals(MyService, factory.service_factory) def testDefaultMappingWithFactory(self): mapping = service_handlers.service_mapping( [MyService.new_factory('service1')]) mapped_services = mapping[2:-1] self.assertEquals(1, len(mapped_services)) path, factory = mapped_services[0] self.assertEquals( r'(/test_package/MyService)' + service_handlers._METHOD_PATTERN, path) self.assertEquals(MyService, factory.service_factory.service_class) def testMappingDuplicateExplicitServiceName(self): self.assertRaisesWithRegexpMatch( service_handlers.ServiceConfigurationError, "Path '/my_path' is already defined in service mapping", service_handlers.service_mapping, [('/my_path', MyService), ('/my_path', MyService), ]) def testMappingDuplicateServiceName(self): self.assertRaisesWithRegexpMatch( service_handlers.ServiceConfigurationError, "Path '/test_package/MyService' is already defined in service mapping", service_handlers.service_mapping, [MyService, MyService]) class GetCalled(remote.Service): def __init__(self, test): self.test = test @remote.method(Request1, Response1) def my_method(self, request): self.test.request = request return Response1(string_field='a response') class TestRunServices(test_util.TestCase): def DoRequest(self, path, request, response_type, reg_path='/protorpc'): stdin = sys.stdin stdout = sys.stdout environ = os.environ try: sys.stdin = cStringIO.StringIO(protojson.encode_message(request)) sys.stdout = cStringIO.StringIO() os.environ = webapp_test_util.GetDefaultEnvironment() os.environ['PATH_INFO'] = path os.environ['REQUEST_METHOD'] = 'POST' os.environ['CONTENT_TYPE'] = 'application/json' os.environ['wsgi.input'] = sys.stdin os.environ['wsgi.output'] = sys.stdout os.environ['CONTENT_LENGTH'] = len(sys.stdin.getvalue()) service_handlers.run_services( [('/my_service', GetCalled.new_factory(self))], reg_path) header, body = sys.stdout.getvalue().split('\n\n', 1) return (header.split('\n')[0], protojson.decode_message(response_type, body)) finally: sys.stdin = stdin sys.stdout = stdout os.environ = environ def testRequest(self): request = Request1(string_field='request value') status, response = self.DoRequest('/my_service.my_method', request, Response1) self.assertEquals('Status: 200 OK', status) self.assertEquals(request, self.request) self.assertEquals(Response1(string_field='a response'), response) def testRegistry(self): request = Request1(string_field='request value') status, response = self.DoRequest('/protorpc.services', message_types.VoidMessage(), registry.ServicesResponse) self.assertEquals('Status: 200 OK', status) self.assertIterEqual([ registry.ServiceMapping( name='/protorpc', definition='protorpc.registry.RegistryService'), registry.ServiceMapping( name='/my_service', definition='test_package.GetCalled'), ], response.services) def testRunServicesWithOutRegistry(self): request = Request1(string_field='request value') status, response = self.DoRequest('/protorpc.services', message_types.VoidMessage(), registry.ServicesResponse, reg_path=None) self.assertEquals('Status: 404 Not Found', status) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/webapp_test_util.py0000755000076500000240000003011012277637135024277 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Testing utilities for the webapp libraries. GetDefaultEnvironment: Method for easily setting up CGI environment. RequestHandlerTestBase: Base class for setting up handler tests. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import cStringIO import threading import urllib2 from wsgiref import simple_server from wsgiref import validate from . import protojson from . import remote from . import test_util from . import transport from .webapp import service_handlers from .webapp.google_imports import webapp class TestService(remote.Service): """Service used to do end to end tests with.""" @remote.method(test_util.OptionalMessage, test_util.OptionalMessage) def optional_message(self, request): if request.string_value: request.string_value = '+%s' % request.string_value return request def GetDefaultEnvironment(): """Function for creating a default CGI environment.""" return { 'LC_NUMERIC': 'C', 'wsgi.multiprocess': True, 'SERVER_PROTOCOL': 'HTTP/1.0', 'SERVER_SOFTWARE': 'Dev AppServer 0.1', 'SCRIPT_NAME': '', 'LOGNAME': 'nickjohnson', 'USER': 'nickjohnson', 'QUERY_STRING': 'foo=bar&foo=baz&foo2=123', 'PATH': '/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/bin/X11', 'LANG': 'en_US', 'LANGUAGE': 'en', 'REMOTE_ADDR': '127.0.0.1', 'LC_MONETARY': 'C', 'CONTENT_TYPE': 'application/x-www-form-urlencoded', 'wsgi.url_scheme': 'http', 'SERVER_PORT': '8080', 'HOME': '/home/mruser', 'USERNAME': 'mruser', 'CONTENT_LENGTH': '', 'USER_IS_ADMIN': '1', 'PYTHONPATH': '/tmp/setup', 'LC_TIME': 'C', 'HTTP_USER_AGENT': 'Mozilla/5.0 (X11; U; Linux i686 (x86_64); en-US; ' 'rv:1.8.1.6) Gecko/20070725 Firefox/2.0.0.6', 'wsgi.multithread': False, 'wsgi.version': (1, 0), 'USER_EMAIL': 'test@example.com', 'USER_EMAIL': '112', 'wsgi.input': cStringIO.StringIO(), 'PATH_TRANSLATED': '/tmp/request.py', 'SERVER_NAME': 'localhost', 'GATEWAY_INTERFACE': 'CGI/1.1', 'wsgi.run_once': True, 'LC_COLLATE': 'C', 'HOSTNAME': 'myhost', 'wsgi.errors': cStringIO.StringIO(), 'PWD': '/tmp', 'REQUEST_METHOD': 'GET', 'MAIL': '/dev/null', 'MAILCHECK': '0', 'USER_NICKNAME': 'test', 'HTTP_COOKIE': 'dev_appserver_login="test:test@example.com:True"', 'PATH_INFO': '/tmp/myhandler' } class RequestHandlerTestBase(test_util.TestCase): """Base class for writing RequestHandler tests. To test a specific request handler override CreateRequestHandler. To change the environment for that handler override GetEnvironment. """ def setUp(self): """Set up test for request handler.""" self.ResetHandler() def GetEnvironment(self): """Get environment. Override for more specific configurations. Returns: dict of CGI environment. """ return GetDefaultEnvironment() def CreateRequestHandler(self): """Create RequestHandler instances. Override to create more specific kinds of RequestHandler instances. Returns: RequestHandler instance used in test. """ return webapp.RequestHandler() def CheckResponse(self, expected_status, expected_headers, expected_content): """Check that the web response is as expected. Args: expected_status: Expected status message. expected_headers: Dictionary of expected headers. Will ignore unexpected headers and only check the value of those expected. expected_content: Expected body. """ def check_content(content): self.assertEquals(expected_content, content) def start_response(status, headers): self.assertEquals(expected_status, status) found_keys = set() for name, value in headers: name = name.lower() try: expected_value = expected_headers[name] except KeyError: pass else: found_keys.add(name) self.assertEquals(expected_value, value) missing_headers = set(expected_headers.iterkeys()) - found_keys if missing_headers: self.fail('Expected keys %r not found' % (list(missing_headers),)) return check_content self.handler.response.wsgi_write(start_response) def ResetHandler(self, change_environ=None): """Reset this tests environment with environment changes. Resets the entire test with a new handler which includes some changes to the default request environment. Args: change_environ: Dictionary of values that are added to default environment. """ environment = self.GetEnvironment() environment.update(change_environ or {}) self.request = webapp.Request(environment) self.response = webapp.Response() self.handler = self.CreateRequestHandler() self.handler.initialize(self.request, self.response) class SyncedWSGIServer(simple_server.WSGIServer): pass class ServerThread(threading.Thread): """Thread responsible for managing wsgi server. This server does not just attach to the socket and listen for requests. This is because the server classes in Python 2.5 or less have no way to shut them down. Instead, the thread must be notified of how many requests it will receive so that it listens for each one individually. Tests should tell how many requests to listen for using the handle_request method. """ def __init__(self, server, *args, **kwargs): """Constructor. Args: server: The WSGI server that is served by this thread. As per threading.Thread base class. State: __serving: Server is still expected to be serving. When False server knows to shut itself down. """ self.server = server # This timeout is for the socket when a connection is made. self.server.socket.settimeout(None) # This timeout is for when waiting for a connection. The allows # server.handle_request() to listen for a short time, then timeout, # allowing the server to check for shutdown. self.server.timeout = 0.05 self.__serving = True super(ServerThread, self).__init__(*args, **kwargs) def shutdown(self): """Notify server that it must shutdown gracefully.""" self.__serving = False def run(self): """Handle incoming requests until shutdown.""" while self.__serving: self.server.handle_request() self.server = None class TestService(remote.Service): """Service used to do end to end tests with.""" def __init__(self, message='uninitialized'): self.__message = message @remote.method(test_util.OptionalMessage, test_util.OptionalMessage) def optional_message(self, request): if request.string_value: request.string_value = '+%s' % request.string_value return request @remote.method(response_type=test_util.OptionalMessage) def init_parameter(self, request): return test_util.OptionalMessage(string_value=self.__message) @remote.method(test_util.NestedMessage, test_util.NestedMessage) def nested_message(self, request): request.string_value = '+%s' % request.string_value return request @remote.method() def raise_application_error(self, request): raise remote.ApplicationError('This is an application error', 'ERROR_NAME') @remote.method() def raise_unexpected_error(self, request): raise TypeError('Unexpected error') @remote.method() def raise_rpc_error(self, request): raise remote.NetworkError('Uncaught network error') @remote.method(response_type=test_util.NestedMessage) def return_bad_message(self, request): return test_util.NestedMessage() class AlternateService(remote.Service): """Service used to requesting non-existant methods.""" @remote.method() def does_not_exist(self, request): raise NotImplementedError('Not implemented') class WebServerTestBase(test_util.TestCase): SERVICE_PATH = '/my/service' def setUp(self): self.server = None self.schema = 'http' self.ResetServer() self.bad_path_connection = self.CreateTransport(self.service_url + '_x') self.bad_path_stub = TestService.Stub(self.bad_path_connection) super(WebServerTestBase, self).setUp() def tearDown(self): self.server.shutdown() super(WebServerTestBase, self).tearDown() def ResetServer(self, application=None): """Reset web server. Shuts down existing server if necessary and starts a new one. Args: application: Optional WSGI function. If none provided will use tests CreateWsgiApplication method. """ if self.server: self.server.shutdown() self.port = test_util.pick_unused_port() self.server, self.application = self.StartWebServer(self.port, application) self.connection = self.CreateTransport(self.service_url) def CreateTransport(self, service_url, protocol=protojson): """Create a new transportation object.""" return transport.HttpTransport(service_url, protocol=protocol) def StartWebServer(self, port, application=None): """Start web server. Args: port: Port to start application on. application: Optional WSGI function. If none provided will use tests CreateWsgiApplication method. Returns: A tuple (server, application): server: An instance of ServerThread. application: Application that web server responds with. """ if not application: application = self.CreateWsgiApplication() validated_application = validate.validator(application) server = simple_server.make_server('localhost', port, validated_application) server = ServerThread(server) server.start() return server, application def make_service_url(self, path): """Make service URL using current schema and port.""" return '%s://localhost:%d%s' % (self.schema, self.port, path) @property def service_url(self): return self.make_service_url(self.SERVICE_PATH) class EndToEndTestBase(WebServerTestBase): # Sub-classes may override to create alternate configurations. DEFAULT_MAPPING = service_handlers.service_mapping( [('/my/service', TestService), ('/my/other_service', TestService.new_factory('initialized')), ]) def setUp(self): super(EndToEndTestBase, self).setUp() self.stub = TestService.Stub(self.connection) self.other_connection = self.CreateTransport(self.other_service_url) self.other_stub = TestService.Stub(self.other_connection) self.mismatched_stub = AlternateService.Stub(self.connection) @property def other_service_url(self): return 'http://localhost:%d/my/other_service' % self.port def CreateWsgiApplication(self): """Create WSGI application used on the server side for testing.""" return webapp.WSGIApplication(self.DEFAULT_MAPPING, True) def DoRawRequest(self, method, content='', content_type='application/json', headers=None): headers = headers or {} headers.update({'content-length': len(content or ''), 'content-type': content_type, }) request = urllib2.Request('%s.%s' % (self.service_url, method), content, headers) return urllib2.urlopen(request) def RawRequestError(self, method, content=None, content_type='application/json', headers=None): try: self.DoRawRequest(method, content, content_type, headers) self.fail('Expected HTTP error') except urllib2.HTTPError, err: return err.code, err.read(), err.headers protorpc-standalone-0.9.1/protorpc/wsgi/0000755000076500000240000000000012300027071021301 5ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc/wsgi/__init__.py0000755000076500000240000000113112277637135023436 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2011 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # protorpc-standalone-0.9.1/protorpc/wsgi/service.py0000755000076500000240000002361212300025203023315 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2011 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ProtoRPC WSGI service applications. Use functions in this module to configure ProtoRPC services for use with WSGI applications. For more information about WSGI, please see: http://wsgi.org/wsgi http://docs.python.org/library/wsgiref.html """ __author__ = 'rafek@google.com (Rafe Kaplan)' import cgi import httplib import logging import re from .. import messages from .. import registry from .. import remote from .. import util from . import util as wsgi_util __all__ = [ 'DEFAULT_REGISTRY_PATH', 'service_app', ] _METHOD_PATTERN = r'(?:\.([^?]+))' _REQUEST_PATH_PATTERN = r'^(%%s)%s$' % _METHOD_PATTERN _HTTP_BAD_REQUEST = wsgi_util.error(httplib.BAD_REQUEST) _HTTP_NOT_FOUND = wsgi_util.error(httplib.NOT_FOUND) _HTTP_UNSUPPORTED_MEDIA_TYPE = wsgi_util.error(httplib.UNSUPPORTED_MEDIA_TYPE) DEFAULT_REGISTRY_PATH = '/protorpc' @util.positional(2) def service_mapping(service_factory, service_path=r'.*', protocols=None): """WSGI application that handles a single ProtoRPC service mapping. Args: service_factory: Service factory for creating instances of service request handlers. Either callable that takes no parameters and returns a service instance or a service class whose constructor requires no parameters. service_path: Regular expression for matching requests against. Requests that do not have matching paths will cause a 404 (Not Found) response. protocols: remote.Protocols instance that configures supported protocols on server. """ service_class = getattr(service_factory, 'service_class', service_factory) remote_methods = service_class.all_remote_methods() path_matcher = re.compile(_REQUEST_PATH_PATTERN % service_path) def protorpc_service_app(environ, start_response): """Actual WSGI application function.""" path_match = path_matcher.match(environ['PATH_INFO']) if not path_match: return _HTTP_NOT_FOUND(environ, start_response) service_path = path_match.group(1) method_name = path_match.group(2) content_type = environ.get('CONTENT_TYPE') if not content_type: content_type = environ.get('HTTP_CONTENT_TYPE') if not content_type: return _HTTP_BAD_REQUEST(environ, start_response) # TODO(rafek): Handle alternate encodings. content_type = cgi.parse_header(content_type)[0] request_method = environ['REQUEST_METHOD'] if request_method != 'POST': content = ('%s.%s is a ProtoRPC method.\n\n' 'Service %s\n\n' 'More about ProtoRPC: ' '%s\n' % (service_path, method_name, service_class.definition_name().encode('utf-8'), util.PROTORPC_PROJECT_URL)) error_handler = wsgi_util.error( httplib.METHOD_NOT_ALLOWED, httplib.responses[httplib.METHOD_NOT_ALLOWED], content=content, content_type='text/plain; charset=utf-8') return error_handler(environ, start_response) local_protocols = protocols or remote.Protocols.get_default() try: protocol = local_protocols.lookup_by_content_type(content_type) except KeyError: return _HTTP_UNSUPPORTED_MEDIA_TYPE(environ,start_response) def send_rpc_error(status_code, state, message, error_name=None): """Helper function to send an RpcStatus message as response. Will create static error handler and begin response. Args: status_code: HTTP integer status code. state: remote.RpcState enum value to send as response. message: Helpful message to send in response. error_name: Error name if applicable. Returns: List containing encoded content response using the same content-type as the request. """ status = remote.RpcStatus(state=state, error_message=message, error_name=error_name) encoded_status = protocol.encode_message(status) error_handler = wsgi_util.error( status_code, content_type=protocol.default_content_type, content=encoded_status) return error_handler(environ, start_response) method = remote_methods.get(method_name) if not method: return send_rpc_error(httplib.BAD_REQUEST, remote.RpcState.METHOD_NOT_FOUND_ERROR, 'Unrecognized RPC method: %s' % method_name) content_length = int(environ.get('CONTENT_LENGTH') or '0') remote_info = method.remote try: request = protocol.decode_message( remote_info.request_type, environ['wsgi.input'].read(content_length)) except (messages.ValidationError, messages.DecodeError), err: return send_rpc_error(httplib.BAD_REQUEST, remote.RpcState.REQUEST_ERROR, 'Error parsing ProtoRPC request ' '(Unable to parse request content: %s)' % err) instance = service_factory() initialize_request_state = getattr( instance, 'initialize_request_state', None) if initialize_request_state: # TODO(rafek): This is not currently covered by tests. server_port = environ.get('SERVER_PORT', None) if server_port: server_port = int(server_port) headers = [] for name, value in environ.iteritems(): if name.startswith('HTTP_'): headers.append((name[len('HTTP_'):].lower().replace('_', '-'), value)) request_state = remote.HttpRequestState( remote_host=environ.get('REMOTE_HOST', None), remote_address=environ.get('REMOTE_ADDR', None), server_host=environ.get('SERVER_HOST', None), server_port=server_port, http_method=request_method, service_path=service_path, headers=headers) initialize_request_state(request_state) try: response = method(instance, request) encoded_response = protocol.encode_message(response) except remote.ApplicationError, err: return send_rpc_error(httplib.BAD_REQUEST, remote.RpcState.APPLICATION_ERROR, err.message, err.error_name) except Exception, err: logging.exception('Encountered unexpected error from ProtoRPC ' 'method implementation: %s (%s)' % (err.__class__.__name__, err)) return send_rpc_error(httplib.INTERNAL_SERVER_ERROR, remote.RpcState.SERVER_ERROR, 'Internal Server Error') response_headers = [('content-type', content_type)] start_response('%d %s' % (httplib.OK, httplib.responses[httplib.OK],), response_headers) return [encoded_response] # Return WSGI application. return protorpc_service_app @util.positional(1) def service_mappings(services, registry_path=DEFAULT_REGISTRY_PATH, service_prefix=None, append_wsgi_apps=None): """Create multiple service mappings with optional RegistryService. Use this function to create single WSGI application that maps to multiple ProtoRPC services plus an optional RegistryService. Example: services = service.service_mappings( [(r'/time', TimeService), (r'/weather', WeatherService) ]) In this example, the services WSGI application will map to two services, TimeService and WeatherService to the '/time' and '/weather' paths respectively. In addition, it will also add a ProtoRPC RegistryService configured to serve information about both services at the (default) path '/protorpc'. Args: services: If a dictionary is provided instead of a list of tuples, the dictionary item pairs are used as the mappings instead. Otherwise, a list of tuples (service_path, service_factory): service_path: The path to mount service on. service_factory: A service class or service instance factory. registry_path: A string to change where the registry is mapped (the default location is '/protorpc'). When None, no registry is created or mounted. service_prefix: Runs "first found" logic only when request paths begin with this prefix. When None, "first found" logic occurs for all requests. append_wsgi_apps: Additional WSGI apps to run. Returns: WSGI application that serves ProtoRPC services on their respective URLs plus optional RegistryService. """ if isinstance(services, dict): services = services.iteritems() final_mapping = [] paths = set() registry_map = {} if registry_path else None for service_path, service_factory in services: try: service_class = service_factory.service_class except AttributeError: service_class = service_factory if service_path not in paths: paths.add(service_path) else: raise remote.ServiceConfigurationError( 'Path %r is already defined in service mapping' % service_path.encode('utf-8')) if registry_map is not None: registry_map[service_path] = service_class final_mapping.append(service_mapping(service_factory, service_path)) if registry_map is not None: final_mapping.append(service_mapping( registry.RegistryService.new_factory(registry_map), registry_path)) if append_wsgi_apps is not None: final_mapping.extend(append_wsgi_apps) return wsgi_util.first_found(final_mapping, service_prefix=service_prefix) protorpc-standalone-0.9.1/protorpc/wsgi/service_test.py0000755000076500000240000001562012277637135024406 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2011 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """WSGI application tests.""" __author__ = 'rafek@google.com (Rafe Kaplan)' import unittest from protorpc import end2end_test from protorpc import protojson from protorpc import remote from protorpc import registry from protorpc import transport from protorpc import test_util from protorpc import webapp_test_util from protorpc.wsgi import service from protorpc.wsgi import util class ServiceMappingTest(end2end_test.EndToEndTest): def setUp(self): self.protocols = None remote.Protocols.set_default(remote.Protocols.new_default()) super(ServiceMappingTest, self).setUp() def CreateServices(self): return my_service, my_other_service def CreateWsgiApplication(self): """Create WSGI application used on the server side for testing.""" my_service = service.service_mapping(webapp_test_util.TestService, '/my/service') my_other_service = service.service_mapping( webapp_test_util.TestService.new_factory('initialized'), '/my/other_service', protocols=self.protocols) return util.first_found([my_service, my_other_service]) def testAlternateProtocols(self): self.protocols = remote.Protocols() self.protocols.add_protocol(protojson, 'altproto', 'image/png') global_protocols = remote.Protocols() global_protocols.add_protocol(protojson, 'server-side-name', 'image/png') remote.Protocols.set_default(global_protocols) self.ResetServer() self.connection = transport.HttpTransport( self.service_url, protocol=self.protocols.lookup_by_name('altproto')) self.stub = webapp_test_util.TestService.Stub(self.connection) self.stub.optional_message(string_value='alternate-protocol') def testAlwaysUseDefaults(self): new_protocols = remote.Protocols() new_protocols.add_protocol(protojson, 'altproto', 'image/png') self.connection = transport.HttpTransport( self.service_url, protocol=new_protocols.lookup_by_name('altproto')) self.stub = webapp_test_util.TestService.Stub(self.connection) self.assertRaisesWithRegexpMatch( remote.ServerError, 'HTTP Error 415: Unsupported Media Type', self.stub.optional_message, string_value='alternate-protocol') remote.Protocols.set_default(new_protocols) self.stub.optional_message(string_value='alternate-protocol') class ProtoServiceMappingsTest(ServiceMappingTest): def CreateWsgiApplication(self): """Create WSGI application used on the server side for testing.""" return service.service_mappings( [('/my/service', webapp_test_util.TestService), ('/my/other_service', webapp_test_util.TestService.new_factory('initialized')) ]) def GetRegistryStub(self, path='/protorpc'): service_url = self.make_service_url(path) transport = self.CreateTransport(service_url) return registry.RegistryService.Stub(transport) def testRegistry(self): registry_client = self.GetRegistryStub() response = registry_client.services() self.assertIterEqual([ registry.ServiceMapping( name='/my/other_service', definition='protorpc.webapp_test_util.TestService'), registry.ServiceMapping( name='/my/service', definition='protorpc.webapp_test_util.TestService'), ], response.services) def testRegistryDictionary(self): self.ResetServer(service.service_mappings( {'/my/service': webapp_test_util.TestService, '/my/other_service': webapp_test_util.TestService.new_factory('initialized'), })) registry_client = self.GetRegistryStub() response = registry_client.services() self.assertIterEqual([ registry.ServiceMapping( name='/my/other_service', definition='protorpc.webapp_test_util.TestService'), registry.ServiceMapping( name='/my/service', definition='protorpc.webapp_test_util.TestService'), ], response.services) def testNoRegistry(self): self.ResetServer(service.service_mappings( [('/my/service', webapp_test_util.TestService), ('/my/other_service', webapp_test_util.TestService.new_factory('initialized')) ], registry_path=None)) registry_client = self.GetRegistryStub() self.assertRaisesWithRegexpMatch( remote.ServerError, 'HTTP Error 404: Not Found', registry_client.services) def testAltRegistry(self): self.ResetServer(service.service_mappings( [('/my/service', webapp_test_util.TestService), ('/my/other_service', webapp_test_util.TestService.new_factory('initialized')) ], registry_path='/registry')) registry_client = self.GetRegistryStub('/registry') services = registry_client.services() self.assertTrue(isinstance(services, registry.ServicesResponse)) self.assertIterEqual( [registry.ServiceMapping( name='/my/other_service', definition='protorpc.webapp_test_util.TestService'), registry.ServiceMapping( name='/my/service', definition='protorpc.webapp_test_util.TestService'), ], services.services) def testDuplicateRegistryEntry(self): self.assertRaisesWithRegexpMatch( remote.ServiceConfigurationError, "Path '/my/service' is already defined in service mapping", service.service_mappings, [('/my/service', webapp_test_util.TestService), ('/my/service', webapp_test_util.TestService.new_factory('initialized')) ]) def testRegex(self): self.ResetServer(service.service_mappings( [('/my/[0-9]+', webapp_test_util.TestService.new_factory('service')), ('/my/[a-z]+', webapp_test_util.TestService.new_factory('other-service')), ])) my_service_url = 'http://localhost:%d/my/12345' % self.port my_other_service_url = 'http://localhost:%d/my/blarblar' % self.port my_service = webapp_test_util.TestService.Stub( transport.HttpTransport(my_service_url)) my_other_service = webapp_test_util.TestService.Stub( transport.HttpTransport(my_other_service_url)) response = my_service.init_parameter() self.assertEquals('service', response.string_value) response = my_other_service.init_parameter() self.assertEquals('other-service', response.string_value) def main(): unittest.main() if __name__ == '__main__': main() protorpc-standalone-0.9.1/protorpc/wsgi/util.py0000755000076500000240000001426012300025776022650 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2011 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """WSGI utilities Small collection of helpful utilities for working with WSGI. """ __author__ = 'rafek@google.com (Rafe Kaplan)' import httplib import re import webapp2 from .. import util __all__ = ['static_page', 'error', 'first_found', ] _STATUS_PATTERN = re.compile('^(\d{3})\s') @util.positional(1) def static_page(content='', status='200 OK', content_type='text/html; charset=utf-8', headers=None): """Create a WSGI application that serves static content. A static page is one that will be the same every time it receives a request. It will always serve the same status, content and headers. Args: content: Content to serve in response to HTTP request. status: Status to serve in response to HTTP request. If string, status is served as is without any error checking. If integer, will look up status message. Otherwise, parameter is tuple (status, description): status: Integer status of response. description: Brief text description of response. content_type: Convenient parameter for content-type header. Will appear before any content-type header that appears in 'headers' parameter. headers: Dictionary of headers or iterable of tuples (name, value): name: String name of header. value: String value of header. Returns: WSGI application that serves static content. """ if isinstance(status, (int, long)): status = '%d %s' % (status, httplib.responses.get(status, 'Unknown Error')) elif not isinstance(status, basestring): status = '%d %s' % tuple(status) if isinstance(headers, dict): headers = headers.iteritems() headers = [('content-length', str(len(content))), ('content-type', content_type), ] + list(headers or []) # Ensure all headers are str. for index, (key, value) in enumerate(headers): if isinstance(value, unicode): value = value.encode('utf-8') headers[index] = key, value if not isinstance(key, str): raise TypeError('Header key must be str, found: %r' % (key,)) if not isinstance(value, str): raise TypeError( 'Header %r must be type str or unicode, found: %r' % (key, value)) def static_page_application(environ, start_response): start_response(status, headers) return [content] return static_page_application @util.positional(2) def error(status_code, status_message=None, content_type='text/plain; charset=utf-8', headers=None, content=None): """Create WSGI application that statically serves an error page. Creates a static error page specifically for non-200 HTTP responses. Error pages that are not provided will content will contain the standard HTTP status message as their content. Args: status_code: Integer status code of error. status_message: Status message. Returns: Static WSGI application that sends static error response. """ if status_message is None: status_message = httplib.responses.get(status_code, 'Unknown Error') if content is None: content = status_message return static_page(content, status=(status_code, status_message), content_type=content_type, headers=headers) def first_found(apps, service_prefix=None): """Serve the first application that does not response with 404 Not Found. If no application serves content, will respond with generic 404 Not Found. Args: apps: List of WSGI applications to search through. Will serve the content of the first of these that does not return a 404 Not Found. Applications in this list must not modify the environment or any objects in it if they do not match. Applications that do not obey this restriction can create unpredictable results. service_prefix: Runs "first found" logic only when request paths begin with this prefix. When None, "first found" logic occurs for all requests. Returns: Compound application that serves the contents of the first application that does not response with 404 Not Found. """ apps = tuple(apps) not_found = error(httplib.NOT_FOUND) def first_found_app(environ, start_response): """Compound application returned from the first_found function.""" final_result = {} # Used in absence of Python local scoping. def first_found_start_response(status, response_headers): """Replacement for start_response as passed in to first_found_app. Called by each application in apps instead of the real start response. Checks the response status, and if anything other than 404, sets 'status' and 'response_headers' in final_result. """ status_match = _STATUS_PATTERN.match(status) assert status_match, ('Status must be a string beginning ' 'with 3 digit number. Found: %s' % status) status_code = status_match.group(0) if int(status_code) == httplib.NOT_FOUND: return final_result['status'] = status final_result['response_headers'] = response_headers if service_prefix is None or environ['PATH_INFO'].startswith(service_prefix): for app in apps: response = app(environ, first_found_start_response) if final_result: start_response(final_result['status'], final_result['response_headers']) return response else: for app in apps: if not isinstance(app, webapp2.WSGIApplication): continue response = app(environ, start_response) return response return not_found(environ, start_response) return first_found_app protorpc-standalone-0.9.1/protorpc/wsgi/util_test.py0000755000076500000240000002530412277637135023723 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2011 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """WSGI utility library tests.""" __author__ = 'rafe@google.com (Rafe Kaplan)' import httplib import unittest from protorpc import test_util from protorpc import util from protorpc import webapp_test_util from protorpc.wsgi import util as wsgi_util APP1 = wsgi_util.static_page('App1') APP2 = wsgi_util.static_page('App2') NOT_FOUND = wsgi_util.error(httplib.NOT_FOUND) class WsgiTestBase(webapp_test_util.WebServerTestBase): server_thread = None def CreateWsgiApplication(self): return None def DoHttpRequest(self, path='/', content=None, content_type='text/plain; charset=utf-8', headers=None): connection = httplib.HTTPConnection('localhost', self.port) if content is None: method = 'GET' else: method = 'POST' headers = {'content=type': content_type} headers.update(headers) connection.request(method, path, content, headers) response = connection.getresponse() not_date_or_server = lambda header: header[0] not in ('date', 'server') headers = filter(not_date_or_server, response.getheaders()) return response.status, response.reason, response.read(), dict(headers) class StaticPageBase(WsgiTestBase): def testDefault(self): default_page = wsgi_util.static_page() self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(200, status) self.assertEquals('OK', reason) self.assertEquals('', content) self.assertEquals({'content-length': '0', 'content-type': 'text/html; charset=utf-8', }, headers) def testHasContent(self): default_page = wsgi_util.static_page('my content') self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(200, status) self.assertEquals('OK', reason) self.assertEquals('my content', content) self.assertEquals({'content-length': str(len('my content')), 'content-type': 'text/html; charset=utf-8', }, headers) def testHasContentType(self): default_page = wsgi_util.static_page(content_type='text/plain') self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(200, status) self.assertEquals('OK', reason) self.assertEquals('', content) self.assertEquals({'content-length': '0', 'content-type': 'text/plain', }, headers) def testHasStatus(self): default_page = wsgi_util.static_page(status='400 Not Good Request') self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(400, status) self.assertEquals('Not Good Request', reason) self.assertEquals('', content) self.assertEquals({'content-length': '0', 'content-type': 'text/html; charset=utf-8', }, headers) def testHasStatusInt(self): default_page = wsgi_util.static_page(status=401) self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(401, status) self.assertEquals('Unauthorized', reason) self.assertEquals('', content) self.assertEquals({'content-length': '0', 'content-type': 'text/html; charset=utf-8', }, headers) def testHasStatusUnknown(self): default_page = wsgi_util.static_page(status=909) self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(909, status) self.assertEquals('Unknown Error', reason) self.assertEquals('', content) self.assertEquals({'content-length': '0', 'content-type': 'text/html; charset=utf-8', }, headers) def testHasStatusTuple(self): default_page = wsgi_util.static_page(status=(500, 'Bad Thing')) self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(500, status) self.assertEquals('Bad Thing', reason) self.assertEquals('', content) self.assertEquals({'content-length': '0', 'content-type': 'text/html; charset=utf-8', }, headers) def testHasHeaders(self): default_page = wsgi_util.static_page(headers=[('x', 'foo'), ('a', 'bar'), ('z', 'bin')]) self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(200, status) self.assertEquals('OK', reason) self.assertEquals('', content) self.assertEquals({'content-length': '0', 'content-type': 'text/html; charset=utf-8', 'x': 'foo', 'a': 'bar', 'z': 'bin', }, headers) def testHeadersUnicodeSafe(self): default_page = wsgi_util.static_page(headers=[('x', u'foo')]) self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(200, status) self.assertEquals('OK', reason) self.assertEquals('', content) self.assertEquals({'content-length': '0', 'content-type': 'text/html; charset=utf-8', 'x': 'foo', }, headers) self.assertTrue(isinstance(headers['x'], str)) def testHasHeadersDict(self): default_page = wsgi_util.static_page(headers={'x': 'foo', 'a': 'bar', 'z': 'bin'}) self.ResetServer(default_page) status, reason, content, headers = self.DoHttpRequest() self.assertEquals(200, status) self.assertEquals('OK', reason) self.assertEquals('', content) self.assertEquals({'content-length': '0', 'content-type': 'text/html; charset=utf-8', 'x': 'foo', 'a': 'bar', 'z': 'bin', }, headers) class FirstFoundTest(WsgiTestBase): def testEmptyConfiguration(self): self.ResetServer(wsgi_util.first_found([])) status, status_text, content, headers = self.DoHttpRequest('/') self.assertEquals(httplib.NOT_FOUND, status) self.assertEquals(httplib.responses[httplib.NOT_FOUND], status_text) self.assertEquals(util.pad_string(httplib.responses[httplib.NOT_FOUND]), content) self.assertEquals({'content-length': '512', 'content-type': 'text/plain; charset=utf-8', }, headers) def testOneApp(self): self.ResetServer(wsgi_util.first_found([APP1])) status, status_text, content, headers = self.DoHttpRequest('/') self.assertEquals(httplib.OK, status) self.assertEquals(httplib.responses[httplib.OK], status_text) self.assertEquals('App1', content) self.assertEquals({'content-length': '4', 'content-type': 'text/html; charset=utf-8', }, headers) def testIterator(self): self.ResetServer(wsgi_util.first_found(iter([APP1]))) status, status_text, content, headers = self.DoHttpRequest('/') self.assertEquals(httplib.OK, status) self.assertEquals(httplib.responses[httplib.OK], status_text) self.assertEquals('App1', content) self.assertEquals({'content-length': '4', 'content-type': 'text/html; charset=utf-8', }, headers) # Do request again to make sure iterator was properly copied. status, status_text, content, headers = self.DoHttpRequest('/') self.assertEquals(httplib.OK, status) self.assertEquals(httplib.responses[httplib.OK], status_text) self.assertEquals('App1', content) self.assertEquals({'content-length': '4', 'content-type': 'text/html; charset=utf-8', }, headers) def testTwoApps(self): self.ResetServer(wsgi_util.first_found([APP1, APP2])) status, status_text, content, headers = self.DoHttpRequest('/') self.assertEquals(httplib.OK, status) self.assertEquals(httplib.responses[httplib.OK], status_text) self.assertEquals('App1', content) self.assertEquals({'content-length': '4', 'content-type': 'text/html; charset=utf-8', }, headers) def testFirstNotFound(self): self.ResetServer(wsgi_util.first_found([NOT_FOUND, APP2])) status, status_text, content, headers = self.DoHttpRequest('/') self.assertEquals(httplib.OK, status) self.assertEquals(httplib.responses[httplib.OK], status_text) self.assertEquals('App2', content) self.assertEquals({'content-length': '4', 'content-type': 'text/html; charset=utf-8', }, headers) def testOnlyNotFound(self): def current_error(environ, start_response): """The variable current_status is defined in loop after ResetServer.""" headers = [('content-type', 'text/plain')] status_line = '%03d Whatever' % current_status start_response(status_line, headers) return [] self.ResetServer(wsgi_util.first_found([current_error, APP2])) statuses_to_check = sorted(httplib.responses.iterkeys()) # 100, 204 and 304 have slightly different expectations, so they are left # out of this test in order to keep the code simple. for dont_check in (100, 200, 204, 304, 404): statuses_to_check.remove(dont_check) for current_status in statuses_to_check: status, status_text, content, headers = self.DoHttpRequest('/') self.assertEquals(current_status, status) self.assertEquals('Whatever', status_text) if __name__ == '__main__': unittest.main() protorpc-standalone-0.9.1/protorpc_standalone.egg-info/0000755000076500000240000000000012300027071024232 5ustar jeremydwstaff00000000000000protorpc-standalone-0.9.1/protorpc_standalone.egg-info/dependency_links.txt0000644000076500000240000000000112300027071030300 0ustar jeremydwstaff00000000000000 protorpc-standalone-0.9.1/protorpc_standalone.egg-info/entry_points.txt0000644000076500000240000000006412300027071027530 0ustar jeremydwstaff00000000000000[console_scripts] gen_protorpc = gen_protorpc:main protorpc-standalone-0.9.1/protorpc_standalone.egg-info/PKG-INFO0000644000076500000240000000143412300027071025331 0ustar jeremydwstaff00000000000000Metadata-Version: 1.1 Name: protorpc-standalone Version: 0.9.1 Summary: Google Protocol RPC (modified to run outside Google App Engine) Home-page: https://github.com/jeremydw/protorpc-standalone Author: Google Inc. Author-email: rafek@google.com License: Apache 2.0 Description: UNKNOWN Keywords: google protocol rpc Platform: UNKNOWN Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Operating System :: MacOS :: MacOS X Classifier: Operating System :: Microsoft :: Windows Classifier: Operating System :: POSIX :: Linux Classifier: Programming Language :: Python :: 2.7 Classifier: Topic :: Software Development :: Libraries Classifier: Topic :: Software Development :: Libraries :: Python Modules Provides: protorpc (0.9.1) protorpc-standalone-0.9.1/protorpc_standalone.egg-info/SOURCES.txt0000644000076500000240000000324512300027071026122 0ustar jeremydwstaff00000000000000setup.py protorpc/__init__.py protorpc/definition.py protorpc/definition_test.py protorpc/descriptor.py protorpc/descriptor_test.py protorpc/end2end_test.py protorpc/generate.py protorpc/generate_proto.py protorpc/generate_proto_test.py protorpc/generate_python.py protorpc/generate_python_test.py protorpc/generate_test.py protorpc/google_imports.py protorpc/message_types.py protorpc/message_types_test.py protorpc/messages.py protorpc/messages_test.py protorpc/non_sdk_imports.py protorpc/protobuf.py protorpc/protobuf_test.py protorpc/protojson.py protorpc/protojson_test.py protorpc/protorpc_test_pb2.py protorpc/protourlencode.py protorpc/protourlencode_test.py protorpc/registry.py protorpc/registry_test.py protorpc/remote.py protorpc/remote_test.py protorpc/test_util.py protorpc/transport.py protorpc/transport_test.py protorpc/util.py protorpc/util_test.py protorpc/webapp_test_util.py protorpc/_google/__init__.py protorpc/_google/net/__init__.py protorpc/_google/net/proto/ProtocolBuffer.py protorpc/_google/net/proto/RawMessage.py protorpc/_google/net/proto/__init__.py protorpc/_google/net/proto/message_set.py protorpc/experimental/__init__.py protorpc/webapp/__init__.py protorpc/webapp/forms.py protorpc/webapp/forms_test.py protorpc/webapp/google_imports.py protorpc/webapp/service_handlers.py protorpc/webapp/service_handlers_test.py protorpc/wsgi/__init__.py protorpc/wsgi/service.py protorpc/wsgi/service_test.py protorpc/wsgi/util.py protorpc/wsgi/util_test.py protorpc_standalone.egg-info/PKG-INFO protorpc_standalone.egg-info/SOURCES.txt protorpc_standalone.egg-info/dependency_links.txt protorpc_standalone.egg-info/entry_points.txt protorpc_standalone.egg-info/top_level.txtprotorpc-standalone-0.9.1/protorpc_standalone.egg-info/top_level.txt0000644000076500000240000000001112300027071026754 0ustar jeremydwstaff00000000000000protorpc protorpc-standalone-0.9.1/setup.cfg0000644000076500000240000000007312300027071020301 0ustar jeremydwstaff00000000000000[egg_info] tag_build = tag_date = 0 tag_svn_revision = 0 protorpc-standalone-0.9.1/setup.py0000755000076500000240000000407412300026747020213 0ustar jeremydwstaff00000000000000#!/usr/bin/env python # # Copyright 2013 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Setup configuration.""" import platform import setuptools # Configure the required packages and scripts to install, depending on # Python version and OS. REQUIRED_PACKAGES = [] CONSOLE_SCRIPTS = [ 'gen_protorpc = gen_protorpc:main', ] py_version = platform.python_version() if py_version < '2.6': REQUIRED_PACKAGES.append('simplejson') _PROTORPC_VERSION = '0.9.1' setuptools.setup( name='protorpc-standalone', version=_PROTORPC_VERSION, description='Google Protocol RPC (modified to run outside Google App Engine)', url='https://github.com/jeremydw/protorpc-standalone', author='Google Inc.', author_email='rafek@google.com', # Contained modules and scripts. packages=setuptools.find_packages(), entry_points={ 'console_scripts': CONSOLE_SCRIPTS, }, install_requires=REQUIRED_PACKAGES, provides=[ 'protorpc (%s)' % (_PROTORPC_VERSION,), ], # PyPI package information. classifiers=[ 'Intended Audience :: Developers', 'License :: OSI Approved :: Apache Software License', 'Operating System :: MacOS :: MacOS X', 'Operating System :: Microsoft :: Windows', 'Operating System :: POSIX :: Linux', 'Programming Language :: Python :: 2.7', 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', ], license='Apache 2.0', keywords='google protocol rpc', )