calendarserver-5.2+dfsg/0000755000175000017500000000000012322625327014313 5ustar rahulrahulcalendarserver-5.2+dfsg/python0000755000175000017500000000015711615630672015570 0ustar rahulrahul#!/usr/bin/env bash wd="$(cd "$(dirname "$0")" && pwd)"; . "${wd}/support/shell.sh" exec "${python}" "$@"; calendarserver-5.2+dfsg/conf/0000755000175000017500000000000012322625306015235 5ustar rahulrahulcalendarserver-5.2+dfsg/conf/remoteservers.xml0000644000175000017500000000166512263343324020675 0ustar rahulrahul calendarserver-5.2+dfsg/conf/caldavd-partitioning-secondary.plist0000644000175000017500000000440012263343324024401 0ustar rahulrahul Servers Enabled ConfigFile localservers.xml MaxClients 5 ServerPartitionID 00002 ProxyDBService type twistedcaldav.directory.calendaruserproxy.ProxyPostgreSQLDB params host localhost database proxies Memcached Pools CommonToAllNodes ClientEnabled ServerEnabled BindAddress localhost Port 11311 HandleCacheTypes ProxyDB PrincipalToken DIGESTCREDENTIALS MaxClients 5 memcached ../memcached/_root/bin/memcached Options calendarserver-5.2+dfsg/conf/mime.types0000644000175000017500000003525410535615510017263 0ustar rahulrahul# This is a comment. I love comments. # This file controls what Internet media types are sent to the client for # given file extension(s). Sending the correct media type to the client # is important so they know how to handle the content of the file. # Extra types can either be added here or by using an AddType directive # in your config files. For more information about Internet media types, # please read RFC 2045, 2046, 2047, 2048, and 2077. The Internet media type # registry is at . # MIME type Extensions application/activemessage application/andrew-inset ez application/applefile application/atom+xml atom application/atomicmail application/batch-smtp application/beep+xml application/cals-1840 application/cnrp+xml application/commonground application/cpl+xml application/cybercash application/dca-rft application/dec-dx application/dvcs application/edi-consent application/edifact application/edi-x12 application/eshop application/font-tdpfr application/http application/hyperstudio application/iges application/index application/index.cmd application/index.obj application/index.response application/index.vnd application/iotp application/ipp application/isup application/mac-binhex40 hqx application/mac-compactpro cpt application/macwriteii application/marc application/mathematica application/mathml+xml mathml application/msword doc application/news-message-id application/news-transmission application/ocsp-request application/ocsp-response application/octet-stream bin dms lha lzh exe class so dll dmg application/oda oda application/ogg ogg application/parityfec application/pdf pdf application/pgp-encrypted application/pgp-keys application/pgp-signature application/pkcs10 application/pkcs7-mime application/pkcs7-signature application/pkix-cert application/pkix-crl application/pkixcmp application/postscript ai eps ps application/prs.alvestrand.titrax-sheet application/prs.cww application/prs.nprend application/prs.plucker application/qsig application/rdf+xml rdf application/reginfo+xml application/remote-printing application/riscos application/rtf application/sdp application/set-payment application/set-payment-initiation application/set-registration application/set-registration-initiation application/sgml application/sgml-open-catalog application/sieve application/slate application/smil smi smil application/srgs gram application/srgs+xml grxml application/timestamp-query application/timestamp-reply application/tve-trigger application/vemmi application/vnd.3gpp.pic-bw-large application/vnd.3gpp.pic-bw-small application/vnd.3gpp.pic-bw-var application/vnd.3gpp.sms application/vnd.3m.post-it-notes application/vnd.accpac.simply.aso application/vnd.accpac.simply.imp application/vnd.acucobol application/vnd.acucorp application/vnd.adobe.xfdf application/vnd.aether.imp application/vnd.amiga.ami application/vnd.anser-web-certificate-issue-initiation application/vnd.anser-web-funds-transfer-initiation application/vnd.audiograph application/vnd.blueice.multipass application/vnd.bmi application/vnd.businessobjects application/vnd.canon-cpdl application/vnd.canon-lips application/vnd.cinderella application/vnd.claymore application/vnd.commerce-battelle application/vnd.commonspace application/vnd.contact.cmsg application/vnd.cosmocaller application/vnd.criticaltools.wbs+xml application/vnd.ctc-posml application/vnd.cups-postscript application/vnd.cups-raster application/vnd.cups-raw application/vnd.curl application/vnd.cybank application/vnd.data-vision.rdz application/vnd.dna application/vnd.dpgraph application/vnd.dreamfactory application/vnd.dxr application/vnd.ecdis-update application/vnd.ecowin.chart application/vnd.ecowin.filerequest application/vnd.ecowin.fileupdate application/vnd.ecowin.series application/vnd.ecowin.seriesrequest application/vnd.ecowin.seriesupdate application/vnd.enliven application/vnd.epson.esf application/vnd.epson.msf application/vnd.epson.quickanime application/vnd.epson.salt application/vnd.epson.ssf application/vnd.ericsson.quickcall application/vnd.eudora.data application/vnd.fdf application/vnd.ffsns application/vnd.fints application/vnd.flographit application/vnd.framemaker application/vnd.fsc.weblaunch application/vnd.fujitsu.oasys application/vnd.fujitsu.oasys2 application/vnd.fujitsu.oasys3 application/vnd.fujitsu.oasysgp application/vnd.fujitsu.oasysprs application/vnd.fujixerox.ddd application/vnd.fujixerox.docuworks application/vnd.fujixerox.docuworks.binder application/vnd.fut-misnet application/vnd.grafeq application/vnd.groove-account application/vnd.groove-help application/vnd.groove-identity-message application/vnd.groove-injector application/vnd.groove-tool-message application/vnd.groove-tool-template application/vnd.groove-vcard application/vnd.hbci application/vnd.hhe.lesson-player application/vnd.hp-hpgl application/vnd.hp-hpid application/vnd.hp-hps application/vnd.hp-pcl application/vnd.hp-pclxl application/vnd.httphone application/vnd.hzn-3d-crossword application/vnd.ibm.afplinedata application/vnd.ibm.electronic-media application/vnd.ibm.minipay application/vnd.ibm.modcap application/vnd.ibm.rights-management application/vnd.ibm.secure-container application/vnd.informix-visionary application/vnd.intercon.formnet application/vnd.intertrust.digibox application/vnd.intertrust.nncp application/vnd.intu.qbo application/vnd.intu.qfx application/vnd.irepository.package+xml application/vnd.is-xpr application/vnd.japannet-directory-service application/vnd.japannet-jpnstore-wakeup application/vnd.japannet-payment-wakeup application/vnd.japannet-registration application/vnd.japannet-registration-wakeup application/vnd.japannet-setstore-wakeup application/vnd.japannet-verification application/vnd.japannet-verification-wakeup application/vnd.jisp application/vnd.kde.karbon application/vnd.kde.kchart application/vnd.kde.kformula application/vnd.kde.kivio application/vnd.kde.kontour application/vnd.kde.kpresenter application/vnd.kde.kspread application/vnd.kde.kword application/vnd.kenameaapp application/vnd.koan application/vnd.liberty-request+xml application/vnd.llamagraphics.life-balance.desktop application/vnd.llamagraphics.life-balance.exchange+xml application/vnd.lotus-1-2-3 application/vnd.lotus-approach application/vnd.lotus-freelance application/vnd.lotus-notes application/vnd.lotus-organizer application/vnd.lotus-screencam application/vnd.lotus-wordpro application/vnd.mcd application/vnd.mediastation.cdkey application/vnd.meridian-slingshot application/vnd.micrografx.flo application/vnd.micrografx.igx application/vnd.mif mif application/vnd.minisoft-hp3000-save application/vnd.mitsubishi.misty-guard.trustweb application/vnd.mobius.daf application/vnd.mobius.dis application/vnd.mobius.mbk application/vnd.mobius.mqy application/vnd.mobius.msl application/vnd.mobius.plc application/vnd.mobius.txf application/vnd.mophun.application application/vnd.mophun.certificate application/vnd.motorola.flexsuite application/vnd.motorola.flexsuite.adsi application/vnd.motorola.flexsuite.fis application/vnd.motorola.flexsuite.gotap application/vnd.motorola.flexsuite.kmr application/vnd.motorola.flexsuite.ttc application/vnd.motorola.flexsuite.wem application/vnd.mozilla.xul+xml xul application/vnd.ms-artgalry application/vnd.ms-asf application/vnd.ms-excel xls application/vnd.ms-lrm application/vnd.ms-powerpoint ppt application/vnd.ms-project application/vnd.ms-tnef application/vnd.ms-works application/vnd.ms-wpl application/vnd.mseq application/vnd.msign application/vnd.music-niff application/vnd.musician application/vnd.netfpx application/vnd.noblenet-directory application/vnd.noblenet-sealer application/vnd.noblenet-web application/vnd.novadigm.edm application/vnd.novadigm.edx application/vnd.novadigm.ext application/vnd.obn application/vnd.osa.netdeploy application/vnd.palm application/vnd.pg.format application/vnd.pg.osasli application/vnd.powerbuilder6 application/vnd.powerbuilder6-s application/vnd.powerbuilder7 application/vnd.powerbuilder7-s application/vnd.powerbuilder75 application/vnd.powerbuilder75-s application/vnd.previewsystems.box application/vnd.publishare-delta-tree application/vnd.pvi.ptid1 application/vnd.pwg-multiplexed application/vnd.pwg-xhtml-print+xml application/vnd.quark.quarkxpress application/vnd.rapid application/vnd.s3sms application/vnd.sealed.net application/vnd.seemail application/vnd.shana.informed.formdata application/vnd.shana.informed.formtemplate application/vnd.shana.informed.interchange application/vnd.shana.informed.package application/vnd.smaf application/vnd.sss-cod application/vnd.sss-dtf application/vnd.sss-ntf application/vnd.street-stream application/vnd.svd application/vnd.swiftview-ics application/vnd.triscape.mxs application/vnd.trueapp application/vnd.truedoc application/vnd.ufdl application/vnd.uplanet.alert application/vnd.uplanet.alert-wbxml application/vnd.uplanet.bearer-choice application/vnd.uplanet.bearer-choice-wbxml application/vnd.uplanet.cacheop application/vnd.uplanet.cacheop-wbxml application/vnd.uplanet.channel application/vnd.uplanet.channel-wbxml application/vnd.uplanet.list application/vnd.uplanet.list-wbxml application/vnd.uplanet.listcmd application/vnd.uplanet.listcmd-wbxml application/vnd.uplanet.signal application/vnd.vcx application/vnd.vectorworks application/vnd.vidsoft.vidconference application/vnd.visio application/vnd.visionary application/vnd.vividence.scriptfile application/vnd.vsf application/vnd.wap.sic application/vnd.wap.slc application/vnd.wap.wbxml wbxml application/vnd.wap.wmlc wmlc application/vnd.wap.wmlscriptc wmlsc application/vnd.webturbo application/vnd.wrq-hp3000-labelled application/vnd.wt.stf application/vnd.wv.csp+wbxml application/vnd.xara application/vnd.xfdl application/vnd.yamaha.hv-dic application/vnd.yamaha.hv-script application/vnd.yamaha.hv-voice application/vnd.yellowriver-custom-menu application/voicexml+xml vxml application/watcherinfo+xml application/whoispp-query application/whoispp-response application/wita application/wordperfect5.1 application/x-bcpio bcpio application/x-cdlink vcd application/x-chess-pgn pgn application/x-compress application/x-cpio cpio application/x-csh csh application/x-director dcr dir dxr application/x-dvi dvi application/x-futuresplash spl application/x-gtar gtar application/x-gzip application/x-hdf hdf application/x-javascript js application/x-koan skp skd skt skm application/x-latex latex application/x-netcdf nc cdf application/x-sh sh application/x-shar shar application/x-shockwave-flash swf application/x-stuffit sit application/x-sv4cpio sv4cpio application/x-sv4crc sv4crc application/x-tar tar application/x-tcl tcl application/x-tex tex application/x-texinfo texinfo texi application/x-troff t tr roff application/x-troff-man man application/x-troff-me me application/x-troff-ms ms application/x-ustar ustar application/x-wais-source src application/x400-bp application/xhtml+xml xhtml xht application/xslt+xml xslt application/xml xml xsl application/xml-dtd dtd application/xml-external-parsed-entity application/zip zip audio/32kadpcm audio/amr audio/amr-wb audio/basic au snd audio/cn audio/dat12 audio/dsr-es201108 audio/dvi4 audio/evrc audio/evrc0 audio/g722 audio/g.722.1 audio/g723 audio/g726-16 audio/g726-24 audio/g726-32 audio/g726-40 audio/g728 audio/g729 audio/g729D audio/g729E audio/gsm audio/gsm-efr audio/l8 audio/l16 audio/l20 audio/l24 audio/lpc audio/midi mid midi kar audio/mpa audio/mpa-robust audio/mp4a-latm audio/mpeg mpga mp2 mp3 audio/parityfec audio/pcma audio/pcmu audio/prs.sid audio/qcelp audio/red audio/smv audio/smv0 audio/telephone-event audio/tone audio/vdvi audio/vnd.3gpp.iufp audio/vnd.cisco.nse audio/vnd.cns.anp1 audio/vnd.cns.inf1 audio/vnd.digital-winds audio/vnd.everad.plj audio/vnd.lucent.voice audio/vnd.nortel.vbk audio/vnd.nuera.ecelp4800 audio/vnd.nuera.ecelp7470 audio/vnd.nuera.ecelp9600 audio/vnd.octel.sbc audio/vnd.qcelp audio/vnd.rhetorex.32kadpcm audio/vnd.vmx.cvsd audio/x-aiff aif aiff aifc audio/x-alaw-basic audio/x-mpegurl m3u audio/x-pn-realaudio ram ra audio/x-pn-realaudio-plugin application/vnd.rn-realmedia rm audio/x-wav wav chemical/x-pdb pdb chemical/x-xyz xyz image/bmp bmp image/cgm cgm image/g3fax image/gif gif image/ief ief image/jpeg jpeg jpg jpe image/naplps image/png png image/prs.btif image/prs.pti image/svg+xml svg image/t38 image/tiff tiff tif image/tiff-fx image/vnd.cns.inf2 image/vnd.djvu djvu djv image/vnd.dwg image/vnd.dxf image/vnd.fastbidsheet image/vnd.fpx image/vnd.fst image/vnd.fujixerox.edmics-mmr image/vnd.fujixerox.edmics-rlc image/vnd.globalgraphics.pgb image/vnd.mix image/vnd.ms-modi image/vnd.net-fpx image/vnd.svf image/vnd.wap.wbmp wbmp image/vnd.xiff image/x-cmu-raster ras image/x-icon ico image/x-portable-anymap pnm image/x-portable-bitmap pbm image/x-portable-graymap pgm image/x-portable-pixmap ppm image/x-rgb rgb image/x-xbitmap xbm image/x-xpixmap xpm image/x-xwindowdump xwd message/delivery-status message/disposition-notification message/external-body message/http message/news message/partial message/rfc822 message/s-http message/sip message/sipfrag model/iges igs iges model/mesh msh mesh silo model/vnd.dwf model/vnd.flatland.3dml model/vnd.gdl model/vnd.gs-gdl model/vnd.gtw model/vnd.mts model/vnd.parasolid.transmit.binary model/vnd.parasolid.transmit.text model/vnd.vtu model/vrml wrl vrml multipart/alternative multipart/appledouble multipart/byteranges multipart/digest multipart/encrypted multipart/form-data multipart/header-set multipart/mixed multipart/parallel multipart/related multipart/report multipart/signed multipart/voice-message text/calendar ics ifb text/css css text/directory text/enriched text/html html htm text/parityfec text/plain asc txt text/prs.lines.tag text/rfc822-headers text/richtext rtx text/rtf rtf text/sgml sgml sgm text/t140 text/tab-separated-values tsv text/uri-list text/vnd.abc text/vnd.curl text/vnd.dmclientscript text/vnd.fly text/vnd.fmi.flexstor text/vnd.in3d.3dml text/vnd.in3d.spot text/vnd.iptc.nitf text/vnd.iptc.newsml text/vnd.latex-z text/vnd.motorola.reflex text/vnd.ms-mediapackage text/vnd.net2phone.commcenter.command text/vnd.sun.j2me.app-descriptor text/vnd.wap.si text/vnd.wap.sl text/vnd.wap.wml wml text/vnd.wap.wmlscript wmls text/x-setext etx text/xml text/xml-external-parsed-entity video/bmpeg video/bt656 video/celb video/dv video/h261 video/h263 video/h263-1998 video/h263-2000 video/jpeg video/mp1s video/mp2p video/mp2t video/mp4v-es video/mpv video/mpeg mpeg mpg mpe video/nv video/parityfec video/pointer video/quicktime qt mov video/smpte292m video/vnd.fvt video/vnd.motorola.video video/vnd.motorola.videop video/vnd.mpegurl mxu m4u video/vnd.nokia.interleaved-multimedia video/vnd.objectvideo video/vnd.vivo video/x-msvideo avi video/x-sgi-movie movie x-conference/x-cooltalk ice calendarserver-5.2+dfsg/conf/servers.dtd0000644000175000017500000000166212263343324017431 0ustar rahulrahul calendarserver-5.2+dfsg/conf/test/0000755000175000017500000000000012322625306016214 5ustar rahulrahulcalendarserver-5.2+dfsg/conf/test/accounts.xml0000644000175000017500000001122512263343324020557 0ustar rahulrahul admin admin admin Super User apprentice apprentice Apprentice Super User apprentice wsanchez wsanchez wsanchez@example.com Wilfredo Sanchez Vega test cdaboo cdaboo cdaboo@example.com Cyrus Daboo test sagen sagen sagen@example.com Morgen Sagen test andre dre dre@example.com Andre LaBranche test glyph glyph glyph@example.com Glyph Lefkowitz test i18nuser i18nuser i18nuser@example.com まだ i18nuser user%02d user%02d User %02d User %02d user%02d@example.com user%02d public%02d public%02d Public %02d public%02d group01 group01 Group 01 group01 user01 group02 group02 Group 02 group02 user06 user07 group03 group03 Group 03 group03 user08 user09 group04 group04 Group 04 group04 group02 group03 user10 group05 group05 Group 05 group05 group06 user20 group06 group06 Group 06 group06 user21 group07 group07 Group 07 group07 user22 user23 user24 disabledgroup disabledgroup Disabled Group disabledgroup user01 calendarserver-5.2+dfsg/conf/auth/0000755000175000017500000000000012322625306016176 5ustar rahulrahulcalendarserver-5.2+dfsg/conf/auth/accounts.xml0000644000175000017500000000231712263343324020543 0ustar rahulrahul admin admin Super User test test Test User users users Users Group test mercury mercury Mecury Conference Room, Building 1, 2nd Floor calendarserver-5.2+dfsg/conf/auth/accounts.dtd0000644000175000017500000000300112263343324020505 0ustar rahulrahul > calendarserver-5.2+dfsg/conf/auth/resources-test.xml0000755000175000017500000001524412262624042021717 0ustar rahulrahul fantastic 4D66A20A-1437-437D-8069-2F14E8322234 Fantastic Conference Room 63A2F949-2D8D-4C8D-B8A5-DCF2A94610F3 jupiter jupiter Jupiter Conference Room, Building 2, 1st Floor uranus uranus Uranus Conference Room, Building 3, 1st Floor morgensroom 03DFF660-8BCC-4198-8588-DD77F776F518 Morgen's Room mercury mercury Mercury Conference Room, Building 1, 2nd Floor location09 location09 Room 09 location08 location08 Room 08 location07 location07 Room 07 location06 location06 Room 06 location05 location05 Room 05 location04 location04 Room 04 location03 location03 Room 03 location02 location02 Room 02 location01 location01 Room 01 delegatedroom delegatedroom Delegated Conference Room mars redplanet Mars Conference Room, Building 1, 1st Floor sharissroom 80689D41-DAF8-4189-909C-DB017B271892 Shari's Room 6F9EE33B-78F6-481B-9289-3D0812FF0D64 pluto pluto Pluto Conference Room, Building 2, 1st Floor saturn saturn Saturn Conference Room, Building 2, 1st Floor location10 location10 Room 10 pretend 06E3BDCB-9C19-485A-B14E-F146A80ADDC6 Pretend Conference Room 76E7ECA6-08BC-4AE7-930D-F2E7453993A5 neptune neptune Neptune Conference Room, Building 2, 1st Floor Earth Earth Earth Conference Room, Building 1, 1st Floor venus venus Venus Conference Room, Building 1, 2nd Floor sharisotherresource CCE95217-A57B-481A-AC3D-FEC9AB6CE3A9 Shari's Other Resource resource15 resource15 Resource 15 resource14 resource14 Resource 14 resource17 resource17 Resource 17 resource16 resource16 Resource 16 resource11 resource11 Resource 11 resource10 resource10 Resource 10 resource13 resource13 Resource 13 resource12 resource12 Resource 12 resource19 resource19 Resource 19 resource18 resource18 Resource 18 sharisresource C38BEE7A-36EE-478C-9DCB-CBF4612AFE65 Shari's Resource resource20 resource20 Resource 20 resource06 resource06 Resource 06 resource07 resource07 Resource 07 resource04 resource04 Resource 04 resource05 resource05 Resource 05 resource02 resource02 Resource 02 resource03 resource03 Resource 03 resource01 resource01 Resource 01 sharisotherresource1 0CE0BF31-5F9E-4801-A489-8C70CF287F5F Shari's Other Resource1 resource08 resource08 Resource 08 resource09 resource09 Resource 09
testaddress1 6F9EE33B-78F6-481B-9289-3D0812FF0D64 Test Address One 20300 Stevens Creek Blvd, Cupertino, CA 95014 37.322281,-122.028345
il2 63A2F949-2D8D-4C8D-B8A5-DCF2A94610F3 IL2 2 Infinite Loop, Cupertino, CA 95014 37.332633,-122.030502
il1 76E7ECA6-08BC-4AE7-930D-F2E7453993A5 IL1 1 Infinite Loop, Cupertino, CA 95014 37.331741,-122.030333
calendarserver-5.2+dfsg/conf/auth/augments.dtd0000644000175000017500000000242312263343324020520 0ustar rahulrahul > calendarserver-5.2+dfsg/conf/auth/proxies.dtd0000644000175000017500000000156512263343324020374 0ustar rahulrahul > calendarserver-5.2+dfsg/conf/auth/accounts-test.xml0000644000175000017500000001204712263343324021521 0ustar rahulrahul admin admin admin Super User Super User apprentice apprentice apprentice Apprentice Super User Apprentice Super User wsanchez wsanchez wsanchez@example.com test Wilfredo Sanchez Vega Wilfredo Sanchez Vega cdaboo cdaboo cdaboo@example.com test Cyrus Daboo Cyrus Daboo sagen sagen sagen@example.com test Morgen Sagen Morgen Sagen dre andre dre@example.com test Andre LaBranche Andre LaBranche glyph glyph glyph@example.com test Glyph Lefkowitz Glyph Lefkowitz i18nuser i18nuser i18nuser@example.com i18nuser まだ user%02d User %02d user%02d user%02d User %02d User %02d user%02d@example.com public%02d public%02d public%02d Public %02d Public %02d group01 group01 group01 Group 01 user01 group02 group02 group02 Group 02 user06 user07 group03 group03 group03 Group 03 user08 user09 group04 group04 group04 Group 04 group02 group03 user10 group05 group05 group05 Group 05 group06 user20 group06 group06 group06 Group 06 user21 group07 group07 group07 Group 07 user22 user23 user24 disabledgroup disabledgroup disabledgroup Disabled Group user01 calendarserver-5.2+dfsg/conf/auth/augments-default.xml0000644000175000017500000000154012263343324022166 0ustar rahulrahul Default true true true calendarserver-5.2+dfsg/conf/auth/proxies-test.xml0000644000175000017500000000210312263343324021363 0ustar rahulrahul resource%02d user01 user03 delegatedroom group05 group07 calendarserver-5.2+dfsg/conf/auth/augments-test.xml0000644000175000017500000001403112262624042021516 0ustar rahulrahul Default true true true location%02d true true true true resource%02d true true true true resource05 true true true true none resource06 true true true true accept-always resource07 true true true true decline-always resource08 true true true true accept-if-free resource09 true true true true decline-if-busy resource10 true true true true automatic resource11 true true true true decline-always group01 group%02d true disabledgroup false delegatedroom true true false false 03DFF660-8BCC-4198-8588-DD77F776F518 true true true true true 80689D41-DAF8-4189-909C-DB017B271892 true true true true true default C38BEE7A-36EE-478C-9DCB-CBF4612AFE65 true true true true true default group01 CCE95217-A57B-481A-AC3D-FEC9AB6CE3A9 true true true true true 0CE0BF31-5F9E-4801-A489-8C70CF287F5F true true true true true 6F9EE33B-78F6-481B-9289-3D0812FF0D64 true true true true false default 76E7ECA6-08BC-4AE7-930D-F2E7453993A5 true true true true false default 63A2F949-2D8D-4C8D-B8A5-DCF2A94610F3 true true true true false default 06E3BDCB-9C19-485A-B14E-F146A80ADDC6 true true true true true default 4D66A20A-1437-437D-8069-2F14E8322234 true true true true true default calendarserver-5.2+dfsg/conf/sudoers.plist0000644000175000017500000000166410550570476020015 0ustar rahulrahul users username superuser password superuser calendarserver-5.2+dfsg/conf/caldavd-test.plist0000644000175000017500000006103612263344114020672 0ustar rahulrahul ServerHostName localhost EnableCalDAV EnableCardDAV HTTPPort 8008 SSLPort 8443 EnableSSL RedirectHTTPToHTTPS BindAddresses BindHTTPPorts 8008 8800 BindSSLPorts 8443 8843 ServerRoot ./data DataRoot Data DatabaseRoot Database DocumentRoot Documents ConfigRoot ./conf RunRoot Logs/state Aliases UserQuota 104857600 MaxCollectionsPerHome 50 MaxResourcesPerCollection 10000 MaxResourceSize 1048576 MaxAttendeesPerInstance 100 MaxAllowedInstances 3000 DirectoryService type twistedcaldav.directory.xmlfile.XMLDirectoryService params xmlFile ./conf/auth/accounts-test.xml ResourceService Enabled type twistedcaldav.directory.xmlfile.XMLDirectoryService params xmlFile ./conf/auth/resources-test.xml AugmentService type twistedcaldav.directory.augment.AugmentXMLDB params xmlFiles ./conf/auth/augments-test.xml ProxyDBService type twistedcaldav.directory.calendaruserproxy.ProxySqliteDB params dbpath proxies.sqlite ProxyLoadFromFile ./conf/auth/proxies-test.xml AdminPrincipals /principals/__uids__/admin/ ReadPrincipals EnableProxyPrincipals EnableAnonymousReadRoot EnableAnonymousReadNav EnablePrincipalListings EnableMonolithicCalendars Authentication Basic Enabled AllowedOverWireUnencrypted Digest Enabled AllowedOverWireUnencrypted Algorithm md5 Qop Kerberos Enabled AllowedOverWireUnencrypted ServicePrincipal Wiki Enabled Cookie sessionID URL http://127.0.0.1/RPC2 UserMethod userForSession WikiMethod accessLevelForUserWikiCalendar LogRoot Logs AccessLogFile access.log RotateAccessLog ErrorLogFile error.log DefaultLogLevel info LogLevels PIDFile caldavd.pid AccountingCategories iTIP HTTP AccountingPrincipals SSLCertificate twistedcaldav/test/data/server.pem SSLAuthorityChain SSLPrivateKey twistedcaldav/test/data/server.pem UserName GroupName ProcessType Combined MultiProcess ProcessCount 2 Notifications CoalesceSeconds 3 Services AMP Enabled Port 62311 EnableStaggering StaggerSeconds 3 Scheduling CalDAV EmailDomain HTTPDomain AddressPatterns OldDraftCompatibility ScheduleTagCompatibility EnablePrivateComments iSchedule Enabled AddressPatterns RemoteServers remoteservers-test.xml iMIP Enabled MailGatewayServer localhost MailGatewayPort 62310 Sending Server Port 587 UseSSL Username Password Address SupressionDays 7 Receiving Server Port 995 Type UseSSL Username Password PollingSeconds 30 AddressPatterns mailto:.* Options AllowGroupAsOrganizer AllowLocationAsOrganizer AllowResourceAsOrganizer AttendeeRefreshBatch 0 AttendeeRefreshCountLimit 50 AutoSchedule Enabled Always DefaultMode automatic FreeBusyURL Enabled TimePeriod 14 AnonymousAccess EnableDropBox EnableManagedAttachments EnablePrivateEvents RemoveDuplicatePrivateComments EnableTimezoneService TimezoneService Enabled Mode primary BasePath XMLInfoPath SecondaryService Host URI UpdateIntervalMinutes 1440 UsePackageTimezones EnableBatchUpload Sharing Enabled AllowExternalUsers Calendars Enabled AddressBooks Enabled EnableSACLs EnableReadOnlyServer EnableWebAdmin ResponseCompression HTTPRetryAfter 180 ControlSocket caldavd.sock Memcached MaxClients 5 memcached memcached Options EnableResponseCache ResponseCacheTimeout 30 Postgres Options QueryCaching Enabled MemcachedPool Default ExpireSeconds 3600 GroupCaching Enabled EnableUpdater MemcachedPool Default UpdateSeconds 300 ExpireSeconds 3600 LockSeconds 300 UseExternalProxies MaxPrincipalSearchReportResults 500 Twisted twistd ../Twisted/bin/twistd Localization TranslationsDirectory locales LocalesDirectory locales Language en calendarserver-5.2+dfsg/conf/resources.xml0000644000175000017500000000131212263343324017767 0ustar rahulrahul calendarserver-5.2+dfsg/conf/remoteservers-test.xml0000644000175000017500000000165012263343324021644 0ustar rahulrahul https://localhost:8543/inbox example.org 127.0.0.1 calendarserver-5.2+dfsg/conf/localservers-test.xml0000644000175000017500000000174612263343324021451 0ustar rahulrahul 00001 http://localhost:8008 00001 http://localhost:8008 00002 http://localhost:8108 calendarserver-5.2+dfsg/conf/servertoserver.dtd0000644000175000017500000000214012263343324021030 0ustar rahulrahul calendarserver-5.2+dfsg/conf/resources/0000755000175000017500000000000012322625306017247 5ustar rahulrahulcalendarserver-5.2+dfsg/conf/resources/users-groups.xml0000644000175000017500000000542712263343324022460 0ustar rahulrahul admin admin admin Super User Super User apprentice apprentice apprentice Apprentice Super User Apprentice Super User user%02d User %02d user%02d user%02d User %02d User %02d user%02d@example.com public%02d public%02d public%02d Public %02d Public %02d group01 group01 group01 Group 01 user01 group02 group02 group02 Group 02 user06 user07 group03 group03 group03 Group 03 user08 user09 group04 group04 group04 Group 04 group02 group03 user10 disabledgroup disabledgroup disabledgroup Disabled Group user01 calendarserver-5.2+dfsg/conf/resources/caldavd-resources.plist0000644000175000017500000004306012263343324023736 0ustar rahulrahul ServerHostName HTTPPort 8008 SSLPort 8443 RedirectHTTPToHTTPS BindAddresses BindHTTPPorts BindSSLPorts DataRoot data/ DocumentRoot twistedcaldav/test/data/ Aliases UserQuota 104857600 MaximumAttachmentSize 1048576 MaxAttendeesPerInstance 100 MaxInstancesForRRULE 400 DirectoryService type twistedcaldav.directory.xmlfile.XMLDirectoryService params xmlFile conf/resources/users-groups.xml recordTypes users groups ResourceService Enabled type twistedcaldav.directory.xmlfile.XMLDirectoryService params xmlFile conf/resources/locations-resources.xml recordTypes locations resources AugmentService type twistedcaldav.directory.augment.AugmentXMLDB params xmlFiles conf/auth/augments-test.xml ProxyDBService type twistedcaldav.directory.calendaruserproxy.ProxySqliteDB params dbpath data/proxies.sqlite ProxyLoadFromFile conf/auth/proxies-test.xml AdminPrincipals /principals/__uids__/admin/ ReadPrincipals EnableProxyPrincipals EnableAnonymousReadRoot EnableAnonymousReadNav EnablePrincipalListings EnableMonolithicCalendars Authentication Basic Enabled Digest Enabled Algorithm md5 Qop Kerberos Enabled ServicePrincipal Wiki Enabled Cookie sessionID URL http://127.0.0.1/RPC2 UserMethod userForSession WikiMethod accessLevelForUserWikiCalendar AccessLogFile logs/access.log RotateAccessLog ErrorLogFile logs/error.log DefaultLogLevel info LogLevels ServerStatsFile logs/stats.plist PIDFile logs/caldavd.pid AccountingCategories iTIP HTTP AccountingPrincipals SSLCertificate twistedcaldav/test/data/server.pem SSLAuthorityChain SSLPrivateKey twistedcaldav/test/data/server.pem UserName GroupName ProcessType Combined MultiProcess ProcessCount 2 Notifications CoalesceSeconds 3 InternalNotificationHost localhost InternalNotificationPort 62309 Services SimpleLineNotifier Service twistedcaldav.notify.SimpleLineNotifierService Enabled Port 62308 XMPPNotifier Service twistedcaldav.notify.XMPPNotifierService Enabled Host xmpp.host.name Port 5222 JID jid@xmpp.host.name/resource Password password_goes_here ServiceAddress pubsub.xmpp.host.name NodeConfiguration pubsub#deliver_payloads 1 pubsub#persist_items 1 KeepAliveSeconds 120 HeartbeatMinutes 30 AllowedJIDs Scheduling CalDAV EmailDomain HTTPDomain AddressPatterns OldDraftCompatibility ScheduleTagCompatibility EnablePrivateComments iSchedule Enabled AddressPatterns Servers conf/servertoserver-test.xml iMIP Enabled MailGatewayServer localhost MailGatewayPort 62310 Sending Server Port 587 UseSSL Username Password Address Receiving Server Port 995 Type UseSSL Username Password PollingSeconds 30 AddressPatterns mailto:.* Options AllowGroupAsOrganizer AllowLocationAsOrganizer AllowResourceAsOrganizer FreeBusyURL Enabled TimePeriod 14 AnonymousAccess EnableDropBox EnablePrivateEvents EnableTimezoneService EnableSACLs EnableWebAdmin ResponseCompression HTTPRetryAfter 180 ControlSocket logs/caldavd.sock Memcached MaxClients 5 memcached memcached Options EnableResponseCache ResponseCacheTimeout 30 Twisted twistd ../Twisted/bin/twistd Localization LocalesDirectory locales Language English calendarserver-5.2+dfsg/conf/resources/locations-resources-orig.xml0000644000175000017500000000202512263343324024732 0ustar rahulrahul location%02d location%02d location%02d Room %02d resource%02d resource%02d resource%02d Resource %02d calendarserver-5.2+dfsg/conf/resources/locations-resources.xml0000644000175000017500000000202512263343324023774 0ustar rahulrahul location%02d location%02d location%02d Room %02d resource%02d resource%02d resource%02d Resource %02d calendarserver-5.2+dfsg/conf/caldavd-partitioning-primary.plist0000644000175000017500000000437712263343324024112 0ustar rahulrahul Servers Enabled ConfigFile localservers.xml MaxClients 5 ServerPartitionID 00001 ProxyDBService type twistedcaldav.directory.calendaruserproxy.ProxyPostgreSQLDB params host localhost database proxies Memcached Pools CommonToAllNodes ClientEnabled ServerEnabled BindAddress localhost Port 11311 HandleCacheTypes ProxyDB PrincipalToken DIGESTCREDENTIALS MaxClients 5 memcached ../memcached/_root/bin/memcached Options calendarserver-5.2+dfsg/conf/caldavd.plist0000644000175000017500000002707612263343324017725 0ustar rahulrahul ServerHostName HTTPPort 80 RedirectHTTPToHTTPS BindAddresses BindHTTPPorts BindSSLPorts ServerRoot /var/db/caldavd DataRoot Data DocumentRoot Documents ConfigRoot /etc/caldavd RunRoot /var/run Aliases UserQuota 104857600 MaxCollectionsPerHome 50 MaxResourcesPerCollection 10000 MaxResourceSize 1048576 MaxAttendeesPerInstance 100 MaxAllowedInstances 3000 DirectoryService type twistedcaldav.directory.xmlfile.XMLDirectoryService params xmlFile accounts.xml AdminPrincipals ReadPrincipals EnableProxyPrincipals EnableAnonymousReadRoot EnableAnonymousReadNav EnablePrincipalListings EnableMonolithicCalendars Authentication Basic Enabled Digest Enabled Algorithm md5 Qop Kerberos Enabled ServicePrincipal LogRoot /var/log/caldavd AccessLogFile access.log RotateAccessLog ErrorLogFile error.log DefaultLogLevel warn PIDFile caldavd.pid SSLCertificate SSLAuthorityChain SSLPrivateKey UserName daemon GroupName daemon ProcessType Combined MultiProcess ProcessCount 0 Notifications CoalesceSeconds 3 Services XMPPNotifier Service twistedcaldav.notify.XMPPNotifierService Enabled Host xmpp.host.name Port 5222 JID jid@xmpp.host.name/resource Password password_goes_here ServiceAddress pubsub.xmpp.host.name Scheduling CalDAV EmailDomain HTTPDomain AddressPatterns iSchedule Enabled AddressPatterns RemoteServers remoteservers.xml iMIP Enabled MailGatewayServer localhost MailGatewayPort 62310 Sending Server Port 587 UseSSL Username Password Address Receiving Server Port 995 Type UseSSL Username Password PollingSeconds 30 AddressPatterns mailto:.* FreeBusyURL Enabled TimePeriod 14 AnonymousAccess EnablePrivateEvents Sharing Enabled EnableWebAdmin calendarserver-5.2+dfsg/conf/caldavd-apple.plist0000644000175000017500000003220012263344251021005 0ustar rahulrahul ServerHostName EnableCalDAV EnableCardDAV HTTPPort 8008 SSLPort 8443 EnableSSL RedirectHTTPToHTTPS BindAddresses BindHTTPPorts 8008 8800 BindSSLPorts 8443 8843 ServerRoot /Library/Server/Calendar and Contacts DBType DSN DBImportFile /Library/Server/Calendar and Contacts/DataDump.sql Postgres Ctl xpg_ctl Options -c log_lock_waits=TRUE -c deadlock_timeout=10 -c log_line_prefix='%m [%p] ' -c logging_collector=on -c log_truncate_on_rotation=on -c log_directory=/var/log/caldavd/postgresql -c log_filename=postgresql_%w.log -c log_rotation_age=1440 ExtraConnections 20 ClusterName cluster.pg LogFile xpg_ctl.log SocketDirectory /var/run/caldavd/PostgresSocket DataRoot Data DatabaseRoot Database.xpg DocumentRoot Documents ConfigRoot Config RunRoot /var/run/caldavd Aliases UserQuota 104857600 MaxCollectionsPerHome 50 MaxResourcesPerCollection 10000 MaxResourceSize 1048576 MaxAttendeesPerInstance 100 MaxAllowedInstances 3000 DirectoryService type twistedcaldav.directory.appleopendirectory.OpenDirectoryService params node /Search AdminPrincipals ReadPrincipals EnableProxyPrincipals EnableAnonymousReadRoot EnableAnonymousReadNav EnablePrincipalListings EnableMonolithicCalendars Authentication Basic Enabled Digest Enabled Algorithm md5 Qop Kerberos Enabled ServicePrincipal Wiki Enabled LogRoot /var/log/caldavd AccessLogFile access.log RotateAccessLog ErrorLogFile error.log DefaultLogLevel warn PIDFile caldavd.pid SSLCertificate SSLAuthorityChain SSLPrivateKey UserName calendar GroupName calendar ProcessType Combined MultiProcess ProcessCount 0 Notifications CoalesceSeconds 3 Services Scheduling CalDAV EmailDomain HTTPDomain AddressPatterns iSchedule Enabled AddressPatterns RemoteServers remoteservers.xml iMIP Enabled MailGatewayServer localhost MailGatewayPort 62310 Sending Server Port 587 UseSSL Username Password Address Receiving Server Port 995 Type UseSSL Username Password PollingSeconds 30 AddressPatterns mailto:.* FreeBusyURL Enabled TimePeriod 14 AnonymousAccess EnableDropBox EnableManagedAttachments EnablePrivateEvents EnableTimezoneService Sharing Enabled EnableSACLs EnableWebAdmin WebCalendarAuthPath /auth DirectoryAddressBook Enabled params queryUserRecords queryPeopleRecords EnableSearchAddressBook Includes /Library/Server/Calendar and Contacts/Config/caldavd-system.plist /Library/Server/Calendar and Contacts/Config/caldavd-user.plist WritableConfigFile /Library/Server/Calendar and Contacts/Config/caldavd-system.plist calendarserver-5.2+dfsg/conf/localservers.xml0000644000175000017500000000235112263343324020465 0ustar rahulrahul calendarserver-5.2+dfsg/twext/0000755000175000017500000000000012322625326015465 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/python/0000755000175000017500000000000012322625326017006 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/python/filepath.py0000644000175000017500000001050212263343324021151 0ustar rahulrahul# -*- test-case-name: twext.python.test.test_filepath -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extend L{twisted.python.filepath} to provide performance enhancements for calendar server. """ from os import listdir as _listdir from os.path import (join as _joinpath, basename as _basename, exists as _exists, dirname as _dirname) from time import sleep as _sleep from types import FunctionType, MethodType from errno import EINVAL from twisted.python.filepath import FilePath as _FilePath from stat import S_ISDIR class CachingFilePath(_FilePath, object): """ A descendent of L{_FilePath} which implements a more aggressive caching policy. """ _listdir = _listdir # integration points for tests _sleep = _sleep BACKOFF_MAX = 5.0 # Maximum time to wait between calls to # listdir() def __init__(self, path, alwaysCreate=False): super(CachingFilePath, self).__init__(path, alwaysCreate) self.existsCached = None self.isDirCached = None @property def siblingExtensionSearch(self): """ Dynamically create a version of L{_FilePath.siblingExtensionSearch} that uses a pluggable 'listdir' implementation. """ return MethodType(FunctionType( _FilePath.siblingExtensionSearch.im_func.func_code, {'listdir': self._retryListdir, 'basename': _basename, 'dirname': _dirname, 'joinpath': _joinpath, 'exists': _exists}), self, self.__class__) def changed(self): """ This path may have changed in the filesystem, so forget all cached information about it. """ self.statinfo = None self.existsCached = None self.isDirCached = None def _retryListdir(self, pathname): """ Implementation of retry logic for C{listdir} and C{siblingExtensionSearch}. """ delay = 0.1 while True: try: return self._listdir(pathname) except OSError, e: if e.errno == EINVAL: self._sleep(delay) delay = min(self.BACKOFF_MAX, delay * 2.0) else: raise raise RuntimeError("unreachable code.") def listdir(self): """ List the directory which C{self.path} points to, compensating for EINVAL from C{os.listdir}. """ return self._retryListdir(self.path) def restat(self, reraise=True): """ Re-cache stat information. """ try: return super(CachingFilePath, self).restat(reraise) finally: if self.statinfo: self.existsCached = True self.isDirCached = S_ISDIR(self.statinfo.st_mode) else: self.existsCached = False self.isDirCached = None def moveTo(self, destination, followLinks=True): """ Override L{_FilePath.moveTo}, updating extended cache information if necessary. """ result = super(CachingFilePath, self).moveTo(destination, followLinks) self.changed() # Work with vanilla FilePath destinations to pacify the tests. if hasattr(destination, "changed"): destination.changed() return result def remove(self): """ Override L{_FilePath.remove}, updating extended cache information if necessary. """ try: return super(CachingFilePath, self).remove() finally: self.changed() CachingFilePath.clonePath = CachingFilePath __all__ = ["CachingFilePath"] calendarserver-5.2+dfsg/twext/python/parallel.py0000644000175000017500000000632412263343324021160 0ustar rahulrahul# -*- test-case-name: twext.python.test.test_parallel -*- ## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Utilities for parallelizing tasks. """ from twisted.internet.defer import inlineCallbacks, DeferredList, returnValue class Parallelizer(object): """ Do some operation with a degree of parallelism, using a set of resources which may each only be used for one task at a time, given some underlying API that returns L{Deferreds}. @ivar available: A list of available resources from the C{resources} constructor parameter. @ivar busy: A list of resources which are currently being used by operations. """ def __init__(self, resources): """ Initialize a L{Parallelizer} with a list of objects that will be passed to the callables sent to L{Parallelizer.do}. @param resources: objects which may be of any arbitrary type. @type resources: C{list} """ self.available = list(resources) self.busy = [] self.activeDeferreds = [] @inlineCallbacks def do(self, operation): """ Call C{operation} with one of the resources in C{self.available}, removing that value for use by other callers of C{do} until the task performed by C{operation} is complete (in other words, the L{Deferred} returned by C{operation} has fired). @param operation: a 1-argument callable taking a resource from C{self.active} and returning a L{Deferred} when it's done using that resource. @type operation: C{callable} @return: a L{Deferred} that fires as soon as there are resources available such that this task can be I{started} - not completed. """ if not self.available: yield DeferredList(self.activeDeferreds, fireOnOneCallback=True, fireOnOneErrback=True) active = self.available.pop(0) self.busy.append(active) o = operation(active) def andFinally(whatever): self.activeDeferreds.remove(o) self.busy.remove(active) self.available.append(active) return whatever self.activeDeferreds.append(o) o.addBoth(andFinally) returnValue(None) def done(self): """ Wait until all operations started by L{Parallelizer.do} are completed. @return: a L{Deferred} that fires (with C{None}) when all the currently pending work on this L{Parallelizer} is completed and C{busy} is empty again. """ return (DeferredList(self.activeDeferreds) .addCallback(lambda ignored: None)) calendarserver-5.2+dfsg/twext/python/timezone.py0000644000175000017500000000307512263343324021216 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twistedcaldav.config import config import twistedcaldav.timezones DEFAULT_TIMEZONE = "America/Los_Angeles" try: from Foundation import NSTimeZone def lookupSystemTimezone(): return NSTimeZone.localTimeZone().name().encode("utf-8") except: def lookupSystemTimezone(): return "" def getLocalTimezone(): """ Returns the default timezone for the server. The order of precedence is: config.DefaultTimezone, lookupSystemTimezone( ), DEFAULT_TIMEZONE. Also, if neither of the first two values in that list are in the timezone database, DEFAULT_TIMEZONE is returned. @return: The server's local timezone name @rtype: C{str} """ if config.DefaultTimezone: if twistedcaldav.timezones.hasTZ(config.DefaultTimezone): return config.DefaultTimezone systemTimezone = lookupSystemTimezone() if twistedcaldav.timezones.hasTZ(systemTimezone): return systemTimezone return DEFAULT_TIMEZONE calendarserver-5.2+dfsg/twext/python/test/0000755000175000017500000000000012322625326017765 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/python/test/test_filepath.py0000644000175000017500000001257612263343324023204 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for specialized behavior of L{CachingFilePath} """ from errno import EINVAL from os.path import join as pathjoin from twisted.internet.task import Clock from twisted.trial.unittest import TestCase from twext.python.filepath import CachingFilePath # Cheat and pull in the Twisted test cases for FilePath. XXX: Twisteds should # provide a supported way of doing this for exported interfaces. Also, it # should export IFilePath. --glyph from twisted.test.test_paths import FilePathTestCase class BaseVerification(FilePathTestCase): """ Make sure that L{CachingFilePath} doesn't break the contracts that L{FilePath} tries to provide. """ def setUp(self): """ Set up the test case to set the base attributes to point at L{AbstractFilePathTestCase}. """ FilePathTestCase.setUp(self) self.root = CachingFilePath(self.root.path) self.path = CachingFilePath(self.path.path) class EINVALTestCase(TestCase): """ Sometimes, L{os.listdir} will raise C{EINVAL}. This is a transient error, and L{CachingFilePath.listdir} should work around it by retrying the C{listdir} operation until it succeeds. """ def setUp(self): """ Create a L{CachingFilePath} for the test to use. """ self.cfp = CachingFilePath(self.mktemp()) self.clock = Clock() self.cfp._sleep = self.clock.advance def test_testValidity(self): """ If C{listdir} is replaced on a L{CachingFilePath}, we should be able to observe exceptions raised by the replacement. This verifies that the test patching done here is actually testing something. """ class CustomException(Exception): "Just for testing." def blowUp(dirname): raise CustomException() self.cfp._listdir = blowUp self.assertRaises(CustomException, self.cfp.listdir) self.assertRaises(CustomException, self.cfp.children) def test_retryLoop(self): """ L{CachingFilePath} should catch C{EINVAL} and respond by retrying the C{listdir} operation until it succeeds. """ calls = [] def raiseEINVAL(dirname): calls.append(dirname) if len(calls) < 5: raise OSError(EINVAL, "This should be caught by the test.") return ['a', 'b', 'c'] self.cfp._listdir = raiseEINVAL self.assertEquals(self.cfp.listdir(), ['a', 'b', 'c']) self.assertEquals(self.cfp.children(), [ CachingFilePath(pathjoin(self.cfp.path, 'a')), CachingFilePath(pathjoin(self.cfp.path, 'b')), CachingFilePath(pathjoin(self.cfp.path, 'c')),]) def requireTimePassed(self, filenames): """ Create a replacement for listdir() which only fires after a certain amount of time. """ self.calls = [] def thunk(dirname): now = self.clock.seconds() if now < 20.0: self.calls.append(now) raise OSError(EINVAL, "Not enough time has passed yet.") else: return filenames self.cfp._listdir = thunk def assertRequiredTimePassed(self): """ Assert that calls to the simulated time.sleep() installed by C{requireTimePassed} have been invoked the required number of times. """ # Waiting should be growing by *2 each time until the additional wait # exceeds BACKOFF_MAX (5), at which point we should wait for 5s each # time. def cumulative(values): current = 0.0 for value in values: current += value yield current self.assertEquals(self.calls, list(cumulative( [0.0, 0.1, 0.2, 0.4, 0.8, 1.6, 3.2, 5.0, 5.0]))) def test_backoff(self): """ L{CachingFilePath} will wait for an increasing interval up to C{BACKOFF_MAX} between calls to listdir(). """ self.requireTimePassed(['a', 'b', 'c']) self.assertEquals(self.cfp.listdir(), ['a', 'b', 'c']) def test_siblingExtensionSearch(self): """ L{FilePath.siblingExtensionSearch} is unfortunately not implemented in terms of L{FilePath.listdir}, so we need to verify that it will also retry. """ filenames = [self.cfp.basename()+'.a', self.cfp.basename() + '.b', self.cfp.basename() + '.c'] siblings = map(self.cfp.sibling, filenames) for sibling in siblings: sibling.touch() self.requireTimePassed(filenames) self.assertEquals(self.cfp.siblingExtensionSearch("*"), siblings[0]) self.assertRequiredTimePassed() calendarserver-5.2+dfsg/twext/python/test/test_log.py0000644000175000017500000007616212263343324022172 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from zope.interface.verify import verifyObject, BrokenMethodImplementation from twisted.python import log as twistedLogging from twisted.python.failure import Failure from twisted.trial import unittest from twext.python.log import ( LogLevel, InvalidLogLevelError, pythonLogLevelMapping, formatEvent, formatUnformattableEvent, formatWithCall, Logger, LegacyLogger, ILogObserver, LogPublisher, DefaultLogPublisher, FilteringLogObserver, PredicateResult, LogLevelFilterPredicate, OBSERVER_REMOVED ) defaultLogLevel = LogLevelFilterPredicate().defaultLogLevel clearLogLevels = Logger.publisher.levels.clearLogLevels logLevelForNamespace = Logger.publisher.levels.logLevelForNamespace setLogLevelForNamespace = Logger.publisher.levels.setLogLevelForNamespace class TestLogger(Logger): def emit(self, level, format=None, **kwargs): if False: print "*"*60 print "level =", level print "format =", format for key, value in kwargs.items(): print key, "=", value print "*"*60 def observer(event): self.event = event twistedLogging.addObserver(observer) try: Logger.emit(self, level, format, **kwargs) finally: twistedLogging.removeObserver(observer) self.emitted = { "level": level, "format": format, "kwargs": kwargs, } class TestLegacyLogger(LegacyLogger): def __init__(self, logger=TestLogger()): LegacyLogger.__init__(self, logger=logger) class LogComposedObject(object): """ Just a regular object. """ log = TestLogger() def __init__(self, state=None): self.state = state def __str__(self): return "".format(state=self.state) class SetUpTearDown(object): def setUp(self): super(SetUpTearDown, self).setUp() clearLogLevels() def tearDown(self): super(SetUpTearDown, self).tearDown() clearLogLevels() class LoggingTests(SetUpTearDown, unittest.TestCase): """ General module tests. """ def test_levelWithName(self): """ Look up log level by name. """ for level in LogLevel.iterconstants(): self.assertIdentical(LogLevel.levelWithName(level.name), level) def test_levelWithInvalidName(self): """ You can't make up log level names. """ bogus = "*bogus*" try: LogLevel.levelWithName(bogus) except InvalidLogLevelError as e: self.assertIdentical(e.level, bogus) else: self.fail("Expected InvalidLogLevelError.") def test_defaultLogLevel(self): """ Default log level is used. """ self.failUnless(logLevelForNamespace(None), defaultLogLevel) self.failUnless(logLevelForNamespace(""), defaultLogLevel) self.failUnless(logLevelForNamespace("rocker.cool.namespace"), defaultLogLevel) def test_setLogLevel(self): """ Setting and retrieving log levels. """ setLogLevelForNamespace(None, LogLevel.error) setLogLevelForNamespace("twext.web2", LogLevel.debug) setLogLevelForNamespace("twext.web2.dav", LogLevel.warn) self.assertEquals(logLevelForNamespace(None), LogLevel.error) self.assertEquals(logLevelForNamespace("twisted"), LogLevel.error) self.assertEquals(logLevelForNamespace("twext.web2"), LogLevel.debug) self.assertEquals(logLevelForNamespace("twext.web2.dav"), LogLevel.warn) self.assertEquals(logLevelForNamespace("twext.web2.dav.test"), LogLevel.warn) self.assertEquals(logLevelForNamespace("twext.web2.dav.test1.test2"), LogLevel.warn) def test_setInvalidLogLevel(self): """ Can't pass invalid log levels to setLogLevelForNamespace(). """ self.assertRaises(InvalidLogLevelError, setLogLevelForNamespace, "twext.web2", object()) # Level must be a constant, not the name of a constant self.assertRaises(InvalidLogLevelError, setLogLevelForNamespace, "twext.web2", "debug") def test_clearLogLevels(self): """ Clearing log levels. """ setLogLevelForNamespace("twext.web2", LogLevel.debug) setLogLevelForNamespace("twext.web2.dav", LogLevel.error) clearLogLevels() self.assertEquals(logLevelForNamespace("twisted"), defaultLogLevel) self.assertEquals(logLevelForNamespace("twext.web2"), defaultLogLevel) self.assertEquals(logLevelForNamespace("twext.web2.dav"), defaultLogLevel) self.assertEquals(logLevelForNamespace("twext.web2.dav.test"), defaultLogLevel) self.assertEquals(logLevelForNamespace("twext.web2.dav.test1.test2"), defaultLogLevel) def test_namespace_default(self): """ Default namespace is module name. """ log = Logger() self.assertEquals(log.namespace, __name__) def test_formatWithCall(self): """ L{formatWithCall} is an extended version of L{unicode.format} that will interpret a set of parentheses "C{()}" at the end of a format key to mean that the format key ought to be I{called} rather than stringified. """ self.assertEquals( formatWithCall( u"Hello, {world}. {callme()}.", dict(world="earth", callme=lambda: "maybe") ), "Hello, earth. maybe." ) self.assertEquals( formatWithCall( u"Hello, {repr()!r}.", dict(repr=lambda: "repr") ), "Hello, 'repr'." ) def test_formatEvent(self): """ L{formatEvent} will format an event according to several rules: - A string with no formatting instructions will be passed straight through. - PEP 3101 strings will be formatted using the keys and values of the event as named fields. - PEP 3101 keys ending with C{()} will be treated as instructions to call that key (which ought to be a callable) before formatting. L{formatEvent} will always return L{unicode}, and if given bytes, will always treat its format string as UTF-8 encoded. """ def format(log_format, **event): event["log_format"] = log_format result = formatEvent(event) self.assertIdentical(type(result), unicode) return result self.assertEquals(u"", format(b"")) self.assertEquals(u"", format(u"")) self.assertEquals(u"abc", format("{x}", x="abc")) self.assertEquals(u"no, yes.", format("{not_called}, {called()}.", not_called="no", called=lambda: "yes")) self.assertEquals(u'S\xe1nchez', format("S\xc3\xa1nchez")) self.assertIn(u"Unable to format event", format(b"S\xe1nchez")) self.assertIn(u"Unable to format event", format(b"S{a}nchez", a=b"\xe1")) self.assertIn(u"S'\\xe1'nchez", format(b"S{a!r}nchez", a=b"\xe1")) def test_formatEventNoFormat(self): """ Formatting an event with no format. """ event = dict(foo=1, bar=2) result = formatEvent(event) self.assertIn("Unable to format event", result) self.assertIn(repr(event), result) def test_formatEventWeirdFormat(self): """ Formatting an event with a bogus format. """ event = dict(log_format=object(), foo=1, bar=2) result = formatEvent(event) self.assertIn("Log format must be unicode or bytes", result) self.assertIn(repr(event), result) def test_formatUnformattableEvent(self): """ Formatting an event that's just plain out to get us. """ event = dict(log_format="{evil()}", evil=lambda: 1/0) result = formatEvent(event) self.assertIn("Unable to format event", result) self.assertIn(repr(event), result) def test_formatUnformattableEventWithUnformattableKey(self): """ Formatting an unformattable event that has an unformattable key. """ event = { "log_format": "{evil()}", "evil": lambda: 1/0, Unformattable(): "gurk", } result = formatEvent(event) self.assertIn("MESSAGE LOST: unformattable object logged:", result) self.assertIn("Recoverable data:", result) self.assertIn("Exception during formatting:", result) def test_formatUnformattableEventWithUnformattableValue(self): """ Formatting an unformattable event that has an unformattable value. """ event = dict( log_format="{evil()}", evil=lambda: 1/0, gurk=Unformattable(), ) result = formatEvent(event) self.assertIn("MESSAGE LOST: unformattable object logged:", result) self.assertIn("Recoverable data:", result) self.assertIn("Exception during formatting:", result) def test_formatUnformattableEventWithUnformattableErrorOMGWillItStop(self): """ Formatting an unformattable event that has an unformattable value. """ event = dict( log_format="{evil()}", evil=lambda: 1/0, recoverable="okay", ) # Call formatUnformattableEvent() directly with a bogus exception. result = formatUnformattableEvent(event, Unformattable()) self.assertIn("MESSAGE LOST: unformattable object logged:", result) self.assertIn(repr("recoverable") + " = " + repr("okay"), result) class LoggerTests(SetUpTearDown, unittest.TestCase): """ Tests for L{Logger}. """ def test_repr(self): """ repr() on Logger """ namespace = "bleargh" log = Logger(namespace) self.assertEquals(repr(log), "".format(repr(namespace))) def test_namespace_attribute(self): """ Default namespace for classes using L{Logger} as a descriptor is the class name they were retrieved from. """ obj = LogComposedObject() self.assertEquals(obj.log.namespace, "twext.python.test.test_log.LogComposedObject") self.assertEquals(LogComposedObject.log.namespace, "twext.python.test.test_log.LogComposedObject") self.assertIdentical(LogComposedObject.log.source, LogComposedObject) self.assertIdentical(obj.log.source, obj) self.assertIdentical(Logger().source, None) def test_sourceAvailableForFormatting(self): """ On instances that have a L{Logger} class attribute, the C{log_source} key is available to format strings. """ obj = LogComposedObject("hello") log = obj.log log.error("Hello, {log_source}.") self.assertIn("log_source", log.event) self.assertEquals(log.event["log_source"], obj) stuff = formatEvent(log.event) self.assertIn("Hello, .", stuff) def test_basic_Logger(self): """ Test that log levels and messages are emitted correctly for Logger. """ # FIXME: Need a basic test like this for logger attached to a class. # At least: source should not be None in that case. log = TestLogger() for level in LogLevel.iterconstants(): format = "This is a {level_name} message" message = format.format(level_name=level.name) method = getattr(log, level.name) method(format, junk=message, level_name=level.name) # Ensure that test_emit got called with expected arguments self.assertEquals(log.emitted["level"], level) self.assertEquals(log.emitted["format"], format) self.assertEquals(log.emitted["kwargs"]["junk"], message) if level >= logLevelForNamespace(log.namespace): self.assertTrue(hasattr(log, "event"), "No event observed.") self.assertEquals(log.event["log_format"], format) self.assertEquals(log.event["log_level"], level) self.assertEquals(log.event["log_namespace"], __name__) self.assertEquals(log.event["log_source"], None) self.assertEquals(log.event["logLevel"], pythonLogLevelMapping[level]) self.assertEquals(log.event["junk"], message) # FIXME: this checks the end of message because we do # formatting in emit() self.assertEquals( formatEvent(log.event), message ) else: self.assertFalse(hasattr(log, "event")) def test_defaultFailure(self): """ Test that log.failure() emits the right data. """ log = TestLogger() try: raise RuntimeError("baloney!") except RuntimeError: log.failure("Whoops") # # log.failure() will cause trial to complain, so here we check that # trial saw the correct error and remove it from the list of things to # complain about. # errors = self.flushLoggedErrors(RuntimeError) self.assertEquals(len(errors), 1) self.assertEquals(log.emitted["level"], LogLevel.error) self.assertEquals(log.emitted["format"], "Whoops") def test_conflicting_kwargs(self): """ Make sure that kwargs conflicting with args don't pass through. """ log = TestLogger() log.warn( "*", log_format="#", log_level=LogLevel.error, log_namespace="*namespace*", log_source="*source*", ) # FIXME: Should conflicts log errors? self.assertEquals(log.event["log_format"], "*") self.assertEquals(log.event["log_level"], LogLevel.warn) self.assertEquals(log.event["log_namespace"], log.namespace) self.assertEquals(log.event["log_source"], None) def test_logInvalidLogLevel(self): """ Test passing in a bogus log level to C{emit()}. """ log = TestLogger() log.emit("*bogus*") errors = self.flushLoggedErrors(InvalidLogLevelError) self.assertEquals(len(errors), 1) class LogPublisherTests(SetUpTearDown, unittest.TestCase): """ Tests for L{LogPublisher}. """ def test_interface(self): """ L{LogPublisher} is an L{ILogObserver}. """ publisher = LogPublisher() try: verifyObject(ILogObserver, publisher) except BrokenMethodImplementation as e: self.fail(e) def test_observers(self): """ L{LogPublisher.observers} returns the observers. """ o1 = lambda e: None o2 = lambda e: None publisher = LogPublisher(o1, o2) self.assertEquals(set((o1, o2)), set(publisher.observers)) def test_addObserver(self): """ L{LogPublisher.addObserver} adds an observer. """ o1 = lambda e: None o2 = lambda e: None o3 = lambda e: None publisher = LogPublisher(o1, o2) publisher.addObserver(o3) self.assertEquals(set((o1, o2, o3)), set(publisher.observers)) def test_removeObserver(self): """ L{LogPublisher.removeObserver} removes an observer. """ o1 = lambda e: None o2 = lambda e: None o3 = lambda e: None publisher = LogPublisher(o1, o2, o3) publisher.removeObserver(o2) self.assertEquals(set((o1, o3)), set(publisher.observers)) def test_removeObserverNotRegistered(self): """ L{LogPublisher.removeObserver} removes an observer that is not registered. """ o1 = lambda e: None o2 = lambda e: None o3 = lambda e: None publisher = LogPublisher(o1, o2) publisher.removeObserver(o3) self.assertEquals(set((o1, o2)), set(publisher.observers)) def test_fanOut(self): """ L{LogPublisher} calls its observers. """ event = dict(foo=1, bar=2) events1 = [] events2 = [] events3 = [] o1 = lambda e: events1.append(e) o2 = lambda e: events2.append(e) o3 = lambda e: events3.append(e) publisher = LogPublisher(o1, o2, o3) publisher(event) self.assertIn(event, events1) self.assertIn(event, events2) self.assertIn(event, events3) def test_observerRaises(self): nonTestEvents = [] Logger.publisher.addObserver(lambda e: nonTestEvents.append(e)) event = dict(foo=1, bar=2) exception = RuntimeError("ARGH! EVIL DEATH!") events = [] def observer(event): events.append(event) raise exception publisher = LogPublisher(observer) publisher(event) # Verify that the observer saw my event self.assertIn(event, events) # Verify that the observer raised my exception errors = self.flushLoggedErrors(exception.__class__) self.assertEquals(len(errors), 1) self.assertIdentical(errors[0].value, exception) # Verify that the exception was logged for event in nonTestEvents: if ( event.get("log_format", None) == OBSERVER_REMOVED and getattr(event.get("failure", None), "value") is exception ): break else: self.fail("Observer raised an exception " "and the exception was not logged.") def test_observerRaisesAndLoggerHatesMe(self): nonTestEvents = [] Logger.publisher.addObserver(lambda e: nonTestEvents.append(e)) event = dict(foo=1, bar=2) exception = RuntimeError("ARGH! EVIL DEATH!") def observer(event): raise RuntimeError("Sad panda") class GurkLogger(Logger): def failure(self, *args, **kwargs): raise exception publisher = LogPublisher(observer) publisher.log = GurkLogger() publisher(event) # Here, the lack of an exception thus far is a success, of sorts class DefaultLogPublisherTests(SetUpTearDown, unittest.TestCase): def test_addObserver(self): o1 = lambda e: None o2 = lambda e: None o3 = lambda e: None publisher = DefaultLogPublisher() publisher.addObserver(o1) publisher.addObserver(o2, filtered=True) publisher.addObserver(o3, filtered=False) self.assertEquals( set((o1, o2, publisher.legacyLogObserver)), set(publisher.filteredPublisher.observers), "Filtered observers do not match expected set" ) self.assertEquals( set((o3, publisher.filters)), set(publisher.rootPublisher.observers), "Root observers do not match expected set" ) def test_addObserverAgain(self): o1 = lambda e: None o2 = lambda e: None o3 = lambda e: None publisher = DefaultLogPublisher() publisher.addObserver(o1) publisher.addObserver(o2, filtered=True) publisher.addObserver(o3, filtered=False) # Swap filtered-ness of o2 and o3 publisher.addObserver(o1) publisher.addObserver(o2, filtered=False) publisher.addObserver(o3, filtered=True) self.assertEquals( set((o1, o3, publisher.legacyLogObserver)), set(publisher.filteredPublisher.observers), "Filtered observers do not match expected set" ) self.assertEquals( set((o2, publisher.filters)), set(publisher.rootPublisher.observers), "Root observers do not match expected set" ) def test_removeObserver(self): o1 = lambda e: None o2 = lambda e: None o3 = lambda e: None publisher = DefaultLogPublisher() publisher.addObserver(o1) publisher.addObserver(o2, filtered=True) publisher.addObserver(o3, filtered=False) publisher.removeObserver(o2) publisher.removeObserver(o3) self.assertEquals( set((o1, publisher.legacyLogObserver)), set(publisher.filteredPublisher.observers), "Filtered observers do not match expected set" ) self.assertEquals( set((publisher.filters,)), set(publisher.rootPublisher.observers), "Root observers do not match expected set" ) def test_filteredObserver(self): namespace = __name__ event_debug = dict(log_namespace=namespace, log_level=LogLevel.debug, log_format="") event_error = dict(log_namespace=namespace, log_level=LogLevel.error, log_format="") events = [] observer = lambda e: events.append(e) publisher = DefaultLogPublisher() publisher.addObserver(observer, filtered=True) publisher(event_debug) publisher(event_error) self.assertNotIn(event_debug, events) self.assertIn(event_error, events) def test_filteredObserverNoFilteringKeys(self): event_debug = dict(log_level=LogLevel.debug) event_error = dict(log_level=LogLevel.error) event_none = dict() events = [] observer = lambda e: events.append(e) publisher = DefaultLogPublisher() publisher.addObserver(observer, filtered=True) publisher(event_debug) publisher(event_error) publisher(event_none) self.assertNotIn(event_debug, events) self.assertNotIn(event_error, events) self.assertNotIn(event_none, events) def test_unfilteredObserver(self): namespace = __name__ event_debug = dict(log_namespace=namespace, log_level=LogLevel.debug, log_format="") event_error = dict(log_namespace=namespace, log_level=LogLevel.error, log_format="") events = [] observer = lambda e: events.append(e) publisher = DefaultLogPublisher() publisher.addObserver(observer, filtered=False) publisher(event_debug) publisher(event_error) self.assertIn(event_debug, events) self.assertIn(event_error, events) class FilteringLogObserverTests(SetUpTearDown, unittest.TestCase): """ Tests for L{FilteringLogObserver}. """ def test_interface(self): """ L{FilteringLogObserver} is an L{ILogObserver}. """ observer = FilteringLogObserver(lambda e: None, ()) try: verifyObject(ILogObserver, observer) except BrokenMethodImplementation as e: self.fail(e) def filterWith(self, *filters): events = [ dict(count=0), dict(count=1), dict(count=2), dict(count=3), ] class Filters(object): @staticmethod def twoMinus(event): if event["count"] <= 2: return PredicateResult.yes return PredicateResult.maybe @staticmethod def twoPlus(event): if event["count"] >= 2: return PredicateResult.yes return PredicateResult.maybe @staticmethod def notTwo(event): if event["count"] == 2: return PredicateResult.no return PredicateResult.maybe @staticmethod def no(event): return PredicateResult.no @staticmethod def bogus(event): return None predicates = (getattr(Filters, f) for f in filters) eventsSeen = [] trackingObserver = lambda e: eventsSeen.append(e) filteringObserver = FilteringLogObserver(trackingObserver, predicates) for e in events: filteringObserver(e) return [e["count"] for e in eventsSeen] def test_shouldLogEvent_noFilters(self): self.assertEquals(self.filterWith(), [0, 1, 2, 3]) def test_shouldLogEvent_noFilter(self): self.assertEquals(self.filterWith("notTwo"), [0, 1, 3]) def test_shouldLogEvent_yesFilter(self): self.assertEquals(self.filterWith("twoPlus"), [0, 1, 2, 3]) def test_shouldLogEvent_yesNoFilter(self): self.assertEquals(self.filterWith("twoPlus", "no"), [2, 3]) def test_shouldLogEvent_yesYesNoFilter(self): self.assertEquals(self.filterWith("twoPlus", "twoMinus", "no"), [0, 1, 2, 3]) def test_shouldLogEvent_badPredicateResult(self): self.assertRaises(TypeError, self.filterWith, "bogus") def test_call(self): e = dict(obj=object()) def callWithPredicateResult(result): seen = [] observer = FilteringLogObserver(lambda e: seen.append(e), (lambda e: result,)) observer(e) return seen self.assertIn(e, callWithPredicateResult(PredicateResult.yes)) self.assertIn(e, callWithPredicateResult(PredicateResult.maybe)) self.assertNotIn(e, callWithPredicateResult(PredicateResult.no)) class LegacyLoggerTests(SetUpTearDown, unittest.TestCase): """ Tests for L{LegacyLogger}. """ def test_namespace_default(self): """ Default namespace is module name. """ log = TestLegacyLogger(logger=None) self.assertEquals(log.newStyleLogger.namespace, __name__) def test_passThroughAttributes(self): """ C{__getattribute__} on L{LegacyLogger} is passing through to Twisted's logging module. """ log = TestLegacyLogger() # Not passed through self.assertIn("API-compatible", log.msg.__doc__) self.assertIn("API-compatible", log.err.__doc__) # Passed through self.assertIdentical(log.addObserver, twistedLogging.addObserver) def test_legacy_msg(self): """ Test LegacyLogger's log.msg() """ log = TestLegacyLogger() message = "Hi, there." kwargs = {"foo": "bar", "obj": object()} log.msg(message, **kwargs) self.assertIdentical(log.newStyleLogger.emitted["level"], LogLevel.info) self.assertEquals(log.newStyleLogger.emitted["format"], message) for key, value in kwargs.items(): self.assertIdentical(log.newStyleLogger.emitted["kwargs"][key], value) log.msg(foo="") self.assertIdentical(log.newStyleLogger.emitted["level"], LogLevel.info) self.assertIdentical(log.newStyleLogger.emitted["format"], None) def test_legacy_err_implicit(self): """ Test LegacyLogger's log.err() capturing the in-flight exception. """ log = TestLegacyLogger() exception = RuntimeError("Oh me, oh my.") kwargs = {"foo": "bar", "obj": object()} try: raise exception except RuntimeError: log.err(**kwargs) self.legacy_err(log, kwargs, None, exception) def test_legacy_err_exception(self): """ Test LegacyLogger's log.err() with a given exception. """ log = TestLegacyLogger() exception = RuntimeError("Oh me, oh my.") kwargs = {"foo": "bar", "obj": object()} why = "Because I said so." try: raise exception except RuntimeError as e: log.err(e, why, **kwargs) self.legacy_err(log, kwargs, why, exception) def test_legacy_err_failure(self): """ Test LegacyLogger's log.err() with a given L{Failure}. """ log = TestLegacyLogger() exception = RuntimeError("Oh me, oh my.") kwargs = {"foo": "bar", "obj": object()} why = "Because I said so." try: raise exception except RuntimeError: log.err(Failure(), why, **kwargs) self.legacy_err(log, kwargs, why, exception) def test_legacy_err_bogus(self): """ Test LegacyLogger's log.err() with a bogus argument. """ log = TestLegacyLogger() exception = RuntimeError("Oh me, oh my.") kwargs = {"foo": "bar", "obj": object()} why = "Because I said so." bogus = object() try: raise exception except RuntimeError: log.err(bogus, why, **kwargs) errors = self.flushLoggedErrors(exception.__class__) self.assertEquals(len(errors), 0) self.assertIdentical(log.newStyleLogger.emitted["level"], LogLevel.error) self.assertEquals(log.newStyleLogger.emitted["format"], repr(bogus)) self.assertIdentical(log.newStyleLogger.emitted["kwargs"]["why"], why) for key, value in kwargs.items(): self.assertIdentical(log.newStyleLogger.emitted["kwargs"][key], value) def legacy_err(self, log, kwargs, why, exception): # # log.failure() will cause trial to complain, so here we check that # trial saw the correct error and remove it from the list of things to # complain about. # errors = self.flushLoggedErrors(exception.__class__) self.assertEquals(len(errors), 1) self.assertIdentical(log.newStyleLogger.emitted["level"], LogLevel.error) self.assertEquals(log.newStyleLogger.emitted["format"], None) emittedKwargs = log.newStyleLogger.emitted["kwargs"] self.assertIdentical(emittedKwargs["failure"].__class__, Failure) self.assertIdentical(emittedKwargs["failure"].value, exception) self.assertIdentical(emittedKwargs["why"], why) for key, value in kwargs.items(): self.assertIdentical(log.newStyleLogger.emitted["kwargs"][key], value) class Unformattable(object): """ An object that raises an exception from C{__repr__}. """ def __repr__(self): return str(1/0) calendarserver-5.2+dfsg/twext/python/test/test_launchd.py0000644000175000017500000003070112263343324023014 0ustar rahulrahul## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.python.launchd}. """ import sys, os, plistlib, socket, json if __name__ == '__main__': # This module is loaded as a launchd job by test-cases below; the following # code looks up an appropriate function to run. testID = sys.argv[1] a, b = testID.rsplit(".", 1) from twisted.python.reflect import namedAny try: namedAny(".".join([a, b.replace("test_", "job_")]))() finally: sys.stdout.flush() sys.stderr.flush() skt = socket.socket() skt.connect(("127.0.0.1", int(os.environ["TESTING_PORT"]))) sys.exit(0) try: from twext.python.launchd import ( lib, ffi, _LaunchDictionary, _LaunchArray, _managed, constants, plainPython, checkin, _launchify, getLaunchDSocketFDs ) except ImportError: skip = "LaunchD not available." else: skip = False from twisted.trial.unittest import TestCase from twisted.python.filepath import FilePath class LaunchDataStructures(TestCase): """ Tests for L{_launchify} converting data structures from launchd's internals to Python objects. """ def test_fd(self): """ Test converting a launchd FD to an integer. """ fd = _managed(lib.launch_data_new_fd(2)) self.assertEquals(_launchify(fd), 2) def test_bool(self): """ Test converting a launchd bool to a Python bool. """ t = _managed(lib.launch_data_new_bool(True)) f = _managed(lib.launch_data_new_bool(False)) self.assertEqual(_launchify(t), True) self.assertEqual(_launchify(f), False) def test_real(self): """ Test converting a launchd real to a Python float. """ notQuitePi = _managed(lib.launch_data_new_real(3.14158)) self.assertEqual(_launchify(notQuitePi), 3.14158) class DictionaryTests(TestCase): """ Tests for L{_LaunchDictionary} """ def setUp(self): """ Assemble a test dictionary. """ self.testDict = _managed( lib.launch_data_alloc(lib.LAUNCH_DATA_DICTIONARY) ) key1 = ffi.new("char[]", "alpha") val1 = lib.launch_data_new_string("alpha-value") key2 = ffi.new("char[]", "beta") val2 = lib.launch_data_new_string("beta-value") key3 = ffi.new("char[]", "gamma") val3 = lib.launch_data_new_integer(3) lib.launch_data_dict_insert(self.testDict, val1, key1) lib.launch_data_dict_insert(self.testDict, val2, key2) lib.launch_data_dict_insert(self.testDict, val3, key3) self.assertEquals(lib.launch_data_dict_get_count(self.testDict), 3) def test_len(self): """ C{len(_LaunchDictionary())} returns the number of keys in the dictionary. """ self.assertEquals(len(_LaunchDictionary(self.testDict)), 3) def test_keys(self): """ L{_LaunchDictionary.keys} returns keys present in a C{launch_data_dict}. """ dictionary = _LaunchDictionary(self.testDict) self.assertEquals(set(dictionary.keys()), set([b"alpha", b"beta", b"gamma"])) def test_values(self): """ L{_LaunchDictionary.values} returns keys present in a C{launch_data_dict}. """ dictionary = _LaunchDictionary(self.testDict) self.assertEquals(set(dictionary.values()), set([b"alpha-value", b"beta-value", 3])) def test_items(self): """ L{_LaunchDictionary.items} returns all (key, value) tuples present in a C{launch_data_dict}. """ dictionary = _LaunchDictionary(self.testDict) self.assertEquals(set(dictionary.items()), set([(b"alpha", b"alpha-value"), (b"beta", b"beta-value"), (b"gamma", 3)])) def test_plainPython(self): """ L{plainPython} will convert a L{_LaunchDictionary} into a Python dictionary. """ self.assertEquals({b"alpha": b"alpha-value", b"beta": b"beta-value", b"gamma": 3}, plainPython(_LaunchDictionary(self.testDict))) def test_plainPythonNested(self): """ L{plainPython} will convert a L{_LaunchDictionary} containing another L{_LaunchDictionary} into a nested Python dictionary. """ otherDict = lib.launch_data_alloc(lib.LAUNCH_DATA_DICTIONARY) lib.launch_data_dict_insert(otherDict, lib.launch_data_new_string("bar"), "foo") lib.launch_data_dict_insert(self.testDict, otherDict, "delta") self.assertEquals({b"alpha": b"alpha-value", b"beta": b"beta-value", b"gamma": 3, b"delta": {b"foo": b"bar"}}, plainPython(_LaunchDictionary(self.testDict))) class ArrayTests(TestCase): """ Tests for L{_LaunchArray} """ def setUp(self): """ Assemble a test array. """ self.testArray = ffi.gc( lib.launch_data_alloc(lib.LAUNCH_DATA_ARRAY), lib.launch_data_free ) lib.launch_data_array_set_index( self.testArray, lib.launch_data_new_string("test-string-1"), 0 ) lib.launch_data_array_set_index( self.testArray, lib.launch_data_new_string("another string."), 1 ) lib.launch_data_array_set_index( self.testArray, lib.launch_data_new_integer(4321), 2 ) def test_length(self): """ C{len(_LaunchArray(...))} returns the number of elements in the array. """ self.assertEquals(len(_LaunchArray(self.testArray)), 3) def test_indexing(self): """ C{_LaunchArray(...)[n]} returns the n'th element in the array. """ array = _LaunchArray(self.testArray) self.assertEquals(array[0], b"test-string-1") self.assertEquals(array[1], b"another string.") self.assertEquals(array[2], 4321) def test_indexTooBig(self): """ C{_LaunchArray(...)[n]}, where C{n} is greater than the length of the array, raises an L{IndexError}. """ array = _LaunchArray(self.testArray) self.assertRaises(IndexError, lambda: array[3]) def test_iterating(self): """ Iterating over a C{_LaunchArray} returns each item in sequence. """ array = _LaunchArray(self.testArray) i = iter(array) self.assertEquals(i.next(), b"test-string-1") self.assertEquals(i.next(), b"another string.") self.assertEquals(i.next(), 4321) self.assertRaises(StopIteration, i.next) def test_plainPython(self): """ L{plainPython} converts a L{_LaunchArray} into a Python list. """ array = _LaunchArray(self.testArray) self.assertEquals(plainPython(array), [b"test-string-1", b"another string.", 4321]) def test_plainPythonNested(self): """ L{plainPython} converts a L{_LaunchArray} containing another L{_LaunchArray} into a Python list. """ sub = lib.launch_data_alloc(lib.LAUNCH_DATA_ARRAY) lib.launch_data_array_set_index(sub, lib.launch_data_new_integer(7), 0) lib.launch_data_array_set_index(self.testArray, sub, 3) array = _LaunchArray(self.testArray) self.assertEqual(plainPython(array), [b"test-string-1", b"another string.", 4321, [7]]) class SimpleStringConstants(TestCase): """ Tests for bytestring-constants wrapping. """ def test_constant(self): """ C{launchd.constants.LAUNCH_*} will return a bytes object corresponding to a constant. """ self.assertEqual(constants.LAUNCH_JOBKEY_SOCKETS, b"Sockets") self.assertRaises(AttributeError, getattr, constants, "launch_data_alloc") self.assertEquals(constants.LAUNCH_DATA_ARRAY, 2) class CheckInTests(TestCase): """ Integration tests making sure that actual checkin with launchd results in the expected values. """ def setUp(self): fp = FilePath(self.mktemp()) fp.makedirs() from twisted.internet.protocol import Protocol, Factory from twisted.internet import reactor, defer d = defer.Deferred() class JustLetMeMoveOn(Protocol): def connectionMade(self): d.callback(None) self.transport.abortConnection() f = Factory() f.protocol = JustLetMeMoveOn port = reactor.listenTCP(0, f, interface="127.0.0.1") @self.addCleanup def goodbyePort(): return port.stopListening() env = dict(os.environ) env["TESTING_PORT"] = repr(port.getHost().port) self.stdout = fp.child("stdout.txt") self.stderr = fp.child("stderr.txt") self.launchLabel = ("org.calendarserver.UNIT-TESTS." + str(os.getpid()) + "." + self.id()) plist = { "Label": self.launchLabel, "ProgramArguments": [sys.executable, "-m", __name__, self.id()], "EnvironmentVariables": env, "KeepAlive": False, "StandardOutPath": self.stdout.path, "StandardErrorPath": self.stderr.path, "Sockets": { "Awesome": [{"SecureSocketWithKey": "GeneratedSocket"}] }, "RunAtLoad": True, } self.job = fp.child("job.plist") self.job.setContent(plistlib.writePlistToString(plist)) os.spawnlp(os.P_WAIT, "launchctl", "launchctl", "load", self.job.path) return d @staticmethod def job_test(): """ Do something observable in a subprocess. """ sys.stdout.write("Sample Value.") sys.stdout.flush() def test_test(self): """ Since this test framework is somewhat finicky, let's just make sure that a test can complete. """ self.assertEquals("Sample Value.", self.stdout.getContent()) @staticmethod def job_checkin(): """ Check in in the subprocess. """ sys.stdout.write(json.dumps(plainPython(checkin()))) def test_checkin(self): """ L{checkin} performs launchd checkin and returns a launchd data structure. """ d = json.loads(self.stdout.getContent()) self.assertEqual(d[constants.LAUNCH_JOBKEY_LABEL], self.launchLabel) self.assertIsInstance(d, dict) sockets = d[constants.LAUNCH_JOBKEY_SOCKETS] self.assertEquals(len(sockets), 1) self.assertEqual(['Awesome'], sockets.keys()) awesomeSocket = sockets['Awesome'] self.assertEqual(len(awesomeSocket), 1) self.assertIsInstance(awesomeSocket[0], int) @staticmethod def job_getFDs(): """ Check-in via the high-level C{getLaunchDSocketFDs} API, that just gives us listening FDs. """ sys.stdout.write(json.dumps(getLaunchDSocketFDs())) def test_getFDs(self): """ L{getLaunchDSocketFDs} returns a Python dictionary mapping the names of sockets specified in the property list to lists of integers representing FDs. """ sockets = json.loads(self.stdout.getContent()) self.assertEquals(len(sockets), 1) self.assertEqual(['Awesome'], sockets.keys()) awesomeSocket = sockets['Awesome'] self.assertEqual(len(awesomeSocket), 1) self.assertIsInstance(awesomeSocket[0], int) def tearDown(self): """ Un-load the launchd job and report any errors it encountered. """ os.spawnlp(os.P_WAIT, "launchctl", "launchctl", "unload", self.job.path) err = self.stderr.getContent() if 'Traceback' in err: self.fail(err) calendarserver-5.2+dfsg/twext/python/test/test_sendmsg.py0000644000175000017500000001220712263343324023037 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## import socket from os import pipe, read, close, environ from twext.python.filepath import CachingFilePath as FilePath import sys from twisted.internet.defer import Deferred from twisted.internet.error import ProcessDone from twisted.trial.unittest import TestCase from twisted.internet.defer import inlineCallbacks from twisted.internet import reactor from twext.python.sendmsg import sendmsg, recvmsg from twext.python.sendfd import sendfd from twisted.internet.protocol import ProcessProtocol class ExitedWithStderr(Exception): """ A process exited with some stderr. """ def __str__(self): """ Dump the errors in a pretty way in the event of a subprocess traceback. """ return '\n'.join([''] + list(self.args)) class StartStopProcessProtocol(ProcessProtocol): """ An L{IProcessProtocol} with a Deferred for events where the subprocess starts and stops. """ def __init__(self): self.started = Deferred() self.stopped = Deferred() self.output = '' self.errors = '' def connectionMade(self): self.started.callback(self.transport) def outReceived(self, data): self.output += data def errReceived(self, data): self.errors += data def processEnded(self, reason): if reason.check(ProcessDone): self.stopped.callback(self.output) else: self.stopped.errback(ExitedWithStderr( self.errors, self.output)) def bootReactor(): """ Yield this from a trial test to bootstrap the reactor in order to avoid PotentialZombieWarning, for tests that use subprocesses. This hack will no longer be necessary in Twisted 10.1, since U{the underlying bug was fixed }. """ d = Deferred() reactor.callLater(0, d.callback, None) return d class SendmsgTestCase(TestCase): """ Tests for sendmsg extension module and associated file-descriptor sending functionality in L{twext.python.sendfd}. """ def setUp(self): """ Create a pair of UNIX sockets. """ self.input, self.output = socket.socketpair(socket.AF_UNIX) def tearDown(self): """ Close the sockets opened by setUp. """ self.input.close() self.output.close() def test_roundtrip(self): """ L{recvmsg} will retrieve a message sent via L{sendmsg}. """ sendmsg(self.input.fileno(), "hello, world!", 0) result = recvmsg(fd=self.output.fileno()) self.assertEquals(result, ("hello, world!", 0, [])) def test_wrongTypeAncillary(self): """ L{sendmsg} will show a helpful exception message when given the wrong type of object for the 'ancillary' argument. """ error = self.assertRaises(TypeError, sendmsg, self.input.fileno(), "hello, world!", 0, 4321) self.assertEquals(str(error), "sendmsg argument 3 expected list, got int") def spawn(self, script): """ Start a script that is a peer of this test as a subprocess. @param script: the module name of the script in this directory (no package prefix, no '.py') @type script: C{str} @rtype: L{StartStopProcessProtocol} """ sspp = StartStopProcessProtocol() reactor.spawnProcess( sspp, sys.executable, [ sys.executable, FilePath(__file__).sibling(script + ".py").path, str(self.output.fileno()), ], environ, childFDs={0: "w", 1: "r", 2: "r", self.output.fileno(): self.output.fileno()} ) return sspp @inlineCallbacks def test_sendSubProcessFD(self): """ Calling L{sendsmsg} with SOL_SOCKET, SCM_RIGHTS, and a platform-endian packed file descriptor number should send that file descriptor to a different process, where it can be retrieved by using L{recvmsg}. """ yield bootReactor() sspp = self.spawn("pullpipe") yield sspp.started pipeOut, pipeIn = pipe() self.addCleanup(close, pipeOut) sendfd(self.input.fileno(), pipeIn, "blonk") close(pipeIn) yield sspp.stopped self.assertEquals(read(pipeOut, 1024), "Test fixture data: blonk.\n") # Make sure that the pipe is actually closed now. self.assertEquals(read(pipeOut, 1024), "") calendarserver-5.2+dfsg/twext/python/test/pullpipe.py0000644000175000017500000000160712263343324022174 0ustar rahulrahul#!/usr/bin/python # -*- test-case-name: twext.python.test.test_sendmsg -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## if __name__ == '__main__': from twext.python.sendfd import recvfd import sys, os fd, description = recvfd(int(sys.argv[1])) os.write(fd, "Test fixture data: %s.\n" % (description,)) os.close(fd) calendarserver-5.2+dfsg/twext/python/test/test_parallel.py0000644000175000017500000000376012263343324023177 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.python.parallel}. """ from twisted.internet.defer import Deferred from twext.python.parallel import Parallelizer from twisted.trial.unittest import TestCase class ParallelizerTests(TestCase): """ Tests for L{Parallelizer}. """ def test_doAndDone(self): """ Blanket catch-all test. (TODO: split this up into more nice fine-grained tests.) """ d1 = Deferred() d2 = Deferred() d3 = Deferred() d4 = Deferred() doing = [] done = [] allDone = [] p = Parallelizer(['a', 'b', 'c']) p.do(lambda a: doing.append(a) or d1).addCallback(done.append) p.do(lambda b: doing.append(b) or d2).addCallback(done.append) p.do(lambda c: doing.append(c) or d3).addCallback(done.append) p.do(lambda b1: doing.append(b1) or d4).addCallback(done.append) p.done().addCallback(allDone.append) self.assertEqual(allDone, []) self.assertEqual(doing, ['a', 'b', 'c']) self.assertEqual(done, [None, None, None]) d2.callback(1) self.assertEqual(doing, ['a', 'b', 'c', 'b']) self.assertEqual(done, [None, None, None, None]) self.assertEqual(allDone, []) d3.callback(2) d4.callback(3) d1.callback(4) self.assertEqual(done, [None, None, None, None]) self.assertEqual(allDone, [None]) calendarserver-5.2+dfsg/twext/python/test/test_timezone.py0000644000175000017500000000501112263343324023224 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twistedcaldav.test.util import TestCase from twistedcaldav.config import config import twext.python.timezone import twistedcaldav.timezones from twext.python.timezone import getLocalTimezone, DEFAULT_TIMEZONE class DefaultTimezoneTests(TestCase): def stubLookup(self): return self._storedLookup def stubHasTZ(self, ignored): return self._storedHasTZ.pop() def setUp(self): self.patch(twext.python.timezone, "lookupSystemTimezone", self.stubLookup) self.patch(twistedcaldav.timezones, "hasTZ", self.stubHasTZ) def test_getLocalTimezone(self): # Empty config, system timezone known = use system timezone self.patch(config, "DefaultTimezone", "") self._storedLookup = "America/New_York" self._storedHasTZ = [True] self.assertEquals(getLocalTimezone(), "America/New_York") # Empty config, system timezone unknown = use DEFAULT_TIMEZONE self.patch(config, "DefaultTimezone", "") self._storedLookup = "Unknown/Unknown" self._storedHasTZ = [False] self.assertEquals(getLocalTimezone(), DEFAULT_TIMEZONE) # Known config value = use config value self.patch(config, "DefaultTimezone", "America/New_York") self._storedHasTZ = [True] self.assertEquals(getLocalTimezone(), "America/New_York") # Unknown config value, system timezone known = use system timezone self.patch(config, "DefaultTimezone", "Unknown/Unknown") self._storedLookup = "America/New_York" self._storedHasTZ = [True, False] self.assertEquals(getLocalTimezone(), "America/New_York") # Unknown config value, system timezone unknown = use DEFAULT_TIMEZONE self.patch(config, "DefaultTimezone", "Unknown/Unknown") self._storedLookup = "Unknown/Unknown" self._storedHasTZ = [False, False] self.assertEquals(getLocalTimezone(), DEFAULT_TIMEZONE) calendarserver-5.2+dfsg/twext/python/test/__init__.py0000644000175000017500000000121212263343324022071 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Test extensions to twisted.python. """ calendarserver-5.2+dfsg/twext/python/launchd.py0000644000175000017500000002004012263343324020771 0ustar rahulrahul# -*- test-case-name: twext.python.test.test_launchd -*- ## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Bindings for launchd check-in API. @see: U{SampleD.c } @var ffi: a L{cffi.FFI} instance wrapping the functions exposed by C{launch.h}. @var lib: a L{cffi} "U{dynamic library object }" wrapping the functions exposed by C{launch.h}. @var constants: Select C{LAUNCH_*} constants from C{launch.h}, exposed as plain Python values. Note that this is not a complete wrapping, but as the header file suggests, these APIs are only for use during check-in. """ from __future__ import print_function from cffi import FFI, VerificationError ffi = FFI() ffi.cdef(""" static const char* LAUNCH_KEY_CHECKIN; static const char* LAUNCH_JOBKEY_LABEL; static const char* LAUNCH_JOBKEY_SOCKETS; typedef enum { LAUNCH_DATA_DICTIONARY = 1, LAUNCH_DATA_ARRAY, LAUNCH_DATA_FD, LAUNCH_DATA_INTEGER, LAUNCH_DATA_REAL, LAUNCH_DATA_BOOL, LAUNCH_DATA_STRING, LAUNCH_DATA_OPAQUE, LAUNCH_DATA_ERRNO, LAUNCH_DATA_MACHPORT, } launch_data_type_t; typedef struct _launch_data *launch_data_t; bool launch_data_dict_insert(launch_data_t, const launch_data_t, const char *); launch_data_t launch_data_alloc(launch_data_type_t); launch_data_t launch_data_new_string(const char *); launch_data_t launch_data_new_integer(long long); launch_data_t launch_data_new_fd(int); launch_data_t launch_data_new_bool(bool); launch_data_t launch_data_new_real(double); launch_data_t launch_msg(const launch_data_t); launch_data_type_t launch_data_get_type(const launch_data_t); launch_data_t launch_data_dict_lookup(const launch_data_t, const char *); size_t launch_data_dict_get_count(const launch_data_t); long long launch_data_get_integer(const launch_data_t); void launch_data_dict_iterate( const launch_data_t, void (*)(const launch_data_t, const char *, void *), void *); int launch_data_get_fd(const launch_data_t); bool launch_data_get_bool(const launch_data_t); const char * launch_data_get_string(const launch_data_t); double launch_data_get_real(const launch_data_t); size_t launch_data_array_get_count(const launch_data_t); launch_data_t launch_data_array_get_index(const launch_data_t, size_t); bool launch_data_array_set_index(launch_data_t, const launch_data_t, size_t); void launch_data_free(launch_data_t); """) try: lib = ffi.verify(""" #include """, tag=__name__.replace(".", "_")) except VerificationError as ve: raise ImportError(ve) class _LaunchArray(object): def __init__(self, launchdata): self.launchdata = launchdata def __len__(self): return lib.launch_data_array_get_count(self.launchdata) def __getitem__(self, index): if index >= len(self): raise IndexError(index) return _launchify( lib.launch_data_array_get_index(self.launchdata, index) ) class _LaunchDictionary(object): def __init__(self, launchdata): self.launchdata = launchdata def keys(self): """ Return keys in the dictionary. """ keys = [] @ffi.callback("void (*)(const launch_data_t, const char *, void *)") def icb(v, k, n): keys.append(ffi.string(k)) lib.launch_data_dict_iterate(self.launchdata, icb, ffi.NULL) return keys def values(self): """ Return values in the dictionary. """ values = [] @ffi.callback("void (*)(const launch_data_t, const char *, void *)") def icb(v, k, n): values.append(_launchify(v)) lib.launch_data_dict_iterate(self.launchdata, icb, ffi.NULL) return values def items(self): """ Return items in the dictionary. """ values = [] @ffi.callback("void (*)(const launch_data_t, const char *, void *)") def icb(v, k, n): values.append((ffi.string(k), _launchify(v))) lib.launch_data_dict_iterate(self.launchdata, icb, ffi.NULL) return values def __getitem__(self, key): launchvalue = lib.launch_data_dict_lookup(self.launchdata, key) try: return _launchify(launchvalue) except LaunchErrno: raise KeyError(key) def __len__(self): return lib.launch_data_dict_get_count(self.launchdata) def plainPython(x): """ Convert a launchd python-like data structure into regular Python dictionaries and lists. """ if isinstance(x, _LaunchDictionary): result = {} for k, v in x.items(): result[k] = plainPython(v) return result elif isinstance(x, _LaunchArray): return map(plainPython, x) else: return x class LaunchErrno(Exception): """ Error from launchd. """ def _launchify(launchvalue): """ Convert a ctypes value wrapping a C{_launch_data} structure into the relevant Python object (integer, bytes, L{_LaunchDictionary}, L{_LaunchArray}). """ if launchvalue == ffi.NULL: return None dtype = lib.launch_data_get_type(launchvalue) if dtype == lib.LAUNCH_DATA_DICTIONARY: return _LaunchDictionary(launchvalue) elif dtype == lib.LAUNCH_DATA_ARRAY: return _LaunchArray(launchvalue) elif dtype == lib.LAUNCH_DATA_FD: return lib.launch_data_get_fd(launchvalue) elif dtype == lib.LAUNCH_DATA_INTEGER: return lib.launch_data_get_integer(launchvalue) elif dtype == lib.LAUNCH_DATA_REAL: return lib.launch_data_get_real(launchvalue) elif dtype == lib.LAUNCH_DATA_BOOL: return lib.launch_data_get_bool(launchvalue) elif dtype == lib.LAUNCH_DATA_STRING: cvalue = lib.launch_data_get_string(launchvalue) if cvalue == ffi.NULL: return None return ffi.string(cvalue) elif dtype == lib.LAUNCH_DATA_OPAQUE: return launchvalue elif dtype == lib.LAUNCH_DATA_ERRNO: raise LaunchErrno(launchvalue) elif dtype == lib.LAUNCH_DATA_MACHPORT: return lib.launch_data_get_machport(launchvalue) else: raise TypeError("Unknown Launch Data Type", dtype) def checkin(): """ Perform a launchd checkin, returning a Pythonic wrapped data structure representing the retrieved check-in plist. @return: a C{dict}-like object. """ lkey = lib.launch_data_new_string(lib.LAUNCH_KEY_CHECKIN) msgr = lib.launch_msg(lkey) return _launchify(msgr) def _managed(obj): """ Automatically free an object that was allocated with a launch_data_* function, or raise L{MemoryError} if it's C{NULL}. """ if obj == ffi.NULL: raise MemoryError() else: return ffi.gc(obj, lib.launch_data_free) class _Strings(object): """ Expose constants as Python-readable values rather than wrapped ctypes pointers. """ def __getattribute__(self, name): value = getattr(lib, name) if isinstance(value, int): return value if ffi.typeof(value) != ffi.typeof("char *"): raise AttributeError("no such constant", name) return ffi.string(value) constants = _Strings() def getLaunchDSocketFDs(): """ Perform checkin via L{checkin} and return just a dictionary mapping the sockets to file descriptors. """ return plainPython(checkin()[constants.LAUNCH_JOBKEY_SOCKETS]) __all__ = [ 'checkin', 'lib', 'ffi', 'plainPython', ] calendarserver-5.2+dfsg/twext/python/sendfd.py0000644000175000017500000000453712263343324020633 0ustar rahulrahul# -*- test-case-name: twext.python.test.test_sendmsg -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from struct import pack, unpack, calcsize from socket import SOL_SOCKET from twext.python.sendmsg import sendmsg, recvmsg, SCM_RIGHTS def sendfd(socketfd, fd, description): """ Send the given FD to another process via L{sendmsg} on the given C{AF_UNIX} socket. @param socketfd: An C{AF_UNIX} socket, attached to another process waiting to receive sockets via the ancillary data mechanism in L{sendmsg}. @type socketfd: C{int} @param fd: A file descriptor to be sent to the other process. @type fd: C{int} @param description: a string describing the socket that was passed. @type description: C{str} """ sendmsg( socketfd, description, 0, [(SOL_SOCKET, SCM_RIGHTS, pack("i", fd))] ) def recvfd(socketfd): """ Receive a file descriptor from a L{sendmsg} message on the given C{AF_UNIX} socket. @param socketfd: An C{AF_UNIX} socket, attached to another process waiting to send sockets via the ancillary data mechanism in L{sendmsg}. @param fd: C{int} @return: a 2-tuple of (new file descriptor, description). @rtype: 2-tuple of (C{int}, C{str}) """ data, _ignore_flags, ancillary = recvmsg(socketfd) [(_ignore_cmsg_level, _ignore_cmsg_type, packedFD)] = ancillary # cmsg_level and cmsg_type really need to be SOL_SOCKET / SCM_RIGHTS, but # since those are the *only* standard values, there's not much point in # checking. unpackedFD = 0 int_size = calcsize("i") if len(packedFD) > int_size: # [ar]happens on 64 bit architecture (FreeBSD) [unpackedFD] = unpack("i", packedFD[0:int_size]) else: [unpackedFD] = unpack("i", packedFD) return (unpackedFD, data) calendarserver-5.2+dfsg/twext/python/_plistlib.py0000644000175000017500000003531412113213176021341 0ustar rahulrahul# # Added to standard library in Python 2.6 (Mac only in prior versions) # from __future__ import print_function """plistlib.py -- a tool to generate and parse MacOSX .plist files. The PropertList (.plist) file format is a simple XML pickle supporting basic object types, like dictionaries, lists, numbers and strings. Usually the top level object is a dictionary. To write out a plist file, use the writePlist(rootObject, pathOrFile) function. 'rootObject' is the top level object, 'pathOrFile' is a filename or a (writable) file object. To parse a plist from a file, use the readPlist(pathOrFile) function, with a file name or a (readable) file object as the only argument. It returns the top level object (again, usually a dictionary). To work with plist data in strings, you can use readPlistFromString() and writePlistToString(). Values can be strings, integers, floats, booleans, tuples, lists, dictionaries, Data or datetime.datetime objects. String values (including dictionary keys) may be unicode strings -- they will be written out as UTF-8. The plist type is supported through the Data class. This is a thin wrapper around a Python string. Generate Plist example:: pl = dict( aString="Doodah", aList=["A", "B", 12, 32.1, [1, 2, 3]], aFloat = 0.1, anInt = 728, aDict=dict( anotherString="", aUnicodeValue=u'M\xe4ssig, Ma\xdf', aTrueValue=True, aFalseValue=False, ), someData = Data(""), someMoreData = Data("" * 10), aDate = datetime.datetime.fromtimestamp(time.mktime(time.gmtime())), ) # unicode keys are possible, but a little awkward to use: pl[u'\xc5benraa'] = "That was a unicode key." writePlist(pl, fileName) Parse Plist example:: pl = readPlist(pathOrFile) print(pl["aKey"]) """ __all__ = [ "readPlist", "writePlist", "readPlistFromString", "writePlistToString", "readPlistFromResource", "writePlistToResource", "Plist", "Data", "Dict" ] # Note: the Plist and Dict classes have been deprecated. import binascii import datetime from cStringIO import StringIO import re def readPlist(pathOrFile): """Read a .plist file. 'pathOrFile' may either be a file name or a (readable) file object. Return the unpacked root object (which usually is a dictionary). """ didOpen = 0 if isinstance(pathOrFile, (str, unicode)): pathOrFile = open(pathOrFile) didOpen = 1 p = PlistParser() rootObject = p.parse(pathOrFile) if didOpen: pathOrFile.close() return rootObject def writePlist(rootObject, pathOrFile): """Write 'rootObject' to a .plist file. 'pathOrFile' may either be a file name or a (writable) file object. """ didOpen = 0 if isinstance(pathOrFile, (str, unicode)): pathOrFile = open(pathOrFile, "w") didOpen = 1 writer = PlistWriter(pathOrFile) writer.writeln("") writer.writeValue(rootObject) writer.writeln("") if didOpen: pathOrFile.close() def readPlistFromString(data): """Read a plist data from a string. Return the root object. """ return readPlist(StringIO(data)) def writePlistToString(rootObject): """Return 'rootObject' as a plist-formatted string. """ f = StringIO() writePlist(rootObject, f) return f.getvalue() def readPlistFromResource(path, restype='plst', resid=0): """Read plst resource from the resource fork of path. """ from Carbon.File import FSRef, FSGetResourceForkName from Carbon.Files import fsRdPerm from Carbon import Res fsRef = FSRef(path) resNum = Res.FSOpenResourceFile(fsRef, FSGetResourceForkName(), fsRdPerm) Res.UseResFile(resNum) plistData = Res.Get1Resource(restype, resid).data Res.CloseResFile(resNum) return readPlistFromString(plistData) def writePlistToResource(rootObject, path, restype='plst', resid=0): """Write 'rootObject' as a plst resource to the resource fork of path. """ from Carbon.File import FSRef, FSGetResourceForkName from Carbon.Files import fsRdWrPerm from Carbon import Res plistData = writePlistToString(rootObject) fsRef = FSRef(path) resNum = Res.FSOpenResourceFile(fsRef, FSGetResourceForkName(), fsRdWrPerm) Res.UseResFile(resNum) try: Res.Get1Resource(restype, resid).RemoveResource() except Res.Error: pass res = Res.Resource(plistData) res.AddResource(restype, resid, '') res.WriteResource() Res.CloseResFile(resNum) class DumbXMLWriter: def __init__(self, file, indentLevel=0, indent="\t"): self.file = file self.stack = [] self.indentLevel = indentLevel self.indent = indent def beginElement(self, element): self.stack.append(element) self.writeln("<%s>" % element) self.indentLevel += 1 def endElement(self, element): assert self.indentLevel > 0 assert self.stack.pop() == element self.indentLevel -= 1 self.writeln("" % element) def simpleElement(self, element, value=None): if value is not None: value = _escapeAndEncode(value) self.writeln("<%s>%s" % (element, value, element)) else: self.writeln("<%s/>" % element) def writeln(self, line): if line: self.file.write(self.indentLevel * self.indent + line + "\n") else: self.file.write("\n") # Contents should conform to a subset of ISO 8601 # (in particular, YYYY '-' MM '-' DD 'T' HH ':' MM ':' SS 'Z'. Smaller units may be omitted with # a loss of precision) _dateParser = re.compile(r"(?P\d\d\d\d)(?:-(?P\d\d)(?:-(?P\d\d)(?:T(?P\d\d)(?::(?P\d\d)(?::(?P\d\d))?)?)?)?)?Z") def _dateFromString(s): order = ('year', 'month', 'day', 'hour', 'minute', 'second') gd = _dateParser.match(s).groupdict() lst = [] for key in order: val = gd[key] if val is None: break lst.append(int(val)) return datetime.datetime(*lst) def _dateToString(d): return '%04d-%02d-%02dT%02d:%02d:%02dZ' % ( d.year, d.month, d.day, d.hour, d.minute, d.second ) # Regex to find any control chars, except for \t \n and \r _controlCharPat = re.compile( r"[\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f" r"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f]") def _escapeAndEncode(text): m = _controlCharPat.search(text) if m is not None: raise ValueError("strings can't contains control characters; " "use plistlib.Data instead") text = text.replace("\r\n", "\n") # convert DOS line endings text = text.replace("\r", "\n") # convert Mac line endings text = text.replace("&", "&") # escape '&' text = text.replace("<", "<") # escape '<' text = text.replace(">", ">") # escape '>' return text.encode("utf-8") # encode as UTF-8 PLISTHEADER = """\ """ class PlistWriter(DumbXMLWriter): def __init__(self, file, indentLevel=0, indent="\t", writeHeader=1): if writeHeader: file.write(PLISTHEADER) DumbXMLWriter.__init__(self, file, indentLevel, indent) def writeValue(self, value): if isinstance(value, (str, unicode)): self.simpleElement("string", value) elif isinstance(value, bool): # must switch for bool before int, as bool is a # subclass of int... if value: self.simpleElement("true") else: self.simpleElement("false") elif isinstance(value, int): self.simpleElement("integer", str(value)) elif isinstance(value, float): self.simpleElement("real", repr(value)) elif isinstance(value, dict): self.writeDict(value) elif isinstance(value, Data): self.writeData(value) elif isinstance(value, datetime.datetime): self.simpleElement("date", _dateToString(value)) elif isinstance(value, (tuple, list)): self.writeArray(value) else: raise TypeError("unsuported type: %s" % type(value)) def writeData(self, data): self.beginElement("data") self.indentLevel -= 1 maxlinelength = 76 - len(self.indent.replace("\t", " " * 8) * self.indentLevel) for line in data.asBase64(maxlinelength).split("\n"): if line: self.writeln(line) self.indentLevel += 1 self.endElement("data") def writeDict(self, d): self.beginElement("dict") for key, value in sorted(d.items()): if not isinstance(key, (str, unicode)): raise TypeError("keys must be strings") self.simpleElement("key", key) self.writeValue(value) self.endElement("dict") def writeArray(self, array): self.beginElement("array") for value in array: self.writeValue(value) self.endElement("array") class _InternalDict(dict): # This class is needed while Dict is scheduled for deprecation: # we only need to warn when a *user* instantiates Dict or when # the "attribute notation for dict keys" is used. def __getattr__(self, attr): try: value = self[attr] except KeyError: raise AttributeError, attr from warnings import warn warn("Attribute access from plist dicts is deprecated, use d[key] " "notation instead", PendingDeprecationWarning) return value def __setattr__(self, attr, value): from warnings import warn warn("Attribute access from plist dicts is deprecated, use d[key] " "notation instead", PendingDeprecationWarning) self[attr] = value def __delattr__(self, attr): try: del self[attr] except KeyError: raise AttributeError, attr from warnings import warn warn("Attribute access from plist dicts is deprecated, use d[key] " "notation instead", PendingDeprecationWarning) class Dict(_InternalDict): def __init__(self, **kwargs): from warnings import warn warn("The plistlib.Dict class is deprecated, use builtin dict instead", PendingDeprecationWarning) super(Dict, self).__init__(**kwargs) class Plist(_InternalDict): """This class has been deprecated. Use readPlist() and writePlist() functions instead, together with regular dict objects. """ def __init__(self, **kwargs): from warnings import warn warn("The Plist class is deprecated, use the readPlist() and " "writePlist() functions instead", PendingDeprecationWarning) super(Plist, self).__init__(**kwargs) def fromFile(cls, pathOrFile): """Deprecated. Use the readPlist() function instead.""" rootObject = readPlist(pathOrFile) plist = cls() plist.update(rootObject) return plist fromFile = classmethod(fromFile) def write(self, pathOrFile): """Deprecated. Use the writePlist() function instead.""" writePlist(self, pathOrFile) def _encodeBase64(s, maxlinelength=76): # copied from base64.encodestring(), with added maxlinelength argument maxbinsize = (maxlinelength//4)*3 pieces = [] for i in xrange(0, len(s), maxbinsize): chunk = s[i : i + maxbinsize] pieces.append(binascii.b2a_base64(chunk)) return "".join(pieces) class Data: """Wrapper for binary data.""" def __init__(self, data): self.data = data def fromBase64(cls, data): # base64.decodestring just calls binascii.a2b_base64; # it seems overkill to use both base64 and binascii. return cls(binascii.a2b_base64(data)) fromBase64 = classmethod(fromBase64) def asBase64(self, maxlinelength=76): return _encodeBase64(self.data, maxlinelength) def __cmp__(self, other): if isinstance(other, self.__class__): return cmp(self.data, other.data) elif isinstance(other, str): return cmp(self.data, other) else: return cmp(id(self), id(other)) def __repr__(self): return "%s(%s)" % (self.__class__.__name__, repr(self.data)) class PlistParser: def __init__(self): self.stack = [] self.currentKey = None self.root = None def parse(self, fileobj): from xml.parsers.expat import ParserCreate parser = ParserCreate() parser.StartElementHandler = self.handleBeginElement parser.EndElementHandler = self.handleEndElement parser.CharacterDataHandler = self.handleData parser.ParseFile(fileobj) return self.root def handleBeginElement(self, element, attrs): self.data = [] handler = getattr(self, "begin_" + element, None) if handler is not None: handler(attrs) def handleEndElement(self, element): handler = getattr(self, "end_" + element, None) if handler is not None: handler() def handleData(self, data): self.data.append(data) def addObject(self, value): if self.currentKey is not None: self.stack[-1][self.currentKey] = value self.currentKey = None elif not self.stack: # this is the root object self.root = value else: self.stack[-1].append(value) def getData(self): data = "".join(self.data) try: data = data.encode("ascii") except UnicodeError: pass self.data = [] return data # element handlers def begin_dict(self, attrs): d = _InternalDict() self.addObject(d) self.stack.append(d) def end_dict(self): self.stack.pop() def end_key(self): self.currentKey = self.getData() def begin_array(self, attrs): a = [] self.addObject(a) self.stack.append(a) def end_array(self): self.stack.pop() def end_true(self): self.addObject(True) def end_false(self): self.addObject(False) def end_integer(self): self.addObject(int(self.getData())) def end_real(self): self.addObject(float(self.getData())) def end_string(self): self.addObject(self.getData()) def end_data(self): self.addObject(Data.fromBase64(self.getData())) def end_date(self): self.addObject(_dateFromString(self.getData())) calendarserver-5.2+dfsg/twext/python/sendmsg.c0000644000175000017500000002761012263343324020617 0ustar rahulrahul/* * Copyright (c) 2010-2014 Apple Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define PY_SSIZE_T_CLEAN 1 #include #if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) /* This may cause some warnings, but if you want to get rid of them, upgrade * your Python version. */ typedef int Py_ssize_t; #endif #include #include #include /* * As per * : * * "To forestall portability problems, it is recommended that applications * not use values larger than (2**31)-1 for the socklen_t type." */ #define SOCKLEN_MAX 0x7FFFFFFF PyObject *sendmsg_socket_error; static PyObject *sendmsg_sendmsg(PyObject *self, PyObject *args, PyObject *keywds); static PyObject *sendmsg_recvmsg(PyObject *self, PyObject *args, PyObject *keywds); static PyObject *sendmsg_getsockfam(PyObject *self, PyObject *args, PyObject *keywds); static PyMethodDef sendmsg_methods[] = { {"sendmsg", (PyCFunction) sendmsg_sendmsg, METH_VARARGS | METH_KEYWORDS, NULL}, {"recvmsg", (PyCFunction) sendmsg_recvmsg, METH_VARARGS | METH_KEYWORDS, NULL}, {"getsockfam", (PyCFunction) sendmsg_getsockfam, METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL} }; PyMODINIT_FUNC initsendmsg(void) { PyObject *module; sendmsg_socket_error = NULL; /* Make sure that this has a known value before doing anything that might exit. */ module = Py_InitModule("sendmsg", sendmsg_methods); if (!module) { return; } /* The following is the only value mentioned by POSIX: http://www.opengroup.org/onlinepubs/9699919799/basedefs/sys_socket.h.html */ if (-1 == PyModule_AddIntConstant(module, "SCM_RIGHTS", SCM_RIGHTS)) { return; } /* BSD, Darwin, Hurd */ #if defined(SCM_CREDS) if (-1 == PyModule_AddIntConstant(module, "SCM_CREDS", SCM_CREDS)) { return; } #endif /* Linux */ #if defined(SCM_CREDENTIALS) if (-1 == PyModule_AddIntConstant(module, "SCM_CREDENTIALS", SCM_CREDENTIALS)) { return; } #endif /* Apparently everywhere, but not standardized. */ #if defined(SCM_TIMESTAMP) if (-1 == PyModule_AddIntConstant(module, "SCM_TIMESTAMP", SCM_TIMESTAMP)) { return; } #endif module = PyImport_ImportModule("socket"); if (!module) { return; } sendmsg_socket_error = PyObject_GetAttrString(module, "error"); if (!sendmsg_socket_error) { return; } } static PyObject *sendmsg_sendmsg(PyObject *self, PyObject *args, PyObject *keywds) { int fd; int flags = 0; Py_ssize_t sendmsg_result, iovec_length; struct msghdr message_header; struct iovec iov[1]; PyObject *ancillary = NULL; PyObject *ultimate_result = NULL; static char *kwlist[] = {"fd", "data", "flags", "ancillary", NULL}; if (!PyArg_ParseTupleAndKeywords( args, keywds, "it#|iO:sendmsg", kwlist, &fd, &iov[0].iov_base, &iovec_length, &flags, &ancillary)) { return NULL; } iov[0].iov_len = iovec_length; message_header.msg_name = NULL; message_header.msg_namelen = 0; message_header.msg_iov = iov; message_header.msg_iovlen = 1; message_header.msg_control = NULL; message_header.msg_controllen = 0; message_header.msg_flags = 0; if (ancillary) { if (!PyList_Check(ancillary)) { PyErr_Format(PyExc_TypeError, "sendmsg argument 3 expected list, got %s", ancillary->ob_type->tp_name); goto finished; } PyObject *iterator = PyObject_GetIter(ancillary); PyObject *item = NULL; if (iterator == NULL) { goto finished; } size_t all_data_len = 0; /* First we need to know how big the buffer needs to be in order to have enough space for all of the messages. */ while ( (item = PyIter_Next(iterator)) ) { int type, level; Py_ssize_t data_len; size_t prev_all_data_len; char *data; if (!PyArg_ParseTuple( item, "iit#:sendmsg ancillary data (level, type, data)", &level, &type, &data, &data_len)) { Py_DECREF(item); Py_DECREF(iterator); goto finished; } prev_all_data_len = all_data_len; all_data_len += CMSG_SPACE(data_len); Py_DECREF(item); if (all_data_len < prev_all_data_len) { Py_DECREF(iterator); PyErr_Format(PyExc_OverflowError, "Too much msg_control to fit in a size_t: %zu", prev_all_data_len); goto finished; } } Py_DECREF(iterator); iterator = NULL; /* Allocate the buffer for all of the ancillary elements, if we have * any. */ if (all_data_len) { if (all_data_len > SOCKLEN_MAX) { PyErr_Format(PyExc_OverflowError, "Too much msg_control to fit in a socklen_t: %zu", all_data_len); goto finished; } message_header.msg_control = malloc(all_data_len); if (!message_header.msg_control) { PyErr_NoMemory(); goto finished; } } message_header.msg_controllen = (socklen_t) all_data_len; iterator = PyObject_GetIter(ancillary); /* again */ item = NULL; if (!iterator) { goto finished; } /* Unpack the tuples into the control message. */ struct cmsghdr *control_message = CMSG_FIRSTHDR(&message_header); while ( (item = PyIter_Next(iterator)) ) { int type, level; Py_ssize_t data_len; size_t data_size; unsigned char *data, *cmsg_data; /* We explicitly allocated enough space for all ancillary data above; if there isn't enough room, all bets are off. */ assert(control_message); if (!PyArg_ParseTuple(item, "iit#:sendmsg ancillary data (level, type, data)", &level, &type, &data, &data_len)) { Py_DECREF(item); Py_DECREF(iterator); goto finished; } control_message->cmsg_level = level; control_message->cmsg_type = type; data_size = CMSG_LEN(data_len); if (data_size > SOCKLEN_MAX) { Py_DECREF(item); Py_DECREF(iterator); PyErr_Format(PyExc_OverflowError, "CMSG_LEN(%zd) > SOCKLEN_MAX", data_len); goto finished; } control_message->cmsg_len = (socklen_t) data_size; cmsg_data = CMSG_DATA(control_message); memcpy(cmsg_data, data, data_len); Py_DECREF(item); control_message = CMSG_NXTHDR(&message_header, control_message); } Py_DECREF(iterator); if (PyErr_Occurred()) { goto finished; } } sendmsg_result = sendmsg(fd, &message_header, flags); if (sendmsg_result < 0) { PyErr_SetFromErrno(sendmsg_socket_error); goto finished; } else { ultimate_result = Py_BuildValue("n", sendmsg_result); } finished: if (message_header.msg_control) { free(message_header.msg_control); } return ultimate_result; } static PyObject *sendmsg_recvmsg(PyObject *self, PyObject *args, PyObject *keywds) { int fd = -1; int flags = 0; int maxsize = 8192; int cmsg_size = 4*1024; size_t cmsg_space; Py_ssize_t recvmsg_result; struct msghdr message_header; struct cmsghdr *control_message; struct iovec iov[1]; char *cmsgbuf; PyObject *ancillary; PyObject *final_result = NULL; static char *kwlist[] = {"fd", "flags", "maxsize", "cmsg_size", NULL}; if (!PyArg_ParseTupleAndKeywords(args, keywds, "i|iii:recvmsg", kwlist, &fd, &flags, &maxsize, &cmsg_size)) { return NULL; } cmsg_space = CMSG_SPACE(cmsg_size); /* overflow check */ if (cmsg_space > SOCKLEN_MAX) { PyErr_Format(PyExc_OverflowError, "CMSG_SPACE(cmsg_size) greater than SOCKLEN_MAX: %d", cmsg_size); return NULL; } message_header.msg_name = NULL; message_header.msg_namelen = 0; iov[0].iov_len = maxsize; iov[0].iov_base = malloc(maxsize); if (!iov[0].iov_base) { PyErr_NoMemory(); return NULL; } message_header.msg_iov = iov; message_header.msg_iovlen = 1; cmsgbuf = malloc(cmsg_space); if (!cmsgbuf) { free(iov[0].iov_base); PyErr_NoMemory(); return NULL; } memset(cmsgbuf, 0, cmsg_space); message_header.msg_control = cmsgbuf; /* see above for overflow check */ message_header.msg_controllen = (socklen_t) cmsg_space; recvmsg_result = recvmsg(fd, &message_header, flags); if (recvmsg_result < 0) { PyErr_SetFromErrno(sendmsg_socket_error); goto finished; } ancillary = PyList_New(0); if (!ancillary) { goto finished; } for (control_message = CMSG_FIRSTHDR(&message_header); control_message; control_message = CMSG_NXTHDR(&message_header, control_message)) { PyObject *entry; /* Some platforms apparently always fill out the ancillary data structure with a single bogus value if none is provided; ignore it, if that is the case. */ if ((!(control_message->cmsg_level)) && (!(control_message->cmsg_type))) { continue; } entry = Py_BuildValue( "(iis#)", control_message->cmsg_level, control_message->cmsg_type, CMSG_DATA(control_message), (Py_ssize_t) (control_message->cmsg_len - sizeof(struct cmsghdr))); if (!entry) { Py_DECREF(ancillary); goto finished; } if (PyList_Append(ancillary, entry) < 0) { Py_DECREF(ancillary); Py_DECREF(entry); goto finished; } else { Py_DECREF(entry); } } final_result = Py_BuildValue( "s#iO", iov[0].iov_base, recvmsg_result, message_header.msg_flags, ancillary); Py_DECREF(ancillary); finished: free(iov[0].iov_base); free(cmsgbuf); return final_result; } static PyObject *sendmsg_getsockfam(PyObject *self, PyObject *args, PyObject *keywds) { int fd; struct sockaddr sa; static char *kwlist[] = {"fd", NULL}; if (!PyArg_ParseTupleAndKeywords(args, keywds, "i", kwlist, &fd)) { return NULL; } socklen_t sz = sizeof(sa); if (getsockname(fd, &sa, &sz)) { PyErr_SetFromErrno(sendmsg_socket_error); return NULL; } return Py_BuildValue("i", sa.sa_family); } calendarserver-5.2+dfsg/twext/python/plistlib.py0000644000175000017500000000136412263343324021205 0ustar rahulrahul## # Copyright (c) 2008-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## try: _plistlib = __import__("plistlib") except ImportError: from twext.python import _plistlib import sys sys.modules[__name__] = _plistlib calendarserver-5.2+dfsg/twext/python/log.py0000644000175000017500000006747612263343324020164 0ustar rahulrahul# -*- test-case-name: twext.python.test.test_log-*- ## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Classes and functions to do granular logging. Example usage in a module C{some.module}:: from twext.python.log import Logger log = Logger() def handleData(data): log.debug("Got data: {data!r}.", data=data) Or in a class:: from twext.python.log import Logger class Foo(object): log = Logger() def oops(self, data): self.log.error("Oops! Invalid data from server: {data!r}", data=data) C{Logger}s have namespaces, for which logging can be configured independently. Namespaces may be specified by passing in a C{namespace} argument to L{Logger} when instantiating it, but if none is given, the logger will derive its own namespace by using the module name of the callable that instantiated it, or, in the case of a class, by using the fully qualified name of the class. In the first example above, the namespace would be C{some.module}, and in the second example, it would be C{some.module.Foo}. """ __all__ = [ "InvalidLogLevelError", "LogLevel", "formatEvent", "Logger", "LegacyLogger", "ILogObserver", "ILegacyLogObserver", "LogPublisher", "PredicateResult", "ILogFilterPredicate", "FilteringLogObserver", "LogLevelFilterPredicate", "LegacyLogObserver", "replaceTwistedLoggers", #"StandardIOObserver", ] import sys from sys import stdout, stderr from string import Formatter import inspect import logging import time from zope.interface import Interface, implementer from twisted.python.constants import NamedConstant, Names from twisted.python.failure import Failure from twisted.python.reflect import safe_str, safe_repr import twisted.python.log from twisted.python.log import msg as twistedLogMessage from twisted.python.log import addObserver, removeObserver from twisted.python.log import ILogObserver as ILegacyLogObserver OBSERVER_REMOVED = ( "Temporarily removing observer {observer} due to exception: {e}" ) # # Log level definitions # class InvalidLogLevelError(Exception): """ Someone tried to use a L{LogLevel} that is unknown to the logging system. """ def __init__(self, level): """ @param level: a L{LogLevel} """ super(InvalidLogLevelError, self).__init__(str(level)) self.level = level class LogLevel(Names): """ Constants denoting log levels: - C{debug}: Information of use to a developer of the software, not generally of interest to someone running the software unless they are attempting to diagnose a software issue. - C{info}: Informational events: Routine information about the status of an application, such as incoming connections, startup of a subsystem, etc. - C{warn}: Warnings events: Events that may require greater attention than informational events but are not a systemic failure condition, such as authorization failures, bad data from a network client, etc. - C{error}: Error conditions: Events indicating a systemic failure, such as unhandled exceptions, loss of connectivity to a back-end database, etc. """ debug = NamedConstant() info = NamedConstant() warn = NamedConstant() error = NamedConstant() @classmethod def levelWithName(cls, name): """ @param name: the name of a L{LogLevel} @return: the L{LogLevel} with the specified C{name} """ try: return cls.lookupByName(name) except ValueError: raise InvalidLogLevelError(name) @classmethod def _priorityForLevel(cls, constant): """ We want log levels to have defined ordering - the order of definition - but they aren't value constants (the only value is the name). This is arguably a bug in Twisted, so this is just a workaround for U{until this is fixed in some way }. """ return cls._levelPriorities[constant] LogLevel._levelPriorities = dict( (constant, idx) for (idx, constant) in (enumerate(LogLevel.iterconstants())) ) # # Mappings to Python's logging module # pythonLogLevelMapping = { LogLevel.debug: logging.DEBUG, LogLevel.info: logging.INFO, LogLevel.warn: logging.WARNING, LogLevel.error: logging.ERROR, # LogLevel.critical: logging.CRITICAL, } ## # Loggers ## def formatEvent(event): """ Formats an event as a L{unicode}, using the format in C{event["log_format"]}. This implementation should never raise an exception; if the formatting cannot be done, the returned string will describe the event generically so that a useful message is emitted regardless. @param event: a logging event @return: a L{unicode} """ try: format = event.get("log_format", None) if format is None: raise ValueError("No log format provided") # Make sure format is unicode. if isinstance(format, bytes): # If we get bytes, assume it's UTF-8 bytes format = format.decode("utf-8") elif isinstance(format, unicode): pass else: raise TypeError("Log format must be unicode or bytes, not {0!r}" .format(format)) return formatWithCall(format, event) except BaseException as e: return formatUnformattableEvent(event, e) def formatUnformattableEvent(event, error): """ Formats an event as a L{unicode} that describes the event generically and a formatting error. @param event: a logging event @type dict: L{dict} @param error: the formatting error @type error: L{Exception} @return: a L{unicode} """ try: return ( u"Unable to format event {event!r}: {error}" .format(event=event, error=error) ) except BaseException: # Yikes, something really nasty happened. # # Try to recover as much formattable data as possible; hopefully at # least the namespace is sane, which will help you find the offending # logger. failure = Failure() text = ", ".join(" = ".join((safe_repr(key), safe_repr(value))) for key, value in event.items()) return ( u"MESSAGE LOST: unformattable object logged: {error}\n" u"Recoverable data: {text}\n" u"Exception during formatting:\n{failure}" .format(error=safe_repr(error), failure=failure, text=text) ) class Logger(object): """ Logging object. """ publisher = lambda e: None @staticmethod def _namespaceFromCallingContext(): """ Derive a namespace from the module containing the caller's caller. @return: a namespace """ return inspect.currentframe().f_back.f_back.f_globals["__name__"] def __init__(self, namespace=None, source=None): """ @param namespace: The namespace for this logger. Uses a dotted notation, as used by python modules. If not C{None}, then the name of the module of the caller is used. @param source: The object which is emitting events to this logger; this is automatically set on instances of a class if this L{Logger} is an attribute of that class. """ if namespace is None: namespace = self._namespaceFromCallingContext() self.namespace = namespace self.source = source def __get__(self, oself, type=None): """ When used as a descriptor, i.e.:: # athing.py class Something(object): log = Logger() def hello(self): self.log.info("Hello") a L{Logger}'s namespace will be set to the name of the class it is declared on. In the above example, the namespace would be C{athing.Something}. Additionally, it's source will be set to the actual object referring to the L{Logger}. In the above example, C{Something.log.source} would be C{Something}, and C{Something().log.source} would be an instance of C{Something}. """ if oself is None: source = type else: source = oself return self.__class__( '.'.join([type.__module__, type.__name__]), source ) def __repr__(self): return "<%s %r>" % (self.__class__.__name__, self.namespace) def emit(self, level, format=None, **kwargs): """ Emit a log event to all log observers at the given level. @param level: a L{LogLevel} @param format: a message format using new-style (PEP 3101) formatting. The logging event (which is a L{dict}) is used to render this format string. @param kwargs: additional keyword parameters to include with the event. """ # FIXME: Updated Twisted supports 'in' on constants container if level not in LogLevel.iterconstants(): self.failure( "Got invalid log level {invalidLevel!r} in {logger}.emit().", Failure(InvalidLogLevelError(level)), invalidLevel=level, logger=self, ) #level = LogLevel.error # FIXME: continue to emit? return kwargs.update( log_logger=self, log_level=level, log_namespace=self.namespace, log_source=self.source, log_format=format, log_time=time.time(), ) self.publisher(kwargs) def failure(self, format, failure=None, level=LogLevel.error, **kwargs): """ Log an failure and emit a traceback. For example:: try: frob(knob) except Exception: log.failure("While frobbing {knob}", knob=knob) or:: d = deferredFrob(knob) d.addErrback(lambda f: log.failure, "While frobbing {knob}", f, knob=knob) @param format: a message format using new-style (PEP 3101) formatting. The logging event (which is a L{dict}) is used to render this format string. @param failure: a L{Failure} to log. If C{None}, a L{Failure} is created from the exception in flight. @param level: a L{LogLevel} to use. @param kwargs: additional keyword parameters to include with the event. """ if failure is None: failure = Failure() self.emit(level, format, log_failure=failure, **kwargs) class LegacyLogger(object): """ A logging object that provides some compatibility with the L{twisted.python.log} module. """ def __init__(self, logger=None): if logger is None: self.newStyleLogger = Logger(Logger._namespaceFromCallingContext()) else: self.newStyleLogger = logger def __getattribute__(self, name): try: return super(LegacyLogger, self).__getattribute__(name) except AttributeError: return getattr(twisted.python.log, name) def msg(self, *message, **kwargs): """ This method is API-compatible with L{twisted.python.log.msg} and exists for compatibility with that API. """ if message: message = " ".join(map(safe_str, message)) else: message = None return self.newStyleLogger.emit(LogLevel.info, message, **kwargs) def err(self, _stuff=None, _why=None, **kwargs): """ This method is API-compatible with L{twisted.python.log.err} and exists for compatibility with that API. """ if _stuff is None: _stuff = Failure() elif isinstance(_stuff, Exception): _stuff = Failure(_stuff) if isinstance(_stuff, Failure): self.newStyleLogger.emit(LogLevel.error, failure=_stuff, why=_why, isError=1, **kwargs) else: # We got called with an invalid _stuff. self.newStyleLogger.emit(LogLevel.error, repr(_stuff), why=_why, isError=1, **kwargs) def bindEmit(level): doc = """ Emit a log event at log level L{{{level}}}. @param format: a message format using new-style (PEP 3101) formatting. The logging event (which is a L{{dict}}) is used to render this format string. @param kwargs: additional keyword parameters to include with the event. """.format(level=level.name) # # Attach methods to Logger # def log_emit(self, format=None, **kwargs): self.emit(level, format, **kwargs) log_emit.__doc__ = doc setattr(Logger, level.name, log_emit) def _bindLevels(): for level in LogLevel.iterconstants(): bindEmit(level) _bindLevels() # # Observers # class ILogObserver(Interface): """ An observer which can handle log events. """ def __call__(event): """ Log an event. @type event: C{dict} with (native) C{str} keys. @param event: A dictionary with arbitrary keys as defined by the application emitting logging events, as well as keys added by the logging system, with are: ... """ @implementer(ILogObserver) class LogPublisher(object): """ I{ILogObserver} that fans out events to other observers. Keeps track of a set of L{ILogObserver} objects and forwards events to each. """ log = Logger() def __init__(self, *observers): self._observers = set(observers) @property def observers(self): return frozenset(self._observers) def addObserver(self, observer): """ Registers an observer with this publisher. @param observer: An L{ILogObserver} to add. """ self._observers.add(observer) def removeObserver(self, observer): """ Unregisters an observer with this publisher. @param observer: An L{ILogObserver} to remove. """ try: self._observers.remove(observer) except KeyError: pass def __call__(self, event): for observer in self.observers: try: observer(event) except BaseException as e: # # We have to remove the offending observer because # we're going to badmouth it to all of its friends # (other observers) and it might get offended and # raise again, causing an infinite loop. # self.removeObserver(observer) try: self.log.failure(OBSERVER_REMOVED, observer=observer, e=e) except BaseException: pass finally: self.addObserver(observer) class PredicateResult(Names): """ Predicate results. """ yes = NamedConstant() # Log this no = NamedConstant() # Don't log this maybe = NamedConstant() # No opinion class ILogFilterPredicate(Interface): """ A predicate that determined whether an event should be logged. """ def __call__(event): """ Determine whether an event should be logged. @returns: a L{PredicateResult}. """ @implementer(ILogObserver) class FilteringLogObserver(object): """ L{ILogObserver} that wraps another L{ILogObserver}, but filters out events based on applying a series of L{ILogFilterPredicate}s. """ def __init__(self, observer, predicates): """ @param observer: an L{ILogObserver} to which this observer will forward events. @param predicates: an ordered iterable of predicates to apply to events before forwarding to the wrapped observer. """ self.observer = observer self.predicates = list(predicates) def shouldLogEvent(self, event): """ Determine whether an event should be logged, based C{self.predicates}. @param event: an event """ for predicate in self.predicates: result = predicate(event) if result == PredicateResult.yes: return True if result == PredicateResult.no: return False if result == PredicateResult.maybe: continue raise TypeError("Invalid predicate result: {0!r}".format(result)) return True def __call__(self, event): if self.shouldLogEvent(event): self.observer(event) @implementer(ILogFilterPredicate) class LogLevelFilterPredicate(object): """ L{ILogFilterPredicate} that filters out events with a log level lower than the log level for the event's namespace. Events that not not have a log level or namespace are also dropped. """ def __init__(self): # FIXME: Make this a class variable. But that raises an # _initializeEnumerants constants error in Twisted 12.2.0. self.defaultLogLevel = LogLevel.info self._logLevelsByNamespace = {} self.clearLogLevels() def logLevelForNamespace(self, namespace): """ @param namespace: a logging namespace, or C{None} for the default namespace. @return: the L{LogLevel} for the specified namespace. """ if not namespace: return self._logLevelsByNamespace[None] if namespace in self._logLevelsByNamespace: return self._logLevelsByNamespace[namespace] segments = namespace.split(".") index = len(segments) - 1 while index > 0: namespace = ".".join(segments[:index]) if namespace in self._logLevelsByNamespace: return self._logLevelsByNamespace[namespace] index -= 1 return self._logLevelsByNamespace[None] def setLogLevelForNamespace(self, namespace, level): """ Sets the global log level for a logging namespace. @param namespace: a logging namespace @param level: the L{LogLevel} for the given namespace. """ if level not in LogLevel.iterconstants(): raise InvalidLogLevelError(level) if namespace: self._logLevelsByNamespace[namespace] = level else: self._logLevelsByNamespace[None] = level def clearLogLevels(self): """ Clears all global log levels to the default. """ self._logLevelsByNamespace.clear() self._logLevelsByNamespace[None] = self.defaultLogLevel def __call__(self, event): level = event.get("log_level", None) namespace = event.get("log_namespace", None) if ( level is None or namespace is None or LogLevel._priorityForLevel(level) < LogLevel._priorityForLevel(self.logLevelForNamespace(namespace)) ): return PredicateResult.no return PredicateResult.maybe @implementer(ILogObserver) class LegacyLogObserver(object): """ L{ILogObserver} that wraps an L{ILegacyLogObserver}. """ def __init__(self, legacyObserver): """ @param legacyObserver: an L{ILegacyLogObserver} to which this observer will forward events. """ self.legacyObserver = legacyObserver def __call__(self, event): prefix = "[{log_namespace}#{log_level.name}] ".format(**event) level = event["log_level"] # # Twisted's logging supports indicating a python log level, so let's # provide the equivalent to our logging levels. # if level in pythonLogLevelMapping: event["logLevel"] = pythonLogLevelMapping[level] # Format new style -> old style if event["log_format"]: # # Create an object that implements __str__() in order to # defer the work of formatting until it's needed by a # legacy log observer. # class LegacyFormatStub(object): def __str__(oself): return formatEvent(event).encode("utf-8") event["format"] = prefix + "%(log_legacy)s" event["log_legacy"] = LegacyFormatStub() # log.failure() -> isError blah blah if "log_failure" in event: event["failure"] = event["log_failure"] event["isError"] = 1 event["why"] = "{prefix}{message}".format( prefix=prefix, message=formatEvent(event) ) self.legacyObserver(**event) # FIXME: This could have a better name. class DefaultLogPublisher(object): """ This observer sets up a set of chained observers as follows: 1. B{rootPublisher} - a L{LogPublisher} 2. B{filters}: a L{FilteringLogObserver} that filters out messages using a L{LogLevelFilterPredicate} 3. B{filteredPublisher} - a L{LogPublisher} 4. B{legacyLogObserver} - a L{LegacyLogObserver} wired up to L{twisted.python.log.msg}. This allows any observers registered with Twisted's logging (that is, most observers in presently use) to receive (filtered) events. The purpose of this class is to provide a default log observer with sufficient hooks to enable applications to add observers that can either receive all log messages, or only log messages that are configured to pass though the L{LogLevelFilterPredicate}:: from twext.python.log import Logger, ILogObserver log = Logger() @implementer(ILogObserver) class AMPObserver(object): def __call__(self, event): # eg.: Hold events in a ring buffer and expose them via AMP. ... @implementer(ILogObserver) class FileObserver(object): def __call__(self, event): # eg.: Take events and write them into a file. ... # Send all events to the AMPObserver log.publisher.addObserver(AMPObserver(), filtered=False) # Send filtered events to the FileObserver log.publisher.addObserver(AMPObserver()) With no observers added, the default behavior is that the legacy Twisted logging system sees messages as controlled by L{LogLevelFilterPredicate}. """ def __init__(self): self.legacyLogObserver = LegacyLogObserver(twistedLogMessage) self.filteredPublisher = LogPublisher(self.legacyLogObserver) self.levels = LogLevelFilterPredicate() self.filters = FilteringLogObserver(self.filteredPublisher, (self.levels,)) self.rootPublisher = LogPublisher(self.filters) def addObserver(self, observer, filtered=True): """ Registers an observer with this publisher. @param observer: An L{ILogObserver} to add. @param filtered: If true, registers C{observer} after filters are applied; otherwise C{observer} will get all events. """ if filtered: self.filteredPublisher.addObserver(observer) self.rootPublisher.removeObserver(observer) else: self.rootPublisher.addObserver(observer) self.filteredPublisher.removeObserver(observer) def removeObserver(self, observer): """ Unregisters an observer with this publisher. @param observer: An L{ILogObserver} to remove. """ self.rootPublisher.removeObserver(observer) self.filteredPublisher.removeObserver(observer) def __call__(self, event): self.rootPublisher(event) Logger.publisher = DefaultLogPublisher() # # Utilities # class CallMapping(object): def __init__(self, submapping): self._submapping = submapping def __getitem__(self, key): callit = key.endswith(u"()") realKey = key[:-2] if callit else key value = self._submapping[realKey] if callit: value = value() return value def formatWithCall(formatString, mapping): """ Format a string like L{unicode.format}, but: - taking only a name mapping; no positional arguments - with the additional syntax that an empty set of parentheses correspond to a formatting item that should be called, and its result C{str}'d, rather than calling C{str} on the element directly as normal. For example:: >>> formatWithCall("{string}, {function()}.", ... dict(string="just a string", ... function=lambda: "a function")) 'just a string, a function.' @param formatString: A PEP-3101 format string. @type formatString: L{unicode} @param mapping: A L{dict}-like object to format. @return: The string with formatted values interpolated. @rtype: L{unicode} """ return unicode( theFormatter.vformat(formatString, (), CallMapping(mapping)) ) theFormatter = Formatter() def replaceTwistedLoggers(): """ Visit all Python modules that have been loaded and: - replace L{twisted.python.log} with a L{LegacyLogger} - replace L{twisted.python.log.msg} with a L{LegacyLogger}'s C{msg} - replace L{twisted.python.log.err} with a L{LegacyLogger}'s C{err} """ log = Logger() for moduleName, module in sys.modules.iteritems(): # Oddly, this happens if module is None: continue # Don't patch Twisted's logging module if module in (twisted.python, twisted.python.log): continue # Don't patch this module if moduleName is __name__: continue for name, obj in module.__dict__.iteritems(): newLogger = Logger(namespace=module.__name__) legacyLogger = LegacyLogger(logger=newLogger) if obj is twisted.python.log: log.info("Replacing Twisted log module object {0} in {1}" .format(name, module.__name__)) setattr(module, name, legacyLogger) elif obj is twisted.python.log.msg: log.info("Replacing Twisted log.msg object {0} in {1}" .format(name, module.__name__)) setattr(module, name, legacyLogger.msg) elif obj is twisted.python.log.err: log.info("Replacing Twisted log.err object {0} in {1}" .format(name, module.__name__)) setattr(module, name, legacyLogger.err) ###################################################################### # FIXME: This may not be needed; look into removing it. class StandardIOObserver(object): """ (Legacy) log observer that writes to standard I/O. """ def emit(self, eventDict): text = None if eventDict["isError"]: output = stderr if "failure" in eventDict: text = eventDict["failure"].getTraceback() else: output = stdout if not text: text = " ".join([str(m) for m in eventDict["message"]]) + "\n" output.write(text) output.flush() def start(self): addObserver(self.emit) def stop(self): removeObserver(self.emit) calendarserver-5.2+dfsg/twext/python/memcacheclient.py0000644000175000017500000014310112113213176022313 0ustar rahulrahul#!/usr/bin/env python from __future__ import print_function """ client module for memcached (memory cache daemon) Overview ======== See U{the MemCached homepage} for more about memcached. Usage summary ============= This should give you a feel for how this module operates:: import memcacheclient mc = memcacheclient.Client(['127.0.0.1:11211'], debug=0) mc.set("some_key", "Some value") value = mc.get("some_key") mc.set("another_key", 3) mc.delete("another_key") mc.set("key", "1") # note that the key used for incr/decr must be a string. mc.incr("key") mc.decr("key") The standard way to use memcache with a database is like this:: key = derive_key(obj) obj = mc.get(key) if not obj: obj = backend_api.get(...) mc.set(obj) # we now have obj, and future passes through this code # will use the object from the cache. Detailed Documentation ====================== More detailed documentation is available in the L{Client} class. """ import sys import socket import time import os import re import types from twext.python.log import Logger from twistedcaldav.config import config log = Logger() try: import cPickle as pickle except ImportError: import pickle try: from zlib import compress, decompress _supports_compress = True except ImportError: _supports_compress = False # quickly define a decompress just in case we recv compressed data. def decompress(val): raise _Error("received compressed data but I don't support compession (import error)") try: from cStringIO import StringIO except ImportError: from StringIO import StringIO from binascii import crc32 # zlib version is not cross-platform serverHashFunction = crc32 __author__ = "Evan Martin " __version__ = "1.44" __copyright__ = "Copyright (C) 2003 Danga Interactive" __license__ = "Python" SERVER_MAX_KEY_LENGTH = 250 # Storing values larger than 1MB requires recompiling memcached. If you do, # this value can be changed by doing "memcacheclient.SERVER_MAX_VALUE_LENGTH = N" # after importing this module. SERVER_MAX_VALUE_LENGTH = 1024*1024 class _Error(Exception): pass class MemcacheError(_Error): """ Memcache connection error """ class NotFoundError(MemcacheError): """ NOT_FOUND error """ class TokenMismatchError(MemcacheError): """ Check-and-set token mismatch """ try: # Only exists in Python 2.4+ from threading import local except ImportError: # TODO: add the pure-python local implementation class local(object): pass class ClientFactory(object): # unit tests should set this to True to enable the fake test cache allowTestCache = False @classmethod def getClient(cls, servers, debug=0, pickleProtocol=0, pickler=pickle.Pickler, unpickler=pickle.Unpickler, pload=None, pid=None): if cls.allowTestCache: return TestClient(servers, debug=debug, pickleProtocol=pickleProtocol, pickler=pickler, unpickler=unpickler, pload=pload, pid=pid) elif config.Memcached.Pools.Default.ClientEnabled: return Client(servers, debug=debug, pickleProtocol=pickleProtocol, pickler=pickler, unpickler=unpickler, pload=pload, pid=pid) else: return None class Client(local): """ Object representing a pool of memcache servers. See L{memcache} for an overview. In all cases where a key is used, the key can be either: 1. A simple hashable type (string, integer, etc.). 2. A tuple of C{(hashvalue, key)}. This is useful if you want to avoid making this module calculate a hash value. You may prefer, for example, to keep all of a given user's objects on the same memcache server, so you could use the user's unique id as the hash value. @group Setup: __init__, set_servers, forget_dead_hosts, disconnect_all, debuglog @group Insertion: set, add, replace, set_multi @group Retrieval: get, get_multi @group Integers: incr, decr @group Removal: delete, delete_multi @sort: __init__, set_servers, forget_dead_hosts, disconnect_all, debuglog,\ set, set_multi, add, replace, get, get_multi, incr, decr, delete, delete_multi """ _FLAG_PICKLE = 1<<0 _FLAG_INTEGER = 1<<1 _FLAG_LONG = 1<<2 _FLAG_COMPRESSED = 1<<3 _SERVER_RETRIES = 10 # how many times to try finding a free server. # exceptions for Client class MemcachedKeyError(Exception): pass class MemcachedKeyLengthError(MemcachedKeyError): pass class MemcachedKeyCharacterError(MemcachedKeyError): pass class MemcachedKeyNoneError(MemcachedKeyError): pass class MemcachedKeyTypeError(MemcachedKeyError): pass class MemcachedStringEncodingError(Exception): pass def __init__(self, servers, debug=0, pickleProtocol=0, pickler=pickle.Pickler, unpickler=pickle.Unpickler, pload=None, pid=None): """ Create a new Client object with the given list of servers. @param servers: C{servers} is passed to L{set_servers}. @param debug: whether to display error messages when a server can't be contacted. @param pickleProtocol: number to mandate protocol used by (c)Pickle. @param pickler: optional override of default Pickler to allow subclassing. @param unpickler: optional override of default Unpickler to allow subclassing. @param pload: optional persistent_load function to call on pickle loading. Useful for cPickle since subclassing isn't allowed. @param pid: optional persistent_id function to call on pickle storing. Useful for cPickle since subclassing isn't allowed. """ local.__init__(self) self.set_servers(servers) self.debug = debug self.stats = {} # Allow users to modify pickling/unpickling behavior self.pickleProtocol = pickleProtocol self.pickler = pickler self.unpickler = unpickler self.persistent_load = pload self.persistent_id = pid # figure out the pickler style file = StringIO() try: pickler = self.pickler(file, protocol = self.pickleProtocol) self.picklerIsKeyword = True except TypeError: self.picklerIsKeyword = False def set_servers(self, servers): """ Set the pool of servers used by this client. @param servers: an array of servers. Servers can be passed in two forms: 1. Strings of the form C{"host:port"}, which implies a default weight of 1. 2. Tuples of the form C{("host:port", weight)}, where C{weight} is an integer weight value. """ self.servers = [_Host(s, self.debuglog) for s in servers] self._init_buckets() def get_stats(self): '''Get statistics from each of the servers. @return: A list of tuples ( server_identifier, stats_dictionary ). The dictionary contains a number of name/value pairs specifying the name of the status field and the string value associated with it. The values are not converted from strings. ''' data = [] for s in self.servers: if not s.connect(): continue if s.family == socket.AF_INET: name = '%s:%s (%s)' % ( s.ip, s.port, s.weight ) else: name = 'unix:%s (%s)' % ( s.address, s.weight ) s.send_cmd('stats') serverData = {} data.append(( name, serverData )) readline = s.readline while 1: line = readline() if not line or line.strip() == 'END': break stats = line.split(' ', 2) serverData[stats[1]] = stats[2] return(data) def get_slabs(self): data = [] for s in self.servers: if not s.connect(): continue if s.family == socket.AF_INET: name = '%s:%s (%s)' % ( s.ip, s.port, s.weight ) else: name = 'unix:%s (%s)' % ( s.address, s.weight ) serverData = {} data.append(( name, serverData )) s.send_cmd('stats items') readline = s.readline while 1: line = readline() if not line or line.strip() == 'END': break item = line.split(' ', 2) #0 = STAT, 1 = ITEM, 2 = Value slab = item[1].split(':', 2) #0 = items, 1 = Slab #, 2 = Name if not serverData.has_key(slab[1]): serverData[slab[1]] = {} serverData[slab[1]][slab[2]] = item[2] return data def flush_all(self): 'Expire all data currently in the memcache servers.' for s in self.servers: if not s.connect(): continue s.send_cmd('flush_all') s.expect("OK") def debuglog(self, str): if self.debug: sys.stderr.write("MemCached: %s\n" % str) def _statlog(self, func): if not self.stats.has_key(func): self.stats[func] = 1 else: self.stats[func] += 1 def forget_dead_hosts(self): """ Reset every host in the pool to an "alive" state. """ for s in self.servers: s.deaduntil = 0 def _init_buckets(self): self.buckets = [] for server in self.servers: for i in range(server.weight): self.buckets.append(server) def _get_server(self, key): if type(key) == types.TupleType: serverhash, key = key else: serverhash = serverHashFunction(key) for i in range(Client._SERVER_RETRIES): server = self.buckets[serverhash % len(self.buckets)] if server.connect(): #print("(using server %s)" % server, end="") return server, key serverhash = serverHashFunction(str(serverhash) + str(i)) log.error("Memcacheclient _get_server( ) failed to connect") return None, None def disconnect_all(self): for s in self.servers: s.close_socket() def delete_multi(self, keys, time=0, key_prefix=''): ''' Delete multiple keys in the memcache doing just one query. >>> notset_keys = mc.set_multi({'key1' : 'val1', 'key2' : 'val2'}) >>> mc.get_multi(['key1', 'key2']) == {'key1' : 'val1', 'key2' : 'val2'} 1 >>> mc.delete_multi(['key1', 'key2']) 1 >>> mc.get_multi(['key1', 'key2']) == {} 1 This method is recommended over iterated regular L{delete}s as it reduces total latency, since your app doesn't have to wait for each round-trip of L{delete} before sending the next one. @param keys: An iterable of keys to clear @param time: number of seconds any subsequent set / update commands should fail. Defaults to 0 for no delay. @param key_prefix: Optional string to prepend to each key when sending to memcache. See docs for L{get_multi} and L{set_multi}. @return: 1 if no failure in communication with any memcacheds. @rtype: int ''' self._statlog('delete_multi') server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(keys, key_prefix) # send out all requests on each server before reading anything dead_servers = [] rc = 1 for server in server_keys.iterkeys(): bigcmd = [] write = bigcmd.append if time != None: for key in server_keys[server]: # These are mangled keys write("delete %s %d\r\n" % (key, time)) else: for key in server_keys[server]: # These are mangled keys write("delete %s\r\n" % key) try: server.send_cmds(''.join(bigcmd)) except socket.error, msg: rc = 0 if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) dead_servers.append(server) # if any servers died on the way, don't expect them to respond. for server in dead_servers: del server_keys[server] for server, keys in server_keys.iteritems(): try: for key in keys: server.expect("DELETED") except socket.error, msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) rc = 0 return rc def delete(self, key, time=0): '''Deletes a key from the memcache. @return: Nonzero on success. @param time: number of seconds any subsequent set / update commands should fail. Defaults to 0 for no delay. @rtype: int ''' check_key(key) server, key = self._get_server(key) if not server: return 0 self._statlog('delete') if time != None: cmd = "delete %s %d" % (key, time) else: cmd = "delete %s" % key try: server.send_cmd(cmd) server.expect("DELETED") except socket.error, msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) return 0 return 1 def incr(self, key, delta=1): """ Sends a command to the server to atomically increment the value for C{key} by C{delta}, or by 1 if C{delta} is unspecified. Returns None if C{key} doesn't exist on server, otherwise it returns the new value after incrementing. Note that the value for C{key} must already exist in the memcache, and it must be the string representation of an integer. >>> mc.set("counter", "20") # returns 1, indicating success 1 >>> mc.incr("counter") 21 >>> mc.incr("counter") 22 Overflow on server is not checked. Be aware of values approaching 2**32. See L{decr}. @param delta: Integer amount to increment by (should be zero or greater). @return: New value after incrementing. @rtype: int """ return self._incrdecr("incr", key, delta) def decr(self, key, delta=1): """ Like L{incr}, but decrements. Unlike L{incr}, underflow is checked and new values are capped at 0. If server value is 1, a decrement of 2 returns 0, not -1. @param delta: Integer amount to decrement by (should be zero or greater). @return: New value after decrementing. @rtype: int """ return self._incrdecr("decr", key, delta) def _incrdecr(self, cmd, key, delta): check_key(key) server, key = self._get_server(key) if not server: return 0 self._statlog(cmd) cmd = "%s %s %d" % (cmd, key, delta) try: server.send_cmd(cmd) line = server.readline() return int(line) except socket.error, msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) return None def add(self, key, val, time = 0, min_compress_len = 0): ''' Add new key with value. Like L{set}, but only stores in memcache if the key doesn't already exist. @return: Nonzero on success. @rtype: int ''' return self._set("add", key, val, time, min_compress_len) def append(self, key, val, time=0, min_compress_len=0): '''Append the value to the end of the existing key's value. Only stores in memcache if key already exists. Also see L{prepend}. @return: Nonzero on success. @rtype: int ''' return self._set("append", key, val, time, min_compress_len) def prepend(self, key, val, time=0, min_compress_len=0): '''Prepend the value to the beginning of the existing key's value. Only stores in memcache if key already exists. Also see L{append}. @return: Nonzero on success. @rtype: int ''' return self._set("prepend", key, val, time, min_compress_len) def replace(self, key, val, time=0, min_compress_len=0): '''Replace existing key with value. Like L{set}, but only stores in memcache if the key already exists. The opposite of L{add}. @return: Nonzero on success. @rtype: int ''' return self._set("replace", key, val, time, min_compress_len) def set(self, key, val, time=0, min_compress_len=0, token=None): '''Unconditionally sets a key to a given value in the memcache. The C{key} can optionally be an tuple, with the first element being the server hash value and the second being the key. If you want to avoid making this module calculate a hash value. You may prefer, for example, to keep all of a given user's objects on the same memcache server, so you could use the user's unique id as the hash value. @return: Nonzero on success. @rtype: int @param time: Tells memcached the time which this value should expire, either as a delta number of seconds, or an absolute unix time-since-the-epoch value. See the memcached protocol docs section "Storage Commands" for more info on . We default to 0 == cache forever. @param min_compress_len: The threshold length to kick in auto-compression of the value using the zlib.compress() routine. If the value being cached is a string, then the length of the string is measured, else if the value is an , then the length of the pickle result is measured. If the resulting attempt at compression yeilds a larger string than the input, then it is discarded. For backwards compatability, this parameter defaults to 0, indicating don't ever try to compress. ''' return self._set("set", key, val, time, min_compress_len, token=token) def _map_and_prefix_keys(self, key_iterable, key_prefix): """Compute the mapping of server (_Host instance) -> list of keys to stuff onto that server, as well as the mapping of prefixed key -> original key. """ # Check it just once ... key_extra_len=len(key_prefix) if key_prefix: check_key(key_prefix) # server (_Host) -> list of unprefixed server keys in mapping server_keys = {} prefixed_to_orig_key = {} # build up a list for each server of all the keys we want. for orig_key in key_iterable: if type(orig_key) is types.TupleType: # Tuple of hashvalue, key ala _get_server(). Caller is essentially telling us what server to stuff this on. # Ensure call to _get_server gets a Tuple as well. str_orig_key = str(orig_key[1]) server, key = self._get_server((orig_key[0], key_prefix + str_orig_key)) # Gotta pre-mangle key before hashing to a server. Returns the mangled key. else: str_orig_key = str(orig_key) # set_multi supports int / long keys. server, key = self._get_server(key_prefix + str_orig_key) # Now check to make sure key length is proper ... check_key(str_orig_key, key_extra_len=key_extra_len) if not server: continue if not server_keys.has_key(server): server_keys[server] = [] server_keys[server].append(key) prefixed_to_orig_key[key] = orig_key return (server_keys, prefixed_to_orig_key) def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0): ''' Sets multiple keys in the memcache doing just one query. >>> notset_keys = mc.set_multi({'key1' : 'val1', 'key2' : 'val2'}) >>> mc.get_multi(['key1', 'key2']) == {'key1' : 'val1', 'key2' : 'val2'} 1 This method is recommended over regular L{set} as it lowers the number of total packets flying around your network, reducing total latency, since your app doesn't have to wait for each round-trip of L{set} before sending the next one. @param mapping: A dict of key/value pairs to set. @param time: Tells memcached the time which this value should expire, either as a delta number of seconds, or an absolute unix time-since-the-epoch value. See the memcached protocol docs section "Storage Commands" for more info on . We default to 0 == cache forever. @param key_prefix: Optional string to prepend to each key when sending to memcache. Allows you to efficiently stuff these keys into a pseudo-namespace in memcache: >>> notset_keys = mc.set_multi({'key1' : 'val1', 'key2' : 'val2'}, key_prefix='subspace_') >>> len(notset_keys) == 0 True >>> mc.get_multi(['subspace_key1', 'subspace_key2']) == {'subspace_key1' : 'val1', 'subspace_key2' : 'val2'} True Causes key 'subspace_key1' and 'subspace_key2' to be set. Useful in conjunction with a higher-level layer which applies namespaces to data in memcache. In this case, the return result would be the list of notset original keys, prefix not applied. @param min_compress_len: The threshold length to kick in auto-compression of the value using the zlib.compress() routine. If the value being cached is a string, then the length of the string is measured, else if the value is an object, then the length of the pickle result is measured. If the resulting attempt at compression yeilds a larger string than the input, then it is discarded. For backwards compatability, this parameter defaults to 0, indicating don't ever try to compress. @return: List of keys which failed to be stored [ memcache out of memory, etc. ]. @rtype: list ''' self._statlog('set_multi') server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(mapping.iterkeys(), key_prefix) # send out all requests on each server before reading anything dead_servers = [] for server in server_keys.iterkeys(): bigcmd = [] write = bigcmd.append try: for key in server_keys[server]: # These are mangled keys store_info = self._val_to_store_info(mapping[prefixed_to_orig_key[key]], min_compress_len) write("set %s %d %d %d\r\n%s\r\n" % (key, store_info[0], time, store_info[1], store_info[2])) server.send_cmds(''.join(bigcmd)) except socket.error, msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) dead_servers.append(server) # if any servers died on the way, don't expect them to respond. for server in dead_servers: del server_keys[server] # short-circuit if there are no servers, just return all keys if not server_keys: return(mapping.keys()) notstored = [] # original keys. for server, keys in server_keys.iteritems(): try: for key in keys: line = server.readline() if line == 'STORED': continue else: notstored.append(prefixed_to_orig_key[key]) #un-mangle. except (_Error, socket.error), msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) return notstored def _val_to_store_info(self, val, min_compress_len): """ Transform val to a storable representation, returning a tuple of the flags, the length of the new value, and the new value itself. """ flags = 0 if isinstance(val, str): pass elif isinstance(val, int): flags |= Client._FLAG_INTEGER val = "%d" % val # force no attempt to compress this silly string. min_compress_len = 0 elif isinstance(val, long): flags |= Client._FLAG_LONG val = "%d" % val # force no attempt to compress this silly string. min_compress_len = 0 else: flags |= Client._FLAG_PICKLE file = StringIO() if self.picklerIsKeyword: pickler = self.pickler(file, protocol = self.pickleProtocol) else: pickler = self.pickler(file, self.pickleProtocol) if self.persistent_id: pickler.persistent_id = self.persistent_id pickler.dump(val) val = file.getvalue() lv = len(val) # We should try to compress if min_compress_len > 0 and we could # import zlib and this string is longer than our min threshold. if min_compress_len and _supports_compress and lv > min_compress_len: comp_val = compress(val) # Only retain the result if the compression result is smaller # than the original. if len(comp_val) < lv: flags |= Client._FLAG_COMPRESSED val = comp_val # silently do not store if value length exceeds maximum if len(val) >= SERVER_MAX_VALUE_LENGTH: return(0) return (flags, len(val), val) def _set(self, cmd, key, val, time, min_compress_len = 0, token=None): check_key(key) server, key = self._get_server(key) if not server: return 0 self._statlog(cmd) store_info = self._val_to_store_info(val, min_compress_len) if not store_info: return(0) if token is not None: cmd = "cas" fullcmd = "cas %s %d %d %d %s\r\n%s" % (key, store_info[0], time, store_info[1], token, store_info[2]) else: fullcmd = "%s %s %d %d %d\r\n%s" % (cmd, key, store_info[0], time, store_info[1], store_info[2]) try: server.send_cmd(fullcmd) result = server.expect("STORED") if (result == "STORED"): return True if (result == "NOT_FOUND"): raise NotFoundError(key) if token and result == "EXISTS": log.debug("Memcacheclient check-and-set failed") raise TokenMismatchError(key) log.error("Memcacheclient %s command failed with result (%s)" % (cmd, result)) return False except socket.error, msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) return 0 def get(self, key): '''Retrieves a key from the memcache. @return: The value or None. ''' check_key(key) server, key = self._get_server(key) if not server: raise MemcacheError("Memcache connection error") self._statlog('get') try: server.send_cmd("get %s" % key) rkey, flags, rlen, = self._expectvalue(server) if not rkey: return None value = self._recv_value(server, flags, rlen) server.expect("END") except (_Error, socket.error), msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) raise MemcacheError("Memcache connection error") return value def gets(self, key): '''Retrieves a key from the memcache. @return: The value or None. ''' check_key(key) server, key = self._get_server(key) if not server: raise MemcacheError("Memcache connection error") self._statlog('get') try: server.send_cmd("gets %s" % key) rkey, flags, rlen, cas_token = self._expectvalue_cas(server) if not rkey: return (None, None) value = self._recv_value(server, flags, rlen) server.expect("END") except (_Error, socket.error), msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) raise MemcacheError("Memcache connection error") return (value, cas_token) def get_multi(self, keys, key_prefix=''): ''' Retrieves multiple keys from the memcache doing just one query. >>> success = mc.set("foo", "bar") >>> success = mc.set("baz", 42) >>> mc.get_multi(["foo", "baz", "foobar"]) == {"foo": "bar", "baz": 42} 1 >>> mc.set_multi({'k1' : 1, 'k2' : 2}, key_prefix='pfx_') == [] 1 This looks up keys 'pfx_k1', 'pfx_k2', ... . Returned dict will just have unprefixed keys 'k1', 'k2'. >>> mc.get_multi(['k1', 'k2', 'nonexist'], key_prefix='pfx_') == {'k1' : 1, 'k2' : 2} 1 get_mult [ and L{set_multi} ] can take str()-ables like ints / longs as keys too. Such as your db pri key fields. They're rotored through str() before being passed off to memcache, with or without the use of a key_prefix. In this mode, the key_prefix could be a table name, and the key itself a db primary key number. >>> mc.set_multi({42: 'douglass adams', 46 : 'and 2 just ahead of me'}, key_prefix='numkeys_') == [] 1 >>> mc.get_multi([46, 42], key_prefix='numkeys_') == {42: 'douglass adams', 46 : 'and 2 just ahead of me'} 1 This method is recommended over regular L{get} as it lowers the number of total packets flying around your network, reducing total latency, since your app doesn't have to wait for each round-trip of L{get} before sending the next one. See also L{set_multi}. @param keys: An array of keys. @param key_prefix: A string to prefix each key when we communicate with memcache. Facilitates pseudo-namespaces within memcache. Returned dictionary keys will not have this prefix. @return: A dictionary of key/value pairs that were available. If key_prefix was provided, the keys in the retured dictionary will not have it present. ''' self._statlog('get_multi') server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(keys, key_prefix) # send out all requests on each server before reading anything dead_servers = [] for server in server_keys.iterkeys(): try: server.send_cmd("get %s" % " ".join(server_keys[server])) except socket.error, msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) dead_servers.append(server) # if any servers died on the way, don't expect them to respond. for server in dead_servers: del server_keys[server] retvals = {} for server in server_keys.iterkeys(): try: line = server.readline() while line and line != 'END': rkey, flags, rlen = self._expectvalue(server, line) # Bo Yang reports that this can sometimes be None if rkey is not None: val = self._recv_value(server, flags, rlen) try: retvals[prefixed_to_orig_key[rkey]] = val # un-prefix returned key. except KeyError: pass line = server.readline() except (_Error, socket.error), msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) return retvals def gets_multi(self, keys, key_prefix=''): ''' Retrieves multiple keys from the memcache doing just one query. See also L{gets} and L{get_multi}. ''' self._statlog('gets_multi') server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(keys, key_prefix) # send out all requests on each server before reading anything dead_servers = [] for server in server_keys.iterkeys(): try: server.send_cmd("gets %s" % " ".join(server_keys[server])) except socket.error, msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) dead_servers.append(server) # if any servers died on the way, don't expect them to respond. for server in dead_servers: del server_keys[server] retvals = {} for server in server_keys.iterkeys(): try: line = server.readline() while line and line != 'END': rkey, flags, rlen, cas_token = self._expectvalue_cas(server, line) # Bo Yang reports that this can sometimes be None if rkey is not None: val = self._recv_value(server, flags, rlen) try: retvals[prefixed_to_orig_key[rkey]] = (val, cas_token) # un-prefix returned key. except KeyError: pass line = server.readline() except (_Error, socket.error), msg: if type(msg) is types.TupleType: msg = msg[1] server.mark_dead(msg) return retvals def _expectvalue(self, server, line=None): if not line: line = server.readline() if line[:5] == 'VALUE': resp, rkey, flags, len = line.split() flags = int(flags) rlen = int(len) return (rkey, flags, rlen) else: return (None, None, None) def _expectvalue_cas(self, server, line=None): if not line: line = server.readline() if line[:5] == 'VALUE': resp, rkey, flags, len, rtoken = line.split() flags = int(flags) rlen = int(len) return (rkey, flags, rlen, rtoken) else: return (None, None, None, None) def _recv_value(self, server, flags, rlen): rlen += 2 # include \r\n buf = server.recv(rlen) if len(buf) != rlen: raise _Error("received %d bytes when expecting %d" % (len(buf), rlen)) if len(buf) == rlen: buf = buf[:-2] # strip \r\n if flags & Client._FLAG_COMPRESSED: buf = decompress(buf) if flags == 0 or flags == Client._FLAG_COMPRESSED: # Either a bare string or a compressed string now decompressed... val = buf elif flags & Client._FLAG_INTEGER: val = int(buf) elif flags & Client._FLAG_LONG: val = long(buf) elif flags & Client._FLAG_PICKLE: try: file = StringIO(buf) unpickler = self.unpickler(file) if self.persistent_load: unpickler.persistent_load = self.persistent_load val = unpickler.load() except Exception, e: self.debuglog('Pickle error: %s\n' % e) val = None else: self.debuglog("unknown flags on get: %x\n" % flags) return val class TestClient(Client): """ Fake memcache client for unit tests """ def __init__(self, servers, debug=0, pickleProtocol=0, pickler=pickle.Pickler, unpickler=pickle.Unpickler, pload=None, pid=None): local.__init__(self) super(TestClient, self).__init__(servers, debug=debug, pickleProtocol=pickleProtocol, pickler=pickler, unpickler=unpickler, pload=pload, pid=pid) self.data = {} self.token = 0 def get_stats(self): raise NotImplementedError() def get_slabs(self): raise NotImplementedError() def flush_all(self): raise NotImplementedError() def forget_dead_hosts(self): raise NotImplementedError() def delete_multi(self, keys, time=0, key_prefix=''): ''' Delete multiple keys in the memcache doing just one query. >>> notset_keys = mc.set_multi({'key1' : 'val1', 'key2' : 'val2'}) >>> mc.get_multi(['key1', 'key2']) == {'key1' : 'val1', 'key2' : 'val2'} 1 >>> mc.delete_multi(['key1', 'key2']) 1 >>> mc.get_multi(['key1', 'key2']) == {} 1 ''' self._statlog('delete_multi') for key in keys: key = key_prefix + key del self.data[key] return 1 def delete(self, key, time=0): '''Deletes a key from the memcache. @return: Nonzero on success. @param time: number of seconds any subsequent set / update commands should fail. Defaults to 0 for no delay. @rtype: int ''' check_key(key) del self.data[key] return 1 def incr(self, key, delta=1): raise NotImplementedError() def decr(self, key, delta=1): raise NotImplementedError() def add(self, key, val, time = 0, min_compress_len = 0): raise NotImplementedError() def append(self, key, val, time=0, min_compress_len=0): raise NotImplementedError() def prepend(self, key, val, time=0, min_compress_len=0): raise NotImplementedError() def replace(self, key, val, time=0, min_compress_len=0): raise NotImplementedError() def set(self, key, val, time=0, min_compress_len=0, token=None): self._statlog('set') return self._set("set", key, val, time, min_compress_len, token=token) def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0): self._statlog('set_multi') for key, val in mapping.iteritems(): key = key_prefix + key self._set("set", key, val, time, min_compress_len) return [] def _set(self, cmd, key, val, time, min_compress_len = 0, token=None): check_key(key) self._statlog(cmd) serialized = pickle.dumps(val, pickle.HIGHEST_PROTOCOL) if token is not None: if self.data.has_key(key): stored_val, stored_token = self.data[key] if token != stored_token: raise TokenMismatchError(key) self.data[key] = (serialized, str(self.token)) self.token += 1 return True def get(self, key): check_key(key) self._statlog('get') if self.data.has_key(key): stored_val, stored_token = self.data[key] val = pickle.loads(stored_val) return val return None def gets(self, key): check_key(key) if self.data.has_key(key): stored_val, stored_token = self.data[key] val = pickle.loads(stored_val) return (val, stored_token) return (None, None) def get_multi(self, keys, key_prefix=''): self._statlog('get_multi') results = {} for key in keys: key = key_prefix + key val = self.get(key) results[key] = val return results def gets_multi(self, keys, key_prefix=''): self._statlog('gets_multi') results = {} for key in keys: key = key_prefix + key result = self.gets(key) if result[1] is not None: results[key] = result return results class _Host: _DEAD_RETRY = 1 # number of seconds before retrying a dead server. _SOCKET_TIMEOUT = 3 # number of seconds before sockets timeout. def __init__(self, host, debugfunc=None): if isinstance(host, types.TupleType): host, self.weight = host else: self.weight = 1 # parse the connection string m = re.match(r'^(?Punix):(?P.*)$', host) if not m: m = re.match(r'^(?Pinet):' r'(?P[^:]+)(:(?P[0-9]+))?$', host) if not m: m = re.match(r'^(?P[^:]+):(?P[0-9]+)$', host) if not m: raise ValueError('Unable to parse connection string: "%s"' % host) hostData = m.groupdict() if hostData.get('proto') == 'unix': self.family = socket.AF_UNIX self.address = hostData['path'] else: self.family = socket.AF_INET self.ip = hostData['host'] self.port = int(hostData.get('port', 11211)) self.address = ( self.ip, self.port ) if not debugfunc: debugfunc = lambda x: x self.debuglog = debugfunc self.deaduntil = 0 self.socket = None self.buffer = '' def _check_dead(self): if self.deaduntil and self.deaduntil > time.time(): return 1 self.deaduntil = 0 return 0 def connect(self): if self._get_socket(): return 1 return 0 def mark_dead(self, reason): log.error("Memcacheclient socket marked dead (%s)" % (reason,)) self.debuglog("MemCache: %s: %s. Marking dead." % (self, reason)) self.deaduntil = time.time() + _Host._DEAD_RETRY self.close_socket() def _get_socket(self): if self._check_dead(): log.error("Memcacheclient _get_socket() found dead socket") return None if self.socket: return self.socket s = socket.socket(self.family, socket.SOCK_STREAM) if hasattr(s, 'settimeout'): s.settimeout(self._SOCKET_TIMEOUT) try: s.connect(self.address) except socket.timeout, msg: log.error("Memcacheclient _get_socket() connection timed out (%s)" % (msg,)) self.mark_dead("connect: %s" % msg) return None except socket.error, msg: if type(msg) is types.TupleType: msg = msg[1] log.error("Memcacheclient _get_socket() connection error (%s)" % (msg,)) self.mark_dead("connect: %s" % msg[1]) return None self.socket = s self.buffer = '' return s def close_socket(self): if self.socket: self.socket.close() self.socket = None def send_cmd(self, cmd): self.socket.sendall(cmd + '\r\n') def send_cmds(self, cmds): """ cmds already has trailing \r\n's applied """ self.socket.sendall(cmds) def readline(self): buf = self.buffer recv = self.socket.recv while True: index = buf.find('\r\n') if index >= 0: break data = recv(4096) if not data: self.mark_dead('Connection closed while reading from %s' % repr(self)) break buf += data if index >= 0: self.buffer = buf[index+2:] buf = buf[:index] else: self.buffer = '' return buf def expect(self, text): line = self.readline() if line != text: self.debuglog("while expecting '%s', got unexpected response '%s'" % (text, line)) return line def recv(self, rlen): self_socket_recv = self.socket.recv buf = self.buffer while len(buf) < rlen: foo = self_socket_recv(4096) buf += foo if len(foo) == 0: raise _Error, ( 'Read %d bytes, expecting %d, ' 'read returned 0 length bytes' % ( len(buf), rlen )) self.buffer = buf[rlen:] return buf[:rlen] def __str__(self): d = '' if self.deaduntil: d = " (dead until %d)" % self.deaduntil if self.family == socket.AF_INET: return "inet:%s:%d%s" % (self.address[0], self.address[1], d) else: return "unix:%s%s" % (self.address, d) def check_key(key, key_extra_len=0): """Checks sanity of key. Fails if: Key length is > SERVER_MAX_KEY_LENGTH (Raises MemcachedKeyLength). Contains control characters (Raises MemcachedKeyCharacterError). Is not a string (Raises MemcachedStringEncodingError) Is an unicode string (Raises MemcachedStringEncodingError) Is not a string (Raises MemcachedKeyError) Is None (Raises MemcachedKeyError) """ return # Short-circuit this expensive method if type(key) == types.TupleType: key = key[1] if not key: raise Client.MemcachedKeyNoneError, ("Key is None") if isinstance(key, unicode): raise Client.MemcachedStringEncodingError, ("Keys must be str()'s, not " "unicode. Convert your unicode strings using " "mystring.encode(charset)!") if not isinstance(key, str): raise Client.MemcachedKeyTypeError, ("Key must be str()'s") if isinstance(key, basestring): if len(key) + key_extra_len > SERVER_MAX_KEY_LENGTH: raise Client.MemcachedKeyLengthError, ("Key length is > %s" % SERVER_MAX_KEY_LENGTH) for char in key: if ord(char) < 32 or ord(char) == 127: raise Client.MemcachedKeyCharacterError, "Control characters not allowed" def _doctest(): import doctest, memcacheclient servers = ["127.0.0.1:11211"] mc = Client(servers, debug=1) globs = {"mc": mc} return doctest.testmod(memcacheclient, globs=globs) if __name__ == "__main__": print("Testing docstrings...") _doctest() print("Running tests:") print serverList = [["127.0.0.1:11211"]] if '--do-unix' in sys.argv: serverList.append([os.path.join(os.getcwd(), 'memcached.socket')]) for servers in serverList: mc = Client(servers, debug=1) def to_s(val): if not isinstance(val, types.StringTypes): return "%s (%s)" % (val, type(val)) return "%s" % val def test_setget(key, val): print("Testing set/get {'%s': %s} ..." % (to_s(key), to_s(val)), end="") mc.set(key, val) newval = mc.get(key) if newval == val: print("OK") return 1 else: print("FAIL") return 0 class FooStruct: def __init__(self): self.bar = "baz" def __str__(self): return "A FooStruct" def __eq__(self, other): if isinstance(other, FooStruct): return self.bar == other.bar return 0 test_setget("a_string", "some random string") test_setget("an_integer", 42) if test_setget("long", long(1<<30)): print("Testing delete ...", end="") if mc.delete("long"): print("OK") else: print("FAIL") print("Testing get_multi ...", end="") print(mc.get_multi(["a_string", "an_integer"])) print("Testing get(unknown value) ...", end="") print(to_s(mc.get("unknown_value"))) f = FooStruct() test_setget("foostruct", f) print("Testing incr ...", end="") x = mc.incr("an_integer", 1) if x == 43: print("OK") else: print("FAIL") print("Testing decr ...", end="") x = mc.decr("an_integer", 1) if x == 42: print("OK") else: print("FAIL") # sanity tests print("Testing sending spaces...", end="") try: x = mc.set("this has spaces", 1) except Client.MemcachedKeyCharacterError, msg: print("OK") else: print("FAIL") print("Testing sending control characters...", end="") try: x = mc.set("this\x10has\x11control characters\x02", 1) except Client.MemcachedKeyCharacterError, msg: print("OK") else: print("FAIL") print("Testing using insanely long key...", end="") try: x = mc.set('a'*SERVER_MAX_KEY_LENGTH + 'aaaa', 1) except Client.MemcachedKeyLengthError, msg: print("OK") else: print("FAIL") print("Testing sending a unicode-string key...", end="") try: x = mc.set(u'keyhere', 1) except Client.MemcachedStringEncodingError, msg: print("OK", end="") else: print("FAIL", end="") try: x = mc.set((u'a'*SERVER_MAX_KEY_LENGTH).encode('utf-8'), 1) except: print("FAIL", end="") else: print("OK", end="") import pickle s = pickle.loads('V\\u4f1a\np0\n.') try: x = mc.set((s*SERVER_MAX_KEY_LENGTH).encode('utf-8'), 1) except Client.MemcachedKeyLengthError: print("OK") else: print("FAIL") print("Testing using a value larger than the memcached value limit...", end="") x = mc.set('keyhere', 'a'*SERVER_MAX_VALUE_LENGTH) if mc.get('keyhere') == None: print("OK", end="") else: print("FAIL", end="") x = mc.set('keyhere', 'a'*SERVER_MAX_VALUE_LENGTH + 'aaa') if mc.get('keyhere') == None: print("OK") else: print("FAIL") print("Testing set_multi() with no memcacheds running", end="") mc.disconnect_all() errors = mc.set_multi({'keyhere' : 'a', 'keythere' : 'b'}) if errors != []: print("FAIL") else: print("OK") print("Testing delete_multi() with no memcacheds running", end="") mc.disconnect_all() ret = mc.delete_multi({'keyhere' : 'a', 'keythere' : 'b'}) if ret != 1: print("FAIL") else: print("OK") # vim: ts=4 sw=4 et : calendarserver-5.2+dfsg/twext/python/vcomponent.py0000644000175000017500000000175312263343324021555 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ iCalendar utilities """ __all__ = [ "VComponent", "VProperty", "InvalidICalendarDataError", ] # FIXME: Move twistedcaldav.ical here, but that module needs some # cleanup first. Perhaps after porting to libical? from twistedcaldav.ical import Component as VComponent from twistedcaldav.ical import Property as VProperty from twistedcaldav.ical import InvalidICalendarDataError calendarserver-5.2+dfsg/twext/python/__init__.py0000644000175000017500000000120512263343324021114 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extensions to twisted.python. """ calendarserver-5.2+dfsg/twext/python/clsprop.py0000644000175000017500000000261612263343324021046 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ A small utility for defining static class properties. """ class classproperty(object): """ Decorator for a method that wants to return a static class property. The decorated method will only be invoked once, for each class, and that value will be returned for that class. """ def __init__(self, thunk=None, cache=True): self.cache = cache self.thunk = thunk self._classcache = {} def __call__(self, thunk): return self.__class__(thunk, self.cache) def __get__(self, instance, owner): if not self.cache: return self.thunk(owner) cc = self._classcache if owner in cc: cached = cc[owner] else: cached = self.thunk(owner) cc[owner] = cached return cached calendarserver-5.2+dfsg/twext/patches.py0000644000175000017500000000527112263343324017472 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Patches for behavior in Twisted which calendarserver requires to be different. """ __all__ = [] import sys from twisted import version from twisted.python.versions import Version from twisted.python.modules import getModule def _hasIPv6ClientSupport(): """ Does the loaded version of Twisted have IPv6 client support? """ lastVersionWithoutIPv6Clients = Version("twisted", 12, 0, 0) if version > lastVersionWithoutIPv6Clients: return True elif version == lastVersionWithoutIPv6Clients: # It could be a snapshot of trunk or a branch with this bug fixed. # Don't load the module, though, as that would be a bunch of # unnecessary work. return "_resolveIPv6" in (getModule("twisted.internet.tcp") .filePath.getContent()) else: return False def _addBackports(): """ We currently require 2 backported bugfixes from a future release of Twisted, for IPv6 support: - U{IPv6 client support } - U{TCP endpoint cancellation } This function will activate those backports. (Note it must be run before any of the modules in question are imported or it will raise an exception.) This function, L{_hasIPv6ClientSupport}, and all the associated backports (i.e., all of C{twext/backport}) should be removed upon upgrading our minimum required Twisted version. """ from twext.backport import internet as bpinternet from twisted import internet internet.__path__[:] = bpinternet.__path__ + internet.__path__ # Make sure none of the backports are loaded yet. backports = getModule("twext.backport.internet") for submod in backports.iterModules(): subname = submod.name.split(".")[-1] tiname = 'twisted.internet.' + subname if tiname in sys.modules: raise RuntimeError( tiname + "already loaded, cannot load required backport") if not _hasIPv6ClientSupport(): _addBackports() from twisted.mail.imap4 import Command Command._1_RESPONSES += tuple(['BYE']) calendarserver-5.2+dfsg/twext/enterprise/0000755000175000017500000000000012322625326017645 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/enterprise/ienterprise.py0000644000175000017500000002423712263343324022557 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Interfaces, mostly related to L{twext.enterprise.adbapi2}. """ __all__ = [ "IAsyncTransaction", "ISQLExecutor", "ICommandBlock", "IQueuer", "IDerivedParameter", "AlreadyFinishedError", "ConnectionError", "POSTGRES_DIALECT", "SQLITE_DIALECT", "ORACLE_DIALECT", "ORACLE_TABLE_NAME_MAX", ] from zope.interface import Interface, Attribute class AlreadyFinishedError(Exception): """ The transaction was already completed via an C{abort} or C{commit} and cannot be aborted or committed again. """ class ConnectionError(Exception): """ An error occurred with the underlying database connection. """ POSTGRES_DIALECT = 'postgres-dialect' ORACLE_DIALECT = 'oracle-dialect' SQLITE_DIALECT = 'sqlite-dialect' ORACLE_TABLE_NAME_MAX = 30 class ISQLExecutor(Interface): """ Base SQL-execution interface, for a group of commands or a transaction. """ paramstyle = Attribute( """ A copy of the 'paramstyle' attribute from a DB-API 2.0 module. """) dialect = Attribute( """ A copy of the 'dialect' attribute from the connection pool. One of the C{*_DIALECT} constants in this module, such as C{POSTGRES_DIALECT}. """) def execSQL(sql, args=(), raiseOnZeroRowCount=None): """ Execute some SQL. @param sql: an SQL string. @type sql: C{str} @param args: C{list} of arguments to interpolate into C{sql}. @param raiseOnZeroRowCount: a 0-argument callable which returns an exception to raise if the executed SQL does not affect any rows. @return: L{Deferred} which fires C{list} of C{tuple} @raise: C{raiseOnZeroRowCount} if it was specified and no rows were affected. """ class IAsyncTransaction(ISQLExecutor): """ Asynchronous execution of SQL. Note that there is no C{begin()} method; if an L{IAsyncTransaction} exists at all, it is assumed to have been started. """ def commit(): """ Commit changes caused by this transaction. @return: L{Deferred} which fires with C{None} upon successful completion of this transaction, or fails if this transaction could not be committed. It fails with L{AlreadyFinishedError} if the transaction has already been committed or rolled back. """ def preCommit(operation): """ Perform the given operation when this L{IAsyncTransaction}'s C{commit} method is called, but before the underlying transaction commits. If any exception is raised by this operation, underlying database commit will be blocked and rollback run instead. @param operation: a 0-argument callable that may return a L{Deferred}. If it does, then the subsequent operations added by L{postCommit} will not fire until that L{Deferred} fires. """ def postCommit(operation): """ Perform the given operation only after this L{IAsyncTransaction} commits. These will be invoked before the L{Deferred} returned by L{IAsyncTransaction.commit} fires. @param operation: a 0-argument callable that may return a L{Deferred}. If it does, then the subsequent operations added by L{postCommit} will not fire until that L{Deferred} fires. """ def abort(): """ Roll back changes caused by this transaction. @return: L{Deferred} which fires with C{None} upon successful rollback of this transaction. """ def postAbort(operation): """ Invoke a callback after abort. @see: L{IAsyncTransaction.postCommit} @param operation: 0-argument callable, potentially returning a L{Deferred}. """ def commandBlock(): """ Create an object which will cause the commands executed on it to be grouped together. This is useful when using database-specific features such as sub-transactions where order of execution is importnat, but where application code may need to perform I/O to determine what SQL, exactly, it wants to execute. Consider this fairly contrived example for an imaginary database:: def storeWebPage(url, block): block.execSQL("BEGIN SUB TRANSACTION") got = getPage(url) def gotPage(data): block.execSQL("INSERT INTO PAGES (TEXT) VALUES (?)", [data]) block.execSQL("INSERT INTO INDEX (TOKENS) VALUES (?)", [tokenize(data)]) lastStmt = block.execSQL("END SUB TRANSACTION") block.end() return lastStmt return got.addCallback(gotPage) gatherResults([storeWebPage(url, txn.commandBlock()) for url in urls]).addCallbacks( lambda x: txn.commit(), lambda f: txn.abort() ) This fires off all the C{getPage} requests in parallel, and prepares all the necessary SQL immediately as the results arrive, but executes those statements in order. In the above example, this makes sure to store the page and its tokens together, another use for this might be to store a computed aggregate (such as a sum) at a particular point in a transaction, without sacrificing parallelism. @rtype: L{ICommandBlock} """ class ICommandBlock(ISQLExecutor): """ This is a block of SQL commands that are grouped together. @see: L{IAsyncTransaction.commandBlock} """ def end(): """ End this command block, allowing other commands queued on the underlying transaction to end. @note: This is I{not} the same as either L{IAsyncTransaction.commit} or L{IAsyncTransaction.abort}, since it does not denote success or failure; merely that the command block has completed and other statements may now be executed. Since sub-transactions are a database-specific feature, they must be implemented at a higher-level than this facility provides (although this facility may be useful in their implementation). Also note that, unlike either of those methods, this does I{not} return a Deferred: if you want to know when the block has completed, simply add a callback to the last L{ICommandBlock.execSQL} call executed on this L{ICommandBlock}. (This may be changed in a future version for the sake of convenience, however.) """ class IDerivedParameter(Interface): """ A parameter which needs to be derived from the underlying DB-API cursor; implicitly, meaning that this must also interact with the actual thread manipulating said cursor. If a provider of this interface is passed in the C{args} argument to L{IAsyncTransaction.execSQL}, it will have its C{prequery} and C{postquery} methods invoked on it before and after executing the SQL query in question, respectively. @note: L{IDerivedParameter} providers must also always be I{pickleable}, because in some cases the actual database cursor objects will be on the other end of a network connection. For an explanation of why this might be, see L{twext.enterprise.adbapi2.ConnectionPoolConnection}. """ def preQuery(cursor): """ Before running a query, invoke this method with the cursor that the query will be run on. (This can be used, for example, to allocate a special database-specific variable based on the cursor, like an out parameter.) @param cursor: the DB-API cursor. @return: the concrete value which should be passed to the DB-API layer. """ def postQuery(cursor): """ After running a query, invoke this method in the DB-API thread. (This can be used, for example, to manipulate any state created in the preQuery method.) @param cursor: the DB-API cursor. @return: C{None} """ class IQueuer(Interface): """ An L{IQueuer} can enqueue work for later execution. """ def enqueueWork(self, transaction, workItemType, **kw): """ Perform some work, eventually. @param transaction: an L{IAsyncTransaction} within which to I{commit} to doing the work. Note that this work will likely be done later (but depending on various factors, may actually be done within this transaction as well). @param workItemType: the type of work item to create. @type workItemType: L{type}, specifically, a subtype of L{WorkItem } @param kw: The keyword parameters are relayed to C{workItemType.create} to create an appropriately initialized item. @return: a work proposal that allows tracking of the various phases of completion of the work item. @rtype: L{twext.enterprise.queue.WorkItem} """ def callWithNewProposals(self, callback): """ Tells the IQueuer to call a callback method whenever a new WorkProposal is created. @param callback: a callable which accepts a single parameter, a L{WorkProposal} """ def transferProposalCallbacks(self, newQueuer): """ Transfer the registered callbacks to the new queuer. """ calendarserver-5.2+dfsg/twext/enterprise/test/0000755000175000017500000000000012322625326020624 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/enterprise/test/test_fixtures.py0000644000175000017500000000323412263343324024107 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.enterprise.fixtures}. Quis custodiet ipsos custodes? This module, that's who. """ from twext.enterprise.fixtures import buildConnectionPool from twisted.trial.unittest import TestCase from twisted.trial.reporter import TestResult from twext.enterprise.adbapi2 import ConnectionPool class PoolTests(TestCase): """ Tests for fixtures that create a connection pool. """ def test_buildConnectionPool(self): """ L{buildConnectionPool} returns a L{ConnectionPool} which will be running only for the duration of the test. """ collect = [] class SampleTest(TestCase): def setUp(self): self.pool = buildConnectionPool(self) def test_sample(self): collect.append(self.pool.running) def tearDown(self): collect.append(self.pool.running) r = TestResult() t = SampleTest("test_sample") t.run(r) self.assertIsInstance(t.pool, ConnectionPool) self.assertEqual([True, False], collect) calendarserver-5.2+dfsg/twext/enterprise/test/test_queue.py0000644000175000017500000007235112276242656023402 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.enterprise.queue}. """ import datetime # TODO: There should be a store-building utility within twext.enterprise. from twisted.protocols.amp import Command from twisted.internet.task import Clock as _Clock from txdav.common.datastore.test.util import buildStore from twext.enterprise.dal.syntax import SchemaSyntax, Select from twext.enterprise.dal.record import fromTable from twext.enterprise.dal.test.test_parseschema import SchemaTestHelper from twext.enterprise.queue import ( inTransaction, PeerConnectionPool, WorkItem, astimestamp ) from twisted.trial.unittest import TestCase from twisted.python.failure import Failure from twisted.internet.defer import ( Deferred, inlineCallbacks, gatherResults, passthru#, returnValue ) from twisted.application.service import Service, MultiService from twext.enterprise.queue import ( LocalPerformer, _IWorkPerformer, WorkerConnectionPool, SchemaAMP, TableSyntaxByName ) from twext.enterprise.dal.record import Record from twext.enterprise.queue import ConnectionFromPeerNode from twext.enterprise.fixtures import buildConnectionPool from zope.interface.verify import verifyObject from twisted.test.proto_helpers import StringTransport, MemoryReactor from twext.enterprise.fixtures import SteppablePoolHelper from twisted.internet.defer import returnValue from twext.enterprise.queue import LocalQueuer from twext.enterprise.fixtures import ConnectionPoolHelper from twext.enterprise.queue import _BaseQueuer, NonPerformingQueuer import twext.enterprise.queue class Clock(_Clock): """ More careful L{IReactorTime} fake which mimics the exception behavior of the real reactor. """ def callLater(self, _seconds, _f, *args, **kw): if _seconds < 0: raise ValueError("%s<0: " % (_seconds,)) return super(Clock, self).callLater(_seconds, _f, *args, **kw) class MemoryReactorWithClock(MemoryReactor, Clock): """ Simulate a real reactor. """ def __init__(self): MemoryReactor.__init__(self) Clock.__init__(self) def transactionally(transactionCreator): """ Perform the decorated function immediately in a transaction, replacing its name with a L{Deferred}. Use like so:: @transactionally(connectionPool.connection) @inlineCallbacks def it(txn): yield txn.doSomething() it.addCallback(firedWhenDone) @param transactionCreator: A 0-arg callable that returns an L{IAsyncTransaction}. """ def thunk(operation): return inTransaction(transactionCreator, operation) return thunk class UtilityTests(TestCase): """ Tests for supporting utilities. """ def test_inTransactionSuccess(self): """ L{inTransaction} invokes its C{transactionCreator} argument, and then returns a L{Deferred} which fires with the result of its C{operation} argument when it succeeds. """ class faketxn(object): def __init__(self): self.commits = [] self.aborts = [] def commit(self): self.commits.append(Deferred()) return self.commits[-1] def abort(self): self.aborts.append(Deferred()) return self.aborts[-1] createdTxns = [] def createTxn(): createdTxns.append(faketxn()) return createdTxns[-1] dfrs = [] def operation(t): self.assertIdentical(t, createdTxns[-1]) dfrs.append(Deferred()) return dfrs[-1] d = inTransaction(createTxn, operation) x = [] d.addCallback(x.append) self.assertEquals(x, []) self.assertEquals(len(dfrs), 1) dfrs[0].callback(35) # Commit in progress, so still no result... self.assertEquals(x, []) createdTxns[0].commits[0].callback(42) # Committed, everything's done. self.assertEquals(x, [35]) class SimpleSchemaHelper(SchemaTestHelper): def id(self): return 'worker' SQL = passthru schemaText = SQL(""" create table DUMMY_WORK_ITEM (WORK_ID integer primary key, NOT_BEFORE timestamp, A integer, B integer, DELETE_ON_LOAD integer default 0); create table DUMMY_WORK_DONE (WORK_ID integer primary key, A_PLUS_B integer); """) nodeSchema = SQL(""" create table NODE_INFO (HOSTNAME varchar(255) not null, PID integer not null, PORT integer not null, TIME timestamp default current_timestamp not null, primary key (HOSTNAME, PORT)); """) schema = SchemaSyntax(SimpleSchemaHelper().schemaFromString(schemaText)) dropSQL = ["drop table {name}".format(name=table.model.name) for table in schema] class DummyWorkDone(Record, fromTable(schema.DUMMY_WORK_DONE)): """ Work result. """ class DummyWorkItem(WorkItem, fromTable(schema.DUMMY_WORK_ITEM)): """ Sample L{WorkItem} subclass that adds two integers together and stores them in another table. """ def doWork(self): return DummyWorkDone.create(self.transaction, workID=self.workID, aPlusB=self.a + self.b) @classmethod @inlineCallbacks def load(cls, txn, *a, **kw): """ Load L{DummyWorkItem} as normal... unless the loaded item has C{DELETE_ON_LOAD} set, in which case, do a deletion of this same row in a concurrent transaction, then commit it. """ self = yield super(DummyWorkItem, cls).load(txn, *a, **kw) if self.deleteOnLoad: otherTransaction = txn.concurrently() otherSelf = yield super(DummyWorkItem, cls).load(txn, *a, **kw) yield otherSelf.delete() yield otherTransaction.commit() returnValue(self) class SchemaAMPTests(TestCase): """ Tests for L{SchemaAMP} faithfully relaying tables across the wire. """ def test_sendTableWithName(self): """ You can send a reference to a table through a L{SchemaAMP} via L{TableSyntaxByName}. """ client = SchemaAMP(schema) class SampleCommand(Command): arguments = [('table', TableSyntaxByName())] class Receiver(SchemaAMP): @SampleCommand.responder def gotIt(self, table): self.it = table return {} server = Receiver(schema) clientT = StringTransport() serverT = StringTransport() client.makeConnection(clientT) server.makeConnection(serverT) client.callRemote(SampleCommand, table=schema.DUMMY_WORK_ITEM) server.dataReceived(clientT.io.getvalue()) self.assertEqual(server.it, schema.DUMMY_WORK_ITEM) class WorkItemTests(TestCase): """ A L{WorkItem} is an item of work that can be executed. """ def test_forTable(self): """ L{WorkItem.forTable} returns L{WorkItem} subclasses mapped to the given table. """ self.assertIdentical(WorkItem.forTable(schema.DUMMY_WORK_ITEM), DummyWorkItem) class WorkerConnectionPoolTests(TestCase): """ A L{WorkerConnectionPool} is responsible for managing, in a node's controller (master) process, the collection of worker (slave) processes that are capable of executing queue work. """ class WorkProposalTests(TestCase): """ Tests for L{WorkProposal}. """ def test_whenProposedSuccess(self): """ The L{Deferred} returned by L{WorkProposal.whenProposed} fires when the SQL sent to the database has completed. """ cph = ConnectionPoolHelper() cph.setUp(test=self) cph.pauseHolders() lq = LocalQueuer(cph.createTransaction) enqTxn = cph.createTransaction() wp = lq.enqueueWork(enqTxn, DummyWorkItem, a=3, b=4) d = wp.whenProposed() r = cph.resultOf(d) self.assertEquals(r, []) cph.flushHolders() self.assertEquals(len(r), 1) def test_whenProposedFailure(self): """ The L{Deferred} returned by L{WorkProposal.whenProposed} fails with an errback when the SQL executed to create the WorkItem row fails. """ cph = ConnectionPoolHelper() cph.setUp(self) cph.pauseHolders() firstConnection = cph.factory.willConnectTo() enqTxn = cph.createTransaction() # Execute some SQL on the connection before enqueueing the work-item so # that we don't get the initial-statement. enqTxn.execSQL("some sql") lq = LocalQueuer(cph.createTransaction) cph.flushHolders() cph.pauseHolders() wp = lq.enqueueWork(enqTxn, DummyWorkItem, a=3, b=4) firstConnection.executeWillFail(lambda: RuntimeError("foo")) d = wp.whenProposed() r = cph.resultOf(d) self.assertEquals(r, []) cph.flushHolders() self.assertEquals(len(r), 1) self.assertIsInstance(r[0], Failure) class PeerConnectionPoolUnitTests(TestCase): """ L{PeerConnectionPool} has many internal components. """ def setUp(self): """ Create a L{PeerConnectionPool} that is just initialized enough. """ self.pcp = PeerConnectionPool(None, None, 4321, schema) def checkPerformer(self, cls): """ Verify that the performer returned by L{PeerConnectionPool.choosePerformer}. """ performer = self.pcp.choosePerformer() self.failUnlessIsInstance(performer, cls) verifyObject(_IWorkPerformer, performer) def test_choosingPerformerWhenNoPeersAndNoWorkers(self): """ If L{PeerConnectionPool.choosePerformer} is invoked when no workers have spawned and no peers have established connections (either incoming or outgoing), then it chooses an implementation of C{performWork} that simply executes the work locally. """ self.checkPerformer(LocalPerformer) def test_choosingPerformerWithLocalCapacity(self): """ If L{PeerConnectionPool.choosePerformer} is invoked when some workers have spawned, then it should choose the worker pool as the local performer. """ # Give it some local capacity. wlf = self.pcp.workerListenerFactory() proto = wlf.buildProtocol(None) proto.makeConnection(StringTransport()) # Sanity check. self.assertEqual(len(self.pcp.workerPool.workers), 1) self.assertEqual(self.pcp.workerPool.hasAvailableCapacity(), True) # Now it has some capacity. self.checkPerformer(WorkerConnectionPool) def test_choosingPerformerFromNetwork(self): """ If L{PeerConnectionPool.choosePerformer} is invoked when no workers have spawned but some peers have connected, then it should choose a connection from the network to perform it. """ peer = PeerConnectionPool(None, None, 4322, schema) local = self.pcp.peerFactory().buildProtocol(None) remote = peer.peerFactory().buildProtocol(None) connection = Connection(local, remote) connection.start() self.checkPerformer(ConnectionFromPeerNode) def test_performingWorkOnNetwork(self): """ The L{PerformWork} command will get relayed to the remote peer controller. """ peer = PeerConnectionPool(None, None, 4322, schema) local = self.pcp.peerFactory().buildProtocol(None) remote = peer.peerFactory().buildProtocol(None) connection = Connection(local, remote) connection.start() d = Deferred() class DummyPerformer(object): def performWork(self, table, workID): self.table = table self.workID = workID return d # Doing real database I/O in this test would be tedious so fake the # first method in the call stack which actually talks to the DB. dummy = DummyPerformer() def chooseDummy(onlyLocally=False): return dummy peer.choosePerformer = chooseDummy performed = local.performWork(schema.DUMMY_WORK_ITEM, 7384) performResult = [] performed.addCallback(performResult.append) # Sanity check. self.assertEquals(performResult, []) connection.flush() self.assertEquals(dummy.table, schema.DUMMY_WORK_ITEM) self.assertEquals(dummy.workID, 7384) self.assertEquals(performResult, []) d.callback(128374) connection.flush() self.assertEquals(performResult, [None]) def test_choosePerformerSorted(self): """ If L{PeerConnectionPool.choosePerformer} is invoked make it return the peer with the least load. """ peer = PeerConnectionPool(None, None, 4322, schema) class DummyPeer(object): def __init__(self, name, load): self.name = name self.load = load def currentLoadEstimate(self): return self.load apeer = DummyPeer("A", 1) bpeer = DummyPeer("B", 0) cpeer = DummyPeer("C", 2) peer.addPeerConnection(apeer) peer.addPeerConnection(bpeer) peer.addPeerConnection(cpeer) performer = peer.choosePerformer(onlyLocally=False) self.assertEqual(performer, bpeer) bpeer.load = 2 performer = peer.choosePerformer(onlyLocally=False) self.assertEqual(performer, apeer) @inlineCallbacks def test_notBeforeWhenCheckingForLostWork(self): """ L{PeerConnectionPool._periodicLostWorkCheck} should execute any outstanding work items, but only those that are expired. """ dbpool = buildConnectionPool(self, schemaText + nodeSchema) # An arbitrary point in time. fakeNow = datetime.datetime(2012, 12, 12, 12, 12, 12) # *why* does datetime still not have .astimestamp() sinceEpoch = astimestamp(fakeNow) clock = Clock() clock.advance(sinceEpoch) qpool = PeerConnectionPool(clock, dbpool.connection, 0, schema) # Let's create a couple of work items directly, not via the enqueue # method, so that they exist but nobody will try to immediately execute # them. @transactionally(dbpool.connection) @inlineCallbacks def setup(txn): # First, one that's right now. yield DummyWorkItem.create(txn, a=1, b=2, notBefore=fakeNow) # Next, create one that's actually far enough into the past to run. yield DummyWorkItem.create( txn, a=3, b=4, notBefore=( # Schedule it in the past so that it should have already # run. fakeNow - datetime.timedelta( seconds=qpool.queueProcessTimeout + 20 ) ) ) # Finally, one that's actually scheduled for the future. yield DummyWorkItem.create( txn, a=10, b=20, notBefore=fakeNow + datetime.timedelta(1000) ) yield setup yield qpool._periodicLostWorkCheck() @transactionally(dbpool.connection) def check(txn): return DummyWorkDone.all(txn) every = yield check self.assertEquals([x.aPlusB for x in every], [7]) @inlineCallbacks def test_notBeforeWhenEnqueueing(self): """ L{PeerConnectionPool.enqueueWork} enqueues some work immediately, but only executes it when enough time has elapsed to allow the C{notBefore} attribute of the given work item to have passed. """ dbpool = buildConnectionPool(self, schemaText + nodeSchema) fakeNow = datetime.datetime(2012, 12, 12, 12, 12, 12) sinceEpoch = astimestamp(fakeNow) clock = Clock() clock.advance(sinceEpoch) qpool = PeerConnectionPool(clock, dbpool.connection, 0, schema) realChoosePerformer = qpool.choosePerformer performerChosen = [] def catchPerformerChoice(): result = realChoosePerformer() performerChosen.append(True) return result qpool.choosePerformer = catchPerformerChoice @transactionally(dbpool.connection) def check(txn): return qpool.enqueueWork( txn, DummyWorkItem, a=3, b=9, notBefore=datetime.datetime(2012, 12, 12, 12, 12, 20) ).whenProposed() proposal = yield check # This is going to schedule the work to happen with some asynchronous # I/O in the middle; this is a problem because how do we know when it's # time to check to see if the work has started? We need to intercept # the thing that kicks off the work; we can then wait for the work # itself. self.assertEquals(performerChosen, []) # Advance to exactly the appointed second. clock.advance(20 - 12) self.assertEquals(performerChosen, [True]) # FIXME: if this fails, it will hang, but that's better than no # notification that it is broken at all. result = yield proposal.whenExecuted() self.assertIdentical(result, proposal) @inlineCallbacks def test_notBeforeBefore(self): """ L{PeerConnectionPool.enqueueWork} will execute its work immediately if the C{notBefore} attribute of the work item in question is in the past. """ dbpool = buildConnectionPool(self, schemaText + nodeSchema) fakeNow = datetime.datetime(2012, 12, 12, 12, 12, 12) sinceEpoch = astimestamp(fakeNow) clock = Clock() clock.advance(sinceEpoch) qpool = PeerConnectionPool(clock, dbpool.connection, 0, schema) realChoosePerformer = qpool.choosePerformer performerChosen = [] def catchPerformerChoice(): result = realChoosePerformer() performerChosen.append(True) return result qpool.choosePerformer = catchPerformerChoice @transactionally(dbpool.connection) def check(txn): return qpool.enqueueWork( txn, DummyWorkItem, a=3, b=9, notBefore=datetime.datetime(2012, 12, 12, 12, 12, 0) ).whenProposed() proposal = yield check clock.advance(1000) # Advance far beyond the given timestamp. self.assertEquals(performerChosen, [True]) result = yield proposal.whenExecuted() self.assertIdentical(result, proposal) def test_workerConnectionPoolPerformWork(self): """ L{WorkerConnectionPool.performWork} performs work by selecting a L{ConnectionFromWorker} and sending it a L{PerformWork} command. """ clock = Clock() peerPool = PeerConnectionPool(clock, None, 4322, schema) factory = peerPool.workerListenerFactory() def peer(): p = factory.buildProtocol(None) t = StringTransport() p.makeConnection(t) return p, t worker1, _ignore_trans1 = peer() worker2, _ignore_trans2 = peer() # Ask the worker to do something. worker1.performWork(schema.DUMMY_WORK_ITEM, 1) self.assertEquals(worker1.currentLoad, 1) self.assertEquals(worker2.currentLoad, 0) # Now ask the pool to do something peerPool.workerPool.performWork(schema.DUMMY_WORK_ITEM, 2) self.assertEquals(worker1.currentLoad, 1) self.assertEquals(worker2.currentLoad, 1) def test_poolStartServiceChecksForWork(self): """ L{PeerConnectionPool.startService} kicks off the idle work-check loop. """ reactor = MemoryReactorWithClock() cph = SteppablePoolHelper(nodeSchema + schemaText) then = datetime.datetime(2012, 12, 12, 12, 12, 0) reactor.advance(astimestamp(then)) cph.setUp(self) pcp = PeerConnectionPool(reactor, cph.pool.connection, 4321, schema) now = then + datetime.timedelta(seconds=pcp.queueProcessTimeout * 2) @transactionally(cph.pool.connection) def createOldWork(txn): one = DummyWorkItem.create(txn, workID=1, a=3, b=4, notBefore=then) two = DummyWorkItem.create(txn, workID=2, a=7, b=9, notBefore=now) return gatherResults([one, two]) pcp.startService() cph.flushHolders() reactor.advance(pcp.queueProcessTimeout * 2) self.assertEquals( cph.rows("select * from DUMMY_WORK_DONE"), [(1, 7)] ) cph.rows("delete from DUMMY_WORK_DONE") reactor.advance(pcp.queueProcessTimeout * 2) self.assertEquals( cph.rows("select * from DUMMY_WORK_DONE"), [(2, 16)] ) class HalfConnection(object): def __init__(self, protocol): self.protocol = protocol self.transport = StringTransport() def start(self): """ Hook up the protocol and the transport. """ self.protocol.makeConnection(self.transport) def extract(self): """ Extract the data currently present in this protocol's output buffer. """ io = self.transport.io value = io.getvalue() io.seek(0) io.truncate() return value def deliver(self, data): """ Deliver the given data to this L{HalfConnection}'s protocol's C{dataReceived} method. @return: a boolean indicating whether any data was delivered. @rtype: L{bool} """ if data: self.protocol.dataReceived(data) return True return False class Connection(object): def __init__(self, local, remote): """ Connect two protocol instances to each other via string transports. """ self.receiver = HalfConnection(local) self.sender = HalfConnection(remote) def start(self): """ Start up the connection. """ self.sender.start() self.receiver.start() def pump(self): """ Relay data in one direction between the two connections. """ result = self.receiver.deliver(self.sender.extract()) self.receiver, self.sender = self.sender, self.receiver return result def flush(self, turns=10): """ Keep relaying data until there's no more. """ for _ignore_x in range(turns): if not (self.pump() or self.pump()): return class PeerConnectionPoolIntegrationTests(TestCase): """ L{PeerConnectionPool} is the service responsible for coordinating eventually-consistent task queuing within a cluster. """ @inlineCallbacks def setUp(self): """ L{PeerConnectionPool} requires access to a database and the reactor. """ self.store = yield buildStore(self, None) def doit(txn): return txn.execSQL(schemaText) yield inTransaction(lambda: self.store.newTransaction("bonus schema"), doit) def indirectedTransactionFactory(*a): """ Allow tests to replace 'self.store.newTransaction' to provide fixtures with extra methods on a test-by-test basis. """ return self.store.newTransaction(*a) def deschema(): @inlineCallbacks def deletestuff(txn): for stmt in dropSQL: yield txn.execSQL(stmt) return inTransaction(lambda *a: self.store.newTransaction(*a), deletestuff) self.addCleanup(deschema) from twisted.internet import reactor self.node1 = PeerConnectionPool( reactor, indirectedTransactionFactory, 0, schema) self.node2 = PeerConnectionPool( reactor, indirectedTransactionFactory, 0, schema) class FireMeService(Service, object): def __init__(self, d): super(FireMeService, self).__init__() self.d = d def startService(self): self.d.callback(None) d1 = Deferred() d2 = Deferred() FireMeService(d1).setServiceParent(self.node1) FireMeService(d2).setServiceParent(self.node2) ms = MultiService() self.node1.setServiceParent(ms) self.node2.setServiceParent(ms) ms.startService() self.addCleanup(ms.stopService) yield gatherResults([d1, d2]) self.store.queuer = self.node1 def test_currentNodeInfo(self): """ There will be two C{NODE_INFO} rows in the database, retrievable as two L{NodeInfo} objects, once both nodes have started up. """ @inlineCallbacks def check(txn): self.assertEquals(len((yield self.node1.activeNodes(txn))), 2) self.assertEquals(len((yield self.node2.activeNodes(txn))), 2) return inTransaction(self.store.newTransaction, check) @inlineCallbacks def test_enqueueHappyPath(self): """ When a L{WorkItem} is scheduled for execution via L{PeerConnectionPool.enqueueWork} its C{doWork} method will be invoked by the time the L{Deferred} returned from the resulting L{WorkProposal}'s C{whenExecuted} method has fired. """ # TODO: this exact test should run against LocalQueuer as well. def operation(txn): # TODO: how does 'enqueue' get associated with the transaction? # This is not the fact with a raw t.w.enterprise transaction. # Should probably do something with components. return txn.enqueue(DummyWorkItem, a=3, b=4, workID=4321, notBefore=datetime.datetime.utcnow()) result = yield inTransaction(self.store.newTransaction, operation) # Wait for it to be executed. Hopefully this does not time out :-\. yield result.whenExecuted() def op2(txn): return Select([schema.DUMMY_WORK_DONE.WORK_ID, schema.DUMMY_WORK_DONE.A_PLUS_B], From=schema.DUMMY_WORK_DONE).on(txn) rows = yield inTransaction(self.store.newTransaction, op2) self.assertEquals(rows, [[4321, 7]]) @inlineCallbacks def test_noWorkDoneWhenConcurrentlyDeleted(self): """ When a L{WorkItem} is concurrently deleted by another transaction, it should I{not} perform its work. """ # Provide access to a method called 'concurrently' everything using original = self.store.newTransaction def decorate(*a, **k): result = original(*a, **k) result.concurrently = self.store.newTransaction return result self.store.newTransaction = decorate def operation(txn): return txn.enqueue(DummyWorkItem, a=30, b=40, workID=5678, deleteOnLoad=1, notBefore=datetime.datetime.utcnow()) proposal = yield inTransaction(self.store.newTransaction, operation) yield proposal.whenExecuted() # Sanity check on the concurrent deletion. def op2(txn): return Select([schema.DUMMY_WORK_ITEM.WORK_ID], From=schema.DUMMY_WORK_ITEM).on(txn) rows = yield inTransaction(self.store.newTransaction, op2) self.assertEquals(rows, []) def op3(txn): return Select([schema.DUMMY_WORK_DONE.WORK_ID, schema.DUMMY_WORK_DONE.A_PLUS_B], From=schema.DUMMY_WORK_DONE).on(txn) rows = yield inTransaction(self.store.newTransaction, op3) self.assertEquals(rows, []) class DummyProposal(object): def __init__(self, *ignored): pass def _start(self): pass class BaseQueuerTests(TestCase): def setUp(self): self.proposal = None self.patch(twext.enterprise.queue, "WorkProposal", DummyProposal) def _proposalCallback(self, proposal): self.proposal = proposal def test_proposalCallbacks(self): queuer = _BaseQueuer() queuer.callWithNewProposals(self._proposalCallback) self.assertEqual(self.proposal, None) queuer.enqueueWork(None, None) self.assertNotEqual(self.proposal, None) class NonPerformingQueuerTests(TestCase): @inlineCallbacks def test_choosePerformer(self): queuer = NonPerformingQueuer() performer = queuer.choosePerformer() result = (yield performer.performWork(None, None)) self.assertEquals(result, None) calendarserver-5.2+dfsg/twext/enterprise/test/test_util.py0000644000175000017500000000231312263343324023210 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## import datetime from twisted.trial.unittest import TestCase from twext.enterprise.util import parseSQLTimestamp class TimestampTests(TestCase): """ Tests for date-related functions. """ def test_parseSQLTimestamp(self): """ L{parseSQLTimestamp} parses the traditional SQL timestamp. """ tests = ( ("2012-04-04 12:34:56", datetime.datetime(2012, 4, 4, 12, 34, 56)), ("2012-12-31 01:01:01", datetime.datetime(2012, 12, 31, 1, 1, 1)), ) for sqlStr, result in tests: self.assertEqual(parseSQLTimestamp(sqlStr), result) calendarserver-5.2+dfsg/twext/enterprise/test/test_locking.py0000644000175000017500000000536612263346572023704 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for mutual exclusion locks. """ from twisted.internet.defer import inlineCallbacks from twisted.trial.unittest import TestCase from twext.enterprise.fixtures import buildConnectionPool from twext.enterprise.locking import NamedLock, LockTimeout from twext.enterprise.dal.syntax import Select from twext.enterprise.locking import LockSchema schemaText = """ create table NAMED_LOCK (LOCK_NAME varchar(255) unique primary key); """ class TestLocking(TestCase): """ Test locking and unlocking a database row. """ def setUp(self): """ Build a connection pool for the tests to use. """ self.pool = buildConnectionPool(self, schemaText) @inlineCallbacks def test_acquire(self): """ Acquiring a lock adds a row in that transaction. """ txn = self.pool.connection() yield NamedLock.acquire(txn, u"a test lock") rows = yield Select(From=LockSchema.NAMED_LOCK).on(txn) self.assertEquals(rows, [tuple([u"a test lock"])]) @inlineCallbacks def test_release(self): """ Releasing an acquired lock removes the row. """ txn = self.pool.connection() lck = yield NamedLock.acquire(txn, u"a test lock") yield lck.release() rows = yield Select(From=LockSchema.NAMED_LOCK).on(txn) self.assertEquals(rows, []) @inlineCallbacks def test_autoRelease(self): """ Committing a transaction automatically releases all of its locks. """ txn = self.pool.connection() yield NamedLock.acquire(txn, u"something") yield txn.commit() txn2 = self.pool.connection() rows = yield Select(From=LockSchema.NAMED_LOCK).on(txn2) self.assertEquals(rows, []) @inlineCallbacks def test_timeout(self): """ Trying to acquire second lock times out. """ txn1 = self.pool.connection() yield NamedLock.acquire(txn1, u"a test lock") txn2 = self.pool.connection() yield self.assertFailure(NamedLock.acquire(txn2, u"a test lock"), LockTimeout) yield txn2.abort() self.flushLoggedErrors() calendarserver-5.2+dfsg/twext/enterprise/test/test_adbapi2.py0000644000175000017500000014037712263343324023552 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.enterprise.adbapi2}. """ import gc from zope.interface.verify import verifyObject from twisted.python.failure import Failure from twisted.trial.unittest import TestCase from twisted.internet.defer import Deferred, fail, succeed, inlineCallbacks from twisted.test.proto_helpers import StringTransport from twext.enterprise.ienterprise import ConnectionError from twext.enterprise.ienterprise import AlreadyFinishedError from twext.enterprise.adbapi2 import ConnectionPoolClient from twext.enterprise.adbapi2 import ConnectionPoolConnection from twext.enterprise.ienterprise import IAsyncTransaction from twext.enterprise.ienterprise import ICommandBlock from twext.enterprise.adbapi2 import FailsafeException from twext.enterprise.adbapi2 import ConnectionPool from twext.enterprise.fixtures import ConnectionPoolHelper from twext.enterprise.fixtures import resultOf from twext.enterprise.fixtures import ClockWithThreads from twext.enterprise.fixtures import FakeConnectionError from twext.enterprise.fixtures import RollbackFail from twext.enterprise.fixtures import CommitFail from twext.enterprise.adbapi2 import Commit from twext.enterprise.adbapi2 import _HookableOperation class TrashCollector(object): """ Test helper for monitoring gc.garbage. """ def __init__(self, testCase): self.testCase = testCase testCase.addCleanup(self.checkTrash) self.start() def start(self): gc.collect() self.garbageStart = len(gc.garbage) def checkTrash(self): """ Ensure that the test has added no additional garbage. """ gc.collect() newGarbage = gc.garbage[self.garbageStart:] if newGarbage: # Don't clean up twice. self.start() self.testCase.fail("New garbage: " + repr(newGarbage)) class AssertResultHelper(object): """ Mixin for asserting about synchronous Deferred results. """ def assertResultList(self, resultList, expected): """ Assert that a list created with L{resultOf} contais the expected result. @param resultList: The return value of L{resultOf}. @type resultList: L{list} @param expected: The expected value that should be present in the list; a L{Failure} if an exception is expected to be raised. """ if not resultList: self.fail("No result; Deferred didn't fire yet.") else: if isinstance(resultList[0], Failure): if isinstance(expected, Failure): resultList[0].trap(expected.type) else: resultList[0].raiseException() else: self.assertEqual(resultList, [expected]) class ConnectionPoolBootTests(TestCase): """ Tests for the start-up phase of L{ConnectionPool}. """ def test_threadCount(self): """ The reactor associated with a L{ConnectionPool} will have its maximum thread count adjusted when L{ConnectionPool.startService} is called, to accomodate for L{ConnectionPool.maxConnections} additional threads. Stopping the service should restore it to its original value, so that a repeatedly re-started L{ConnectionPool} will not cause the thread ceiling to grow without bound. """ defaultMax = 27 connsMax = 45 combinedMax = defaultMax + connsMax pool = ConnectionPool(None, maxConnections=connsMax) pool.reactor = ClockWithThreads() threadpool = pool.reactor.getThreadPool() pool.reactor.suggestThreadPoolSize(defaultMax) self.assertEquals(threadpool.max, defaultMax) pool.startService() self.assertEquals(threadpool.max, combinedMax) justChecking = [] pool.stopService().addCallback(justChecking.append) # No SQL run, so no threads started, so this deferred should fire # immediately. If not, we're in big trouble, so sanity check. self.assertEquals(justChecking, [None]) self.assertEquals(threadpool.max, defaultMax) def test_isRunning(self): """ L{ConnectionPool.startService} should set its C{running} attribute to true. """ pool = ConnectionPool(None) pool.reactor = ClockWithThreads() self.assertEquals(pool.running, False) pool.startService() self.assertEquals(pool.running, True) class ConnectionPoolTests(ConnectionPoolHelper, TestCase, AssertResultHelper): """ Tests for L{ConnectionPool}. """ def test_tooManyConnections(self): """ When the number of outstanding busy transactions exceeds the number of slots specified by L{ConnectionPool.maxConnections}, L{ConnectionPool.connection} will return a pooled transaction that is not backed by any real database connection; this object will queue its SQL statements until an existing connection becomes available. """ a = self.createTransaction() alphaResult = self.resultOf(a.execSQL("alpha")) [[counter, echo]] = alphaResult[0] b = self.createTransaction() # 'b' should have opened a connection. self.assertEquals(len(self.factory.connections), 2) betaResult = self.resultOf(b.execSQL("beta")) [[bcounter, becho]] = betaResult[0] # both 'a' and 'b' are holding open a connection now; let's try to open # a third one. (The ordering will be deterministic even if this fails, # because those threads are already busy.) c = self.createTransaction() gammaResult = self.resultOf(c.execSQL("gamma")) # Did 'c' open a connection? Let's hope not... self.assertEquals(len(self.factory.connections), 2) # SQL shouldn't be executed too soon... self.assertEquals(gammaResult, []) commitResult = self.resultOf(b.commit()) # Now that 'b' has committed, 'c' should be able to complete. [[ccounter, cecho]] = gammaResult[0] # The connection for 'a' ought to still be busy, so let's make sure # we're using the one for 'c'. self.assertEquals(ccounter, bcounter) # Sanity check: the commit should have succeded! self.assertEquals(commitResult, [None]) def test_stopService(self): """ L{ConnectionPool.stopService} stops all the associated L{ThreadHolder}s and thereby frees up the resources it is holding. """ a = self.createTransaction() alphaResult = self.resultOf(a.execSQL("alpha")) [[[counter, echo]]] = alphaResult self.assertEquals(len(self.factory.connections), 1) self.assertEquals(len(self.holders), 1) [holder] = self.holders self.assertEquals(holder.started, True) self.assertEquals(holder.stopped, False) self.pool.stopService() self.assertEquals(self.pool.running, False) self.assertEquals(len(self.holders), 1) self.assertEquals(holder.started, True) self.assertEquals(holder.stopped, True) # Closing fake connections removes them from the list. self.assertEquals(len(self.factory.connections), 1) self.assertEquals(self.factory.connections[0].closed, True) def test_retryAfterConnectError(self): """ When the C{connectionFactory} passed to L{ConnectionPool} raises an exception, the L{ConnectionPool} will log the exception and delay execution of a new connection's SQL methods until an attempt succeeds. """ self.factory.willFail() self.factory.willFail() self.factory.willConnect() c = self.createTransaction() def checkOneFailure(): errors = self.flushLoggedErrors(FakeConnectionError) self.assertEquals(len(errors), 1) checkOneFailure() d = c.execSQL("alpha") happened = [] d.addBoth(happened.append) self.assertEquals(happened, []) self.clock.advance(self.pool.RETRY_TIMEOUT + 0.01) checkOneFailure() self.assertEquals(happened, []) self.clock.advance(self.pool.RETRY_TIMEOUT + 0.01) self.flushHolders() self.assertEquals(happened, [[[1, "alpha"]]]) def test_shutdownDuringRetry(self): """ If a L{ConnectionPool} is attempting to shut down while it's in the process of re-trying a connection attempt that received an error, the connection attempt should be cancelled and the shutdown should complete as normal. """ self.factory.defaultFail() self.createTransaction() errors = self.flushLoggedErrors(FakeConnectionError) self.assertEquals(len(errors), 1) stopd = [] self.pool.stopService().addBoth(stopd.append) self.assertResultList(stopd, None) self.assertEquals(self.clock.calls, []) [holder] = self.holders self.assertEquals(holder.started, True) self.assertEquals(holder.stopped, True) def test_shutdownDuringAttemptSuccess(self): """ If L{ConnectionPool.stopService} is called while a connection attempt is outstanding, the resulting L{Deferred} won't be fired until the connection attempt has finished; in this case, succeeded. """ self.pauseHolders() self.createTransaction() stopd = [] self.pool.stopService().addBoth(stopd.append) self.assertEquals(stopd, []) self.flushHolders() self.assertResultList(stopd, None) [holder] = self.holders self.assertEquals(holder.started, True) self.assertEquals(holder.stopped, True) def test_shutdownDuringAttemptFailed(self): """ If L{ConnectionPool.stopService} is called while a connection attempt is outstanding, the resulting L{Deferred} won't be fired until the connection attempt has finished; in this case, failed. """ self.factory.defaultFail() self.pauseHolders() self.createTransaction() stopd = [] self.pool.stopService().addBoth(stopd.append) self.assertEquals(stopd, []) self.flushHolders() errors = self.flushLoggedErrors(FakeConnectionError) self.assertEquals(len(errors), 1) self.assertResultList(stopd, None) [holder] = self.holders self.assertEquals(holder.started, True) self.assertEquals(holder.stopped, True) def test_stopServiceMidAbort(self): """ When L{ConnectionPool.stopService} is called with deferreds from C{abort} still outstanding, it will wait for the currently-aborting transaction to fully abort before firing the L{Deferred} returned from C{stopService}. """ # TODO: commit() too? self.pauseHolders() c = self.createTransaction() abortResult = self.resultOf(c.abort()) # Should abort instantly, as it hasn't managed to unspool anything yet. # FIXME: kill all Deferreds associated with this thing, make sure that # any outstanding query callback chains get nuked. self.assertEquals(abortResult, [None]) stopResult = self.resultOf(self.pool.stopService()) self.assertEquals(stopResult, []) self.flushHolders() #self.assertEquals(abortResult, [None]) self.assertResultList(stopResult, None) def test_stopServiceWithSpooled(self): """ When L{ConnectionPool.stopService} is called when spooled transactions are outstanding, any pending L{Deferreds} returned by those transactions will be failed with L{ConnectionError}. """ # Use up the free slots so we have to spool. hold = [] hold.append(self.createTransaction()) hold.append(self.createTransaction()) c = self.createTransaction() se = self.resultOf(c.execSQL("alpha")) ce = self.resultOf(c.commit()) self.assertEquals(se, []) self.assertEquals(ce, []) self.resultOf(self.pool.stopService()) self.assertEquals(se[0].type, self.translateError(ConnectionError)) self.assertEquals(ce[0].type, self.translateError(ConnectionError)) def test_repoolSpooled(self): """ Regression test for a somewhat tricky-to-explain bug: when a spooled transaction which has already had commit() called on it before it's received a real connection to start executing on, it will not leave behind any detritus that prevents stopService from working. """ self.pauseHolders() c = self.createTransaction() c2 = self.createTransaction() c3 = self.createTransaction() c.commit() c2.commit() c3.commit() self.flushHolders() self.assertEquals(len(self.factory.connections), 2) stopResult = self.resultOf(self.pool.stopService()) self.assertEquals(stopResult, [None]) self.assertEquals(len(self.factory.connections), 2) self.assertEquals(self.factory.connections[0].closed, True) self.assertEquals(self.factory.connections[1].closed, True) def test_connectAfterStop(self): """ Calls to connection() after stopService() result in transactions which immediately fail all operations. """ stopResults = self.resultOf(self.pool.stopService()) self.assertEquals(stopResults, [None]) self.pauseHolders() postClose = self.createTransaction() queryResult = self.resultOf(postClose.execSQL("hello")) self.assertEquals(len(queryResult), 1) self.assertEquals(queryResult[0].type, self.translateError(ConnectionError)) def test_connectAfterStartedStopping(self): """ Calls to connection() after stopService() has been called but before it has completed will result in transactions which immediately fail all operations. """ self.pauseHolders() preClose = self.createTransaction() preCloseResult = self.resultOf(preClose.execSQL('statement')) stopResult = self.resultOf(self.pool.stopService()) postClose = self.createTransaction() queryResult = self.resultOf(postClose.execSQL("hello")) self.assertEquals(stopResult, []) self.assertEquals(len(queryResult), 1) self.assertEquals(queryResult[0].type, self.translateError(ConnectionError)) self.assertEquals(len(preCloseResult), 1) self.assertEquals(preCloseResult[0].type, self.translateError(ConnectionError)) def test_abortFailsDuringStopService(self): """ L{IAsyncTransaction.abort} might fail, most likely because the underlying database connection has already been disconnected. If this happens, shutdown should continue. """ txns = [] txns.append(self.createTransaction()) txns.append(self.createTransaction()) for txn in txns: # Make sure rollback will actually be executed. results = self.resultOf(txn.execSQL("maybe change something!")) [[[counter, echo]]] = results self.assertEquals("maybe change something!", echo) # Fail one (and only one) call to rollback(). self.factory.rollbackFail = True stopResult = self.resultOf(self.pool.stopService()) self.assertEquals(stopResult, [None]) self.assertEquals(len(self.flushLoggedErrors(RollbackFail)), 1) self.assertEquals(self.factory.connections[0].closed, True) self.assertEquals(self.factory.connections[1].closed, True) def test_abortRecycledTransaction(self): """ L{ConnectionPool.stopService} will shut down if a recycled transaction is still pending. """ recycled = self.createTransaction() self.resultOf(recycled.commit()) remember = [] remember.append(self.createTransaction()) self.assertEquals(self.resultOf(self.pool.stopService()), [None]) def test_abortSpooled(self): """ Aborting a still-spooled transaction (one which has no statements being executed) will result in all of its Deferreds immediately failing and none of the queued statements being executed. """ active = [] # Use up the available connections ... for i in xrange(self.pool.maxConnections): active.append(self.createTransaction()) # ... so that this one has to be spooled. spooled = self.createTransaction() result = self.resultOf(spooled.execSQL("alpha")) # sanity check, it would be bad if this actually executed. self.assertEqual(result, []) self.resultOf(spooled.abort()) self.assertEqual(result[0].type, self.translateError(ConnectionError)) def test_waitForAlreadyAbortedTransaction(self): """ L{ConnectionPool.stopService} will wait for all transactions to shut down before exiting, including those which have already been stopped. """ it = self.createTransaction() self.pauseHolders() abortResult = self.resultOf(it.abort()) # steal it from the queue so we can do it out of order d, work = self.holders[0]._q.get() # that should be the only work unit so don't continue if something else # got in there self.assertEquals(list(self.holders[0]._q.queue), []) self.assertEquals(len(self.holders), 1) self.flushHolders() stopResult = self.resultOf(self.pool.stopService()) # Sanity check that we haven't actually stopped it yet self.assertEquals(abortResult, []) # We haven't fired it yet, so the service had better not have # stopped... self.assertEquals(stopResult, []) d.callback(None) self.flushHolders() self.assertEquals(abortResult, [None]) self.assertEquals(stopResult, [None]) def test_garbageCollectedTransactionAborts(self): """ When an L{IAsyncTransaction} is garbage collected, it ought to abort itself. """ t = self.createTransaction() self.resultOf(t.execSQL("echo", [])) conns = self.factory.connections self.assertEquals(len(conns), 1) self.assertEquals(conns[0]._rollbackCount, 0) del t gc.collect() self.flushHolders() self.assertEquals(len(conns), 1) self.assertEquals(conns[0]._rollbackCount, 1) self.assertEquals(conns[0]._commitCount, 0) def circularReferenceTest(self, finish, hook): """ Collecting a completed (committed or aborted) L{IAsyncTransaction} should not leak any circular references. """ tc = TrashCollector(self) commitExecuted = [] def carefullyManagedScope(): t = self.createTransaction() def holdAReference(): """ This is a hook that holds a reference to 't'. """ commitExecuted.append(True) return t.execSQL("teardown", []) hook(t, holdAReference) finish(t) self.failIf(commitExecuted, "Commit hook executed.") carefullyManagedScope() tc.checkTrash() def test_noGarbageOnCommit(self): """ Committing a transaction does not cause gc garbage. """ self.circularReferenceTest(lambda txn: txn.commit(), lambda txn, hook: txn.preCommit(hook)) def test_noGarbageOnCommitWithAbortHook(self): """ Committing a transaction does not cause gc garbage. """ self.circularReferenceTest(lambda txn: txn.commit(), lambda txn, hook: txn.postAbort(hook)) def test_noGarbageOnAbort(self): """ Aborting a transaction does not cause gc garbage. """ self.circularReferenceTest(lambda txn: txn.abort(), lambda txn, hook: txn.preCommit(hook)) def test_noGarbageOnAbortWithPostCommitHook(self): """ Aborting a transaction does not cause gc garbage. """ self.circularReferenceTest(lambda txn: txn.abort(), lambda txn, hook: txn.postCommit(hook)) def test_tooManyConnectionsWhileOthersFinish(self): """ L{ConnectionPool.connection} will not spawn more than the maximum connections if there are finishing transactions outstanding. """ a = self.createTransaction() b = self.createTransaction() self.pauseHolders() a.abort() b.abort() # Remove the holders for the existing connections, so that the 'extra' # connection() call wins the race and gets executed first. self.holders[:] = [] self.createTransaction() self.flushHolders() self.assertEquals(len(self.factory.connections), 2) def setParamstyle(self, paramstyle): """ Change the paramstyle of the transaction under test. """ self.pool.paramstyle = paramstyle def test_propagateParamstyle(self): """ Each different type of L{ISQLExecutor} relays the C{paramstyle} attribute from the L{ConnectionPool}. """ TEST_PARAMSTYLE = "justtesting" self.setParamstyle(TEST_PARAMSTYLE) normaltxn = self.createTransaction() self.assertEquals(normaltxn.paramstyle, TEST_PARAMSTYLE) self.assertEquals(normaltxn.commandBlock().paramstyle, TEST_PARAMSTYLE) self.pauseHolders() extra = [] extra.append(self.createTransaction()) waitingtxn = self.createTransaction() self.assertEquals(waitingtxn.paramstyle, TEST_PARAMSTYLE) self.flushHolders() self.pool.stopService() notxn = self.createTransaction() self.assertEquals(notxn.paramstyle, TEST_PARAMSTYLE) def setDialect(self, dialect): """ Change the dialect of the transaction under test. """ self.pool.dialect = dialect def test_propagateDialect(self): """ Each different type of L{ISQLExecutor} relays the C{dialect} attribute from the L{ConnectionPool}. """ TEST_DIALECT = "otherdialect" self.setDialect(TEST_DIALECT) normaltxn = self.createTransaction() self.assertEquals(normaltxn.dialect, TEST_DIALECT) self.assertEquals(normaltxn.commandBlock().dialect, TEST_DIALECT) self.pauseHolders() extra = [] extra.append(self.createTransaction()) waitingtxn = self.createTransaction() self.assertEquals(waitingtxn.dialect, TEST_DIALECT) self.flushHolders() self.pool.stopService() notxn = self.createTransaction() self.assertEquals(notxn.dialect, TEST_DIALECT) def test_reConnectWhenFirstExecFails(self): """ Generally speaking, DB-API 2.0 adapters do not provide information about the cause of a failed 'execute' method; they definitely don't provide it in a way which can be identified as related to the syntax of the query, the state of the database itself, the state of the connection, etc. Therefore the best general heuristic for whether the connection to the database has been lost and needs to be re-established is to catch exceptions which are raised by the I{first} statement executed in a transaction. """ # Allow 'connect' to succeed. This should behave basically the same # whether connect() happened to succeed in some previous transaction # and it's recycling the underlying transaction, or connect() just # succeeded. Either way you just have a _SingleTxn wrapping a # _ConnectedTxn. txn = self.createTransaction() self.assertEquals(len(self.factory.connections), 1, "Sanity check failed.") class CustomExecuteFailed(Exception): """ Custom 'execute-failed' exception. """ self.factory.connections[0].executeWillFail(CustomExecuteFailed) results = self.resultOf(txn.execSQL("hello, world!")) [[[counter, echo]]] = results self.assertEquals("hello, world!", echo) # Two execution attempts should have been made, one on each connection. # The first failed with a RuntimeError, but that is deliberately # obscured, because then we tried again and it succeeded. self.assertEquals(len(self.factory.connections), 2, "No new connection opened.") self.assertEquals(self.factory.connections[0].executions, 1) self.assertEquals(self.factory.connections[1].executions, 1) self.assertEquals(self.factory.connections[0].closed, True) self.assertEquals(self.factory.connections[1].closed, False) # Nevertheless, since there is currently no classification of 'safe' # errors, we should probably log these messages when they occur. self.assertEquals(len(self.flushLoggedErrors(CustomExecuteFailed)), 1) def test_reConnectWhenFirstExecOnExistingConnectionFails( self, moreFailureSetup=lambda factory: None): """ Another situation that might arise is that a connection will be successfully connected, executed and recycled into the connection pool; then, the database server will shut down and the connections will die, but we will be none the wiser until we try to use them. """ txn = self.createTransaction() moreFailureSetup(self.factory) self.assertEquals(len(self.factory.connections), 1, "Sanity check failed.") results = self.resultOf(txn.execSQL("hello, world!")) txn.commit() [[[counter, echo]]] = results self.assertEquals("hello, world!", echo) txn2 = self.createTransaction() self.assertEquals(len(self.factory.connections), 1, "Sanity check failed.") class CustomExecFail(Exception): """ Custom 'execute()' failure. """ self.factory.connections[0].executeWillFail(CustomExecFail) results = self.resultOf(txn2.execSQL("second try!")) txn2.commit() [[[counter, echo]]] = results self.assertEquals("second try!", echo) self.assertEquals(len(self.flushLoggedErrors(CustomExecFail)), 1) def test_closeExceptionDoesntHinderReconnection(self): """ In some database bindings, if the server closes the connection, C{close()} will fail. If C{close} fails, there's not much that could mean except that the connection is already closed, so similar to the condition described in L{test_reConnectWhenFirstExecOnExistingConnectionFails}, the failure should be logged, but transparent to application code. """ class BindingSpecificException(Exception): """ Exception that's a placeholder for something that a database binding might raise. """ def alsoFailClose(factory): factory.childCloseWillFail(BindingSpecificException()) t = self.test_reConnectWhenFirstExecOnExistingConnectionFails( alsoFailClose ) errors = self.flushLoggedErrors(BindingSpecificException) self.assertEquals(len(errors), 1) return t def test_preCommitSuccess(self): """ Callables passed to L{IAsyncTransaction.preCommit} will be invoked upon commit. """ txn = self.createTransaction() def simple(): simple.done = True simple.done = False txn.preCommit(simple) self.assertEquals(simple.done, False) result = self.resultOf(txn.commit()) self.assertEquals(len(result), 1) self.assertEquals(simple.done, True) def test_deferPreCommit(self): """ If callables passed to L{IAsyncTransaction.preCommit} return L{Deferred}s, they will defer the actual commit operation until it has fired. """ txn = self.createTransaction() d = Deferred() def wait(): wait.started = True def executed(it): wait.sqlResult = it # To make sure the _underlying_ commit operation was Deferred, we # have to execute some SQL to make sure it happens. return (d.addCallback(lambda ignored: txn.execSQL("some test sql")) .addCallback(executed)) wait.started = False wait.sqlResult = None txn.preCommit(wait) result = self.resultOf(txn.commit()) self.flushHolders() self.assertEquals(wait.started, True) self.assertEquals(wait.sqlResult, None) self.assertEquals(result, []) d.callback(None) # allow network I/O for pooled / networked implementation; there should # be the commit message now. self.flushHolders() self.assertEquals(len(result), 1) self.assertEquals(wait.sqlResult, [[1, "some test sql"]]) def test_failPreCommit(self): """ If callables passed to L{IAsyncTransaction.preCommit} raise an exception or return a Failure, subsequent callables will not be run, and the transaction will be aborted. """ def test(flawedCallable, exc): # Set up. test.committed = False test.aborted = False # Create transaction and add monitoring hooks. txn = self.createTransaction() def didCommit(): test.committed = True def didAbort(): test.aborted = True txn.postCommit(didCommit) txn.postAbort(didAbort) txn.preCommit(flawedCallable) result = self.resultOf(txn.commit()) self.flushHolders() self.assertResultList(result, Failure(exc())) self.assertEquals(test.committed, False) self.assertEquals(test.aborted, True) def failer(): return fail(ZeroDivisionError()) def raiser(): raise EOFError() test(failer, ZeroDivisionError) test(raiser, EOFError) def test_noOpCommitDoesntHinderReconnection(self): """ Until you've executed a query or performed a statement on an ADBAPI connection, the connection is semantically idle (between transactions). A .commit() or .rollback() followed immediately by a .commit() is therefore pointless, and can be ignored. Furthermore, actually executing the commit and propagating a possible connection-oriented error causes clients to see errors, when, if those clients had actually executed any statements, the connection would have been recycled and the statement transparently re-executed by the logic tested by L{test_reConnectWhenFirstExecFails}. """ txn = self.createTransaction() self.factory.commitFail = True self.factory.rollbackFail = True [x] = self.resultOf(txn.commit()) # No statements have been executed, so 'commit' will *not* be executed. self.assertEquals(self.factory.commitFail, True) self.assertIdentical(x, None) self.assertEquals(len(self.pool._free), 1) self.assertEquals(self.pool._finishing, []) self.assertEquals(len(self.factory.connections), 1) self.assertEquals(self.factory.connections[0].closed, False) def test_reConnectWhenSecondExecFailsThenFirstExecFails(self): """ Other connection-oriented errors might raise exceptions if they occur in the middle of a transaction, but that should cause the error to be caught, the transaction to be aborted, and the (closed) connection to be recycled, where the next transaction that attempts to do anything with it will encounter the error immediately and discover it needs to be recycled. It would be better if this behavior were invisible, but that could only be accomplished with more precise database exceptions. We may come up with support in the future for more precisely identifying exceptions, but I{unknown} exceptions should continue to be treated in this manner, relaying the exception back to application code but attempting a re-connection on the next try. """ txn = self.createTransaction() [[[counter, echo]]] = self.resultOf(txn.execSQL("hello, world!", [])) self.factory.connections[0].executeWillFail(ZeroDivisionError) [f] = self.resultOf(txn.execSQL("divide by zero", [])) f.trap(self.translateError(ZeroDivisionError)) self.assertEquals(self.factory.connections[0].executions, 2) # Reconnection should work exactly as before. self.assertEquals(self.factory.connections[0].closed, False) # Application code has to roll back its transaction at this point, # since it failed (and we don't necessarily know why it failed: not # enough information). self.resultOf(txn.abort()) self.factory.connections[0].executions = 0 # re-set for next test self.assertEquals(len(self.factory.connections), 1) self.test_reConnectWhenFirstExecFails() def test_disconnectOnFailedRollback(self): """ When C{rollback} fails for any reason on a connection object, then we don't know what state it's in. Most likely, it's already been disconnected, so the connection should be closed and the transaction de-pooled instead of recycled. Also, a new connection will immediately be established to keep the pool size the same. """ txn = self.createTransaction() results = self.resultOf(txn.execSQL("maybe change something!")) [[[counter, echo]]] = results self.assertEquals("maybe change something!", echo) self.factory.rollbackFail = True [x] = self.resultOf(txn.abort()) # Abort does not propagate the error on, the transaction merely gets # disposed of. self.assertIdentical(x, None) self.assertEquals(len(self.pool._free), 1) self.assertEquals(self.pool._finishing, []) self.assertEquals(len(self.factory.connections), 2) self.assertEquals(self.factory.connections[0].closed, True) self.assertEquals(self.factory.connections[1].closed, False) self.assertEquals(len(self.flushLoggedErrors(RollbackFail)), 1) def test_exceptionPropagatesFailedCommit(self): """ A failed C{rollback} is fine (the premature death of the connection without C{commit} means that the changes are surely gone), but a failed C{commit} has to be relayed to client code, since that actually means some changes didn't hit the database. """ txn = self.createTransaction() self.factory.commitFail = True results = self.resultOf(txn.execSQL("maybe change something!")) [[[counter, echo]]] = results self.assertEquals("maybe change something!", echo) [x] = self.resultOf(txn.commit()) x.trap(self.translateError(CommitFail)) self.assertEquals(len(self.pool._free), 1) self.assertEquals(self.pool._finishing, []) self.assertEquals(len(self.factory.connections), 2) self.assertEquals(self.factory.connections[0].closed, True) self.assertEquals(self.factory.connections[1].closed, False) def test_commandBlock(self): """ L{IAsyncTransaction.commandBlock} returns an L{IAsyncTransaction} provider which ensures that a block of commands are executed together. """ txn = self.createTransaction() a = self.resultOf(txn.execSQL("a")) cb = txn.commandBlock() verifyObject(ICommandBlock, cb) b = self.resultOf(cb.execSQL("b")) d = self.resultOf(txn.execSQL("d")) c = self.resultOf(cb.execSQL("c")) cb.end() e = self.resultOf(txn.execSQL("e")) self.assertEquals(self.factory.connections[0].cursors[0].allExecutions, [("a", []), ("b", []), ("c", []), ("d", []), ("e", [])]) self.assertEquals(len(a), 1) self.assertEquals(len(b), 1) self.assertEquals(len(c), 1) self.assertEquals(len(d), 1) self.assertEquals(len(e), 1) def test_commandBlockWithLatency(self): """ A block returned by L{IAsyncTransaction.commandBlock} won't start executing until all SQL statements scheduled before it have completed. """ self.pauseHolders() txn = self.createTransaction() a = self.resultOf(txn.execSQL("a")) b = self.resultOf(txn.execSQL("b")) cb = txn.commandBlock() c = self.resultOf(cb.execSQL("c")) d = self.resultOf(cb.execSQL("d")) e = self.resultOf(txn.execSQL("e")) cb.end() self.flushHolders() self.assertEquals(self.factory.connections[0].cursors[0].allExecutions, [("a", []), ("b", []), ("c", []), ("d", []), ("e", [])]) self.assertEquals(len(a), 1) self.assertEquals(len(b), 1) self.assertEquals(len(c), 1) self.assertEquals(len(d), 1) self.assertEquals(len(e), 1) def test_twoCommandBlocks(self, flush=lambda: None): """ When execution of one command block is complete, it will proceed to the next queued block, then to regular SQL executed on the transaction. """ txn = self.createTransaction() cb1 = txn.commandBlock() cb2 = txn.commandBlock() txn.execSQL("e") cb1.execSQL("a") cb2.execSQL("c") cb1.execSQL("b") cb2.execSQL("d") cb2.end() cb1.end() flush() self.flushHolders() self.assertEquals(self.factory.connections[0].cursors[0].allExecutions, [("a", []), ("b", []), ("c", []), ("d", []), ("e", [])]) def test_twoCommandBlocksLatently(self): """ Same as L{test_twoCommandBlocks}, but with slower callbacks. """ self.pauseHolders() self.test_twoCommandBlocks(self.flushHolders) def test_commandBlockEndTwice(self): """ L{CommandBlock.end} will raise L{AlreadyFinishedError} when called more than once. """ txn = self.createTransaction() block = txn.commandBlock() block.end() self.assertRaises(AlreadyFinishedError, block.end) def test_commandBlockDelaysCommit(self): """ Some command blocks need to run asynchronously, without the overall transaction-managing code knowing how far they've progressed. Therefore when you call {IAsyncTransaction.commit}(), it should not actually take effect if there are any pending command blocks. """ txn = self.createTransaction() block = txn.commandBlock() commitResult = self.resultOf(txn.commit()) self.resultOf(block.execSQL("in block")) self.assertEquals(commitResult, []) self.assertEquals(self.factory.connections[0].cursors[0].allExecutions, [("in block", [])]) block.end() self.flushHolders() self.assertEquals(commitResult, [None]) def test_commandBlockDoesntDelayAbort(self): """ A L{CommandBlock} can't possibly have anything interesting to say about a transaction that gets rolled back, so C{abort} applies immediately; all outstanding C{execSQL}s will fail immediately, on both command blocks and on the transaction itself. """ txn = self.createTransaction() block = txn.commandBlock() block2 = txn.commandBlock() abortResult = self.resultOf(txn.abort()) self.assertEquals(abortResult, [None]) self.assertRaises(AlreadyFinishedError, block2.execSQL, "bar") self.assertRaises(AlreadyFinishedError, block.execSQL, "foo") self.assertRaises(AlreadyFinishedError, txn.execSQL, "baz") self.assertEquals(self.factory.connections[0].cursors[0].allExecutions, []) # end() should _not_ raise an exception, because this is the sort of # thing that might be around a try/finally or try/except; it's just # putting the commandBlock itself into a state consistent with the # transaction. block.end() block2.end() def test_endedBlockDoesntExecuteMoreSQL(self): """ Attempting to execute SQL on a L{CommandBlock} which has had C{end} called on it will result in an L{AlreadyFinishedError}. """ txn = self.createTransaction() block = txn.commandBlock() block.end() self.assertRaises(AlreadyFinishedError, block.execSQL, "hello") self.assertEquals(self.factory.connections[0].cursors[0].allExecutions, []) def test_commandBlockAfterCommitRaises(self): """ Once an L{IAsyncTransaction} has been committed, L{commandBlock} raises an exception. """ txn = self.createTransaction() txn.commit() self.assertRaises(AlreadyFinishedError, txn.commandBlock) def test_commandBlockAfterAbortRaises(self): """ Once an L{IAsyncTransaction} has been committed, L{commandBlock} raises an exception. """ txn = self.createTransaction() self.resultOf(txn.abort()) self.assertRaises(AlreadyFinishedError, txn.commandBlock) def test_raiseOnZeroRowCount(self): """ L{IAsyncTransaction.execSQL} will return a L{Deferred} failing with the exception passed as its raiseOnZeroRowCount argument if the underlying query returns no rows. """ self.factory.hasResults = False txn = self.createTransaction() f = self.resultOf( txn.execSQL("hello", raiseOnZeroRowCount=ZeroDivisionError) )[0] self.assertRaises(ZeroDivisionError, f.raiseException) txn.commit() def test_raiseOnZeroRowCountWithUnreliableRowCount(self): """ As it turns out, some databases can't reliably tell you how many rows they're going to fetch via the C{rowcount} attribute before the rows have actually been fetched, so the C{raiseOnZeroRowCount} will I{not} raise an exception if C{rowcount} is zero but C{description} and C{fetchall} indicates the presence of some rows. """ self.factory.hasResults = True self.factory.shouldUpdateRowcount = False txn = self.createTransaction() r = self.resultOf( txn.execSQL("some-rows", raiseOnZeroRowCount=RuntimeError) ) [[[counter, echo]]] = r self.assertEquals(echo, "some-rows") class IOPump(object): """ Connect a client and a server. @ivar client: a client protocol @ivar server: a server protocol """ def __init__(self, client, server): self.client = client self.server = server self.clientTransport = StringTransport() self.serverTransport = StringTransport() self.client.makeConnection(self.clientTransport) self.server.makeConnection(self.serverTransport) self.c2s = [self.clientTransport, self.server] self.s2c = [self.serverTransport, self.client] def moveData(self, (outTransport, inProtocol)): """ Move data from a L{StringTransport} to an L{IProtocol}. @return: C{True} if any data was moved, C{False} if no data was moved. """ data = outTransport.io.getvalue() outTransport.io.seek(0) outTransport.io.truncate() if data: inProtocol.dataReceived(data) return True else: return False def pump(self): """ Deliver all input from the client to the server, then from the server to the client. """ a = self.moveData(self.c2s) b = self.moveData(self.s2c) return a or b def flush(self, maxTurns=100): """ Continue pumping until no more data is flowing. """ turns = 0 while self.pump(): turns += 1 if turns > maxTurns: raise RuntimeError("Ran too long!") class NetworkedPoolHelper(ConnectionPoolHelper): """ An extension of L{ConnectionPoolHelper} that can set up a L{ConnectionPoolClient} and L{ConnectionPoolConnection} attached to each other. """ def setUp(self): """ Do the same setup from L{ConnectionPoolBase}, but also establish a loopback connection between a L{ConnectionPoolConnection} and a L{ConnectionPoolClient}. """ super(NetworkedPoolHelper, self).setUp() self.pump = IOPump(ConnectionPoolClient(dialect=self.dialect, paramstyle=self.paramstyle), ConnectionPoolConnection(self.pool)) def flushHolders(self): """ In addition to flushing the L{ThreadHolder} stubs, also flush any pending network I/O. """ self.pump.flush() super(NetworkedPoolHelper, self).flushHolders() self.pump.flush() def createTransaction(self): txn = self.pump.client.newTransaction() self.pump.flush() return txn def translateError(self, err): """ All errors raised locally will unfortunately be translated into UnknownRemoteError, since AMP requires specific enumeration of all of them. Flush the locally logged error of the given type and return L{UnknownRemoteError}. """ if err in Commit.errors: return err self.flushLoggedErrors(err) return FailsafeException def resultOf(self, it): result = resultOf(it) self.pump.flush() return result class NetworkedConnectionPoolTests(NetworkedPoolHelper, ConnectionPoolTests): """ Tests for L{ConnectionPoolConnection} and L{ConnectionPoolClient} interacting with each other. """ def setParamstyle(self, paramstyle): """ Change the paramstyle on both the pool and the client. """ super(NetworkedConnectionPoolTests, self).setParamstyle(paramstyle) self.pump.client.paramstyle = paramstyle def setDialect(self, dialect): """ Change the dialect on both the pool and the client. """ super(NetworkedConnectionPoolTests, self).setDialect(dialect) self.pump.client.dialect = dialect def test_newTransaction(self): """ L{ConnectionPoolClient.newTransaction} returns a provider of L{IAsyncTransaction}, and creates a new transaction on the server side. """ txn = self.pump.client.newTransaction() verifyObject(IAsyncTransaction, txn) self.pump.flush() self.assertEquals(len(self.factory.connections), 1) class HookableOperationTests(TestCase): """ Tests for L{_HookableOperation}. """ @inlineCallbacks def test_clearPreventsSubsequentAddHook(self): """ After clear() or runHooks() are called, subsequent calls to addHook() are NO-OPs. """ def hook(): return succeed(None) hookOp = _HookableOperation() hookOp.addHook(hook) self.assertEquals(len(hookOp._hooks), 1) hookOp.clear() self.assertEquals(hookOp._hooks, None) hookOp = _HookableOperation() hookOp.addHook(hook) yield hookOp.runHooks() self.assertEquals(hookOp._hooks, None) hookOp.addHook(hook) self.assertEquals(hookOp._hooks, None) calendarserver-5.2+dfsg/twext/enterprise/test/__init__.py0000644000175000017500000000120712263343324022734 0ustar rahulrahul ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.enterprise}. """ calendarserver-5.2+dfsg/twext/enterprise/fixtures.py0000644000175000017500000003525412263343324022100 0ustar rahulrahul# -*- test-case-name: twext.enterprise.test.test_fixtures -*- ## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Fixtures for testing code that uses ADBAPI2. """ import sqlite3 from Queue import Empty from itertools import count from zope.interface import implementer from zope.interface.verify import verifyClass from twisted.internet.interfaces import IReactorThreads from twisted.python.threadpool import ThreadPool from twisted.internet.task import Clock from twext.enterprise.adbapi2 import ConnectionPool from twext.enterprise.ienterprise import SQLITE_DIALECT from twext.enterprise.ienterprise import POSTGRES_DIALECT from twext.enterprise.adbapi2 import DEFAULT_PARAM_STYLE from twext.internet.threadutils import ThreadHolder def buildConnectionPool(testCase, schemaText="", dialect=SQLITE_DIALECT): """ Build a L{ConnectionPool} for testing purposes, with the given C{testCase}. @param testCase: the test case to attach the resulting L{ConnectionPool} to. @type testCase: L{twisted.trial.unittest.TestCase} @param schemaText: The text of the schema with which to initialize the database. @type schemaText: L{str} @return: a L{ConnectionPool} service whose C{startService} method has already been invoked. @rtype: L{ConnectionPool} """ sqlitename = testCase.mktemp() seqs = {} def connectionFactory(label=testCase.id()): conn = sqlite3.connect(sqlitename) def nextval(seq): result = seqs[seq] = seqs.get(seq, 0) + 1 return result conn.create_function("nextval", 1, nextval) return conn con = connectionFactory() con.executescript(schemaText) con.commit() pool = ConnectionPool(connectionFactory, paramstyle='numeric', dialect=SQLITE_DIALECT) pool.startService() testCase.addCleanup(pool.stopService) return pool def resultOf(deferred, propagate=False): """ Add a callback and errback which will capture the result of a L{Deferred} in a list, and return that list. If 'propagate' is True, pass through the results. """ results = [] if propagate: def cb(r): results.append(r) return r else: cb = results.append deferred.addBoth(cb) return results class FakeThreadHolder(ThreadHolder): """ Run things to submitted this ThreadHolder on the main thread, so that execution is easier to control. """ def __init__(self, test): super(FakeThreadHolder, self).__init__(self) self.test = test self.started = False self.stopped = False self._workerIsRunning = False def start(self): self.started = True return super(FakeThreadHolder, self).start() def stop(self): result = super(FakeThreadHolder, self).stop() self.stopped = True return result @property def _get_q(self): return self._q_ @_get_q.setter def _q(self, newq): if newq is not None: oget = newq.get newq.get = lambda: oget(timeout=0) oput = newq.put def putit(x): p = oput(x) if not self.test.paused: self.flush() return p newq.put = putit self._q_ = newq def callFromThread(self, f, *a, **k): result = f(*a, **k) return result def callInThread(self, f, *a, **k): """ This should be called only once, to start the worker function that dedicates a thread to this L{ThreadHolder}. """ self._workerIsRunning = True def flush(self): """ Fire all deferreds previously returned from submit. """ try: while self._workerIsRunning and self._qpull(): pass else: self._workerIsRunning = False except Empty: pass @implementer(IReactorThreads) class ClockWithThreads(Clock): """ A testing reactor that supplies L{IReactorTime} and L{IReactorThreads}. """ def __init__(self): super(ClockWithThreads, self).__init__() self._pool = ThreadPool() def getThreadPool(self): """ Get the threadpool. """ return self._pool def suggestThreadPoolSize(self, size): """ Approximate the behavior of a 'real' reactor. """ self._pool.adjustPoolsize(maxthreads=size) def callInThread(self, thunk, *a, **kw): """ No implementation. """ def callFromThread(self, thunk, *a, **kw): """ No implementation. """ verifyClass(IReactorThreads, ClockWithThreads) class ConnectionPoolHelper(object): """ Connection pool setting-up facilities for tests that need a L{ConnectionPool}. """ dialect = POSTGRES_DIALECT paramstyle = DEFAULT_PARAM_STYLE def setUp(self, test=None, connect=None): """ Support inheritance by L{TestCase} classes. """ if test is None: test = self if connect is None: self.factory = ConnectionFactory() connect = self.factory.connect self.connect = connect self.paused = False self.holders = [] self.pool = ConnectionPool(connect, maxConnections=2, dialect=self.dialect, paramstyle=self.paramstyle) self.pool._createHolder = self.makeAHolder self.clock = self.pool.reactor = ClockWithThreads() self.pool.startService() test.addCleanup(self.flushHolders) def flushHolders(self): """ Flush all pending C{submit}s since C{pauseHolders} was called. This makes sure the service is stopped and the fake ThreadHolders are all executing their queues so failed tsets can exit cleanly. """ self.paused = False for holder in self.holders: holder.flush() def pauseHolders(self): """ Pause all L{FakeThreadHolder}s, causing C{submit} to return an unfired L{Deferred}. """ self.paused = True def makeAHolder(self): """ Make a ThreadHolder-alike. """ fth = FakeThreadHolder(self) self.holders.append(fth) return fth def resultOf(self, it): return resultOf(it) def createTransaction(self): return self.pool.connection() def translateError(self, err): return err class SteppablePoolHelper(ConnectionPoolHelper): """ A version of L{ConnectionPoolHelper} that can set up a connection pool capable of firing all its L{Deferred}s on demand, synchronously, by using SQLite. """ dialect = SQLITE_DIALECT paramstyle = sqlite3.paramstyle def __init__(self, schema): self.schema = schema def setUp(self, test): connect = synchronousConnectionFactory(test) con = connect() cur = con.cursor() cur.executescript(self.schema) con.commit() super(SteppablePoolHelper, self).setUp(test, connect) def rows(self, sql): """ Get some rows from the database to compare in a test. """ con = self.connect() cur = con.cursor() cur.execute(sql) result = cur.fetchall() con.commit() return result def synchronousConnectionFactory(test): tmpdb = test.mktemp() def connect(): return sqlite3.connect(tmpdb) return connect class Child(object): """ An object with a L{Parent}, in its list of C{children}. """ def __init__(self, parent): self.closed = False self.parent = parent self.parent.children.append(self) def close(self): if self.parent._closeFailQueue: raise self.parent._closeFailQueue.pop(0) self.closed = True class Parent(object): """ An object with a list of L{Child}ren. """ def __init__(self): self.children = [] self._closeFailQueue = [] def childCloseWillFail(self, exception): """ Closing children of this object will result in the given exception. @see: L{ConnectionFactory} """ self._closeFailQueue.append(exception) class FakeConnection(Parent, Child): """ Fake Stand-in for DB-API 2.0 connection. @ivar executions: the number of statements which have been executed. """ executions = 0 def __init__(self, factory): """ Initialize list of cursors """ Parent.__init__(self) Child.__init__(self, factory) self.id = factory.idcounter.next() self._executeFailQueue = [] self._commitCount = 0 self._rollbackCount = 0 def executeWillFail(self, thunk): """ The next call to L{FakeCursor.execute} will fail with an exception returned from the given callable. """ self._executeFailQueue.append(thunk) @property def cursors(self): "Alias to make tests more readable." return self.children def cursor(self): return FakeCursor(self) def commit(self): self._commitCount += 1 if self.parent.commitFail: self.parent.commitFail = False raise CommitFail() def rollback(self): self._rollbackCount += 1 if self.parent.rollbackFail: self.parent.rollbackFail = False raise RollbackFail() class RollbackFail(Exception): """ Sample rollback-failure exception. """ class CommitFail(Exception): """ Sample Commit-failure exception. """ class FakeCursor(Child): """ Fake stand-in for a DB-API 2.0 cursor. """ def __init__(self, connection): Child.__init__(self, connection) self.rowcount = 0 # not entirely correct, but all we care about is its truth value. self.description = False self.variables = [] self.allExecutions = [] @property def connection(self): "Alias to make tests more readable." return self.parent def execute(self, sql, args=()): self.connection.executions += 1 if self.connection._executeFailQueue: raise self.connection._executeFailQueue.pop(0)() self.allExecutions.append((sql, args)) self.sql = sql factory = self.connection.parent self.description = factory.hasResults if factory.hasResults and factory.shouldUpdateRowcount: self.rowcount = 1 else: self.rowcount = 0 return def var(self, type, *args): """ Return a database variable in the style of the cx_Oracle bindings. """ v = FakeVariable(self, type, args) self.variables.append(v) return v def fetchall(self): """ Just echo the SQL that was executed in the last query. """ if self.connection.parent.hasResults: return [[self.connection.id, self.sql]] if self.description: return [] return None class FakeVariable(object): def __init__(self, cursor, type, args): self.cursor = cursor self.type = type self.args = args def getvalue(self): vv = self.cursor.connection.parent.varvals if vv: return vv.pop(0) return self.cursor.variables.index(self) + 300 def __reduce__(self): raise RuntimeError("Not pickleable (since oracle vars aren't)") class ConnectionFactory(Parent): """ A factory for L{FakeConnection} objects. @ivar shouldUpdateRowcount: Should C{execute} on cursors produced by connections produced by this factory update their C{rowcount} or just their C{description} attribute? @ivar hasResults: should cursors produced by connections by this factory have any results returned by C{fetchall()}? """ rollbackFail = False commitFail = False def __init__(self, shouldUpdateRowcount=True, hasResults=True): Parent.__init__(self) self.idcounter = count(1) self._connectResultQueue = [] self.defaultConnect() self.varvals = [] self.shouldUpdateRowcount = shouldUpdateRowcount self.hasResults = hasResults @property def connections(self): "Alias to make tests more readable." return self.children def connect(self): """ Implement the C{ConnectionFactory} callable expected by L{ConnectionPool}. """ if self._connectResultQueue: thunk = self._connectResultQueue.pop(0) else: thunk = self._default return thunk() def willConnect(self): """ Used by tests to queue a successful result for connect(). """ def thunk(): return FakeConnection(self) self._connectResultQueue.append(thunk) def willConnectTo(self): """ Queue a successful result for connect() and immediately add it as a child to this L{ConnectionFactory}. @return: a connection object @rtype: L{FakeConnection} """ aConnection = FakeConnection(self) def thunk(): return aConnection self._connectResultQueue.append(thunk) return aConnection def willFail(self): """ Used by tests to queue a successful result for connect(). """ def thunk(): raise FakeConnectionError() self._connectResultQueue.append(thunk) def defaultConnect(self): """ By default, connection attempts will succeed. """ self.willConnect() self._default = self._connectResultQueue.pop() def defaultFail(self): """ By default, connection attempts will fail. """ self.willFail() self._default = self._connectResultQueue.pop() class FakeConnectionError(Exception): """ Synthetic error that might occur during connection. """ calendarserver-5.2+dfsg/twext/enterprise/adbapi2.py0000644000175000017500000015243012263343324021525 0ustar rahulrahul# -*- test-case-name: twext.enterprise.test.test_adbapi2 -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Asynchronous multi-process connection pool. This is similar to L{twisted.enterprise.adbapi}, but can hold a transaction (and thereby a thread) open across multiple asynchronous operations, rather than forcing the transaction to be completed entirely in a thread and/or entirely in a single SQL statement. Also, this module includes an AMP protocol for multiplexing connections through a single choke-point host. This is not currently in use, however, as AMP needs some optimization before it can be low-impact enough for this to be an improvement. """ import sys import weakref from cStringIO import StringIO from cPickle import dumps, loads from itertools import count from zope.interface import implements from twisted.internet.defer import inlineCallbacks from twisted.internet.defer import returnValue from twisted.internet.defer import DeferredList from twisted.internet.defer import Deferred from twisted.protocols.amp import Boolean from twisted.python.failure import Failure from twisted.protocols.amp import Argument, String, Command, AMP, Integer from twisted.internet import reactor as _reactor from twisted.application.service import Service from twisted.python import log from twisted.internet.defer import maybeDeferred from twisted.python.components import proxyForInterface from twext.internet.threadutils import ThreadHolder from twisted.internet.defer import succeed from twext.enterprise.ienterprise import ConnectionError from twext.enterprise.ienterprise import IDerivedParameter from twisted.internet.defer import fail from twext.enterprise.ienterprise import ( AlreadyFinishedError, IAsyncTransaction, POSTGRES_DIALECT, ICommandBlock ) # FIXME: there should be no defaults for connection metadata, it should be # discovered dynamically everywhere. Right now it's specified as an explicit # argument to the ConnectionPool but it should probably be determined # automatically from the database binding. DEFAULT_PARAM_STYLE = 'pyformat' DEFAULT_DIALECT = POSTGRES_DIALECT def _forward(thunk): """ Forward an attribute to the connection pool. """ @property def getter(self): return getattr(self._pool, thunk.func_name) return getter def _destructively(aList): """ Destructively iterate a list, popping elements from the beginning. """ while aList: yield aList.pop(0) def _deriveParameters(cursor, args): """ Some DB-API extensions need to call special extension methods on the cursor itself before executing. @param cursor: The DB-API cursor object to derive parameters from. @param args: the parameters being specified to C{execSQL}. This list will be modified to present parameters suitable to pass to the C{cursor}'s C{execute} method. @return: a list of L{IDerivedParameter} providers which had C{preQuery} executed on them, so that after the query they may have C{postQuery} executed. This may also be C{None} if no parameters were derived. @see: {IDerivedParameter} """ # TODO: have the underlying connection report whether it has any # IDerivedParameters that it knows about, so we can skip even inspecting # the arguments if none of them could possibly provide # IDerivedParameter. derived = None for n, arg in enumerate(args): if IDerivedParameter.providedBy(arg): if derived is None: # Be as sparing as possible with extra allocations, as this # usually isn't needed, and we're doing a ton of extra work to # support it. derived = [] derived.append(arg) args[n] = arg.preQuery(cursor) return derived def _deriveQueryEnded(cursor, derived): """ A query which involved some L{IDerivedParameter}s just ended. Execute any post-query cleanup or tasks that those parameters have to do. @param cursor: The DB-API object that derived the query. @param derived: The L{IDerivedParameter} providers that were given C{preQuery} notifications when the query started. @return: C{None} """ for arg in derived: arg.postQuery(cursor) class _ConnectedTxn(object): """ L{IAsyncTransaction} implementation based on a L{ThreadHolder} in the current process. """ implements(IAsyncTransaction) noisy = False def __init__(self, pool, threadHolder, connection, cursor): self._pool = pool self._completed = "idle" self._cursor = cursor self._connection = connection self._holder = threadHolder self._first = True @_forward def paramstyle(self): """ The paramstyle attribute is mirrored from the connection pool. """ @_forward def dialect(self): """ The dialect attribute is mirrored from the connection pool. """ def _reallyExecSQL(self, sql, args=None, raiseOnZeroRowCount=None): """ Execute the given SQL on a thread, using a DB-API 2.0 cursor. This method is invoked internally on a non-reactor thread, one dedicated to and associated with the current cursor. It executes the given SQL, re-connecting first if necessary, re-cycling the old connection if necessary, and then, if there are results from the statement (as determined by the DB-API 2.0 'description' attribute) it will fetch all the rows and return them, leaving them to be relayed to L{_ConnectedTxn.execSQL} via the L{ThreadHolder}. The rules for possibly reconnecting automatically are: if this is the very first statement being executed in this transaction, and an error occurs in C{execute}, close the connection and try again. We will ignore any errors from C{close()} (or C{rollback()}) and log them during this process. This is OK because adbapi2 always enforces transaction discipline: connections are never in autocommit mode, so if the first statement in a transaction fails, nothing can have happened to the database; as per the ADBAPI spec, a lost connection is a rolled-back transaction. In the cases where some databases fail to enforce transaction atomicity (i.e. schema manipulations), re-executing the same statement will result, at worst, in a spurious and harmless error (like "table already exists"), not corruption. @param sql: The SQL string to execute. @type sql: C{str} @param args: The bind parameters to pass to adbapi, if any. @type args: C{list} or C{None} @param raiseOnZeroRowCount: If specified, an exception to raise when no rows are found. @return: all the rows that resulted from execution of the given C{sql}, or C{None}, if the statement is one which does not produce results. @rtype: C{list} of C{tuple}, or C{NoneType} @raise Exception: this function may raise any exception raised by the underlying C{dbapi.connect}, C{cursor.execute}, L{IDerivedParameter.preQuery}, C{connection.cursor}, or C{cursor.fetchall}. @raise raiseOnZeroRowCount: if the argument was specified and no rows were returned by the executed statement. """ wasFirst = self._first # If this is the first time this cursor has been used in this # transaction, remember that, but mark it as now used. self._first = False if args is None: args = [] # Note: as of this writing, derived parameters are only used to support # cx_Oracle's "host variable" feature (i.e. cursor.var()), and creating # a host variable will never be a connection-oriented error (a # disconnected cursor can happily create variables of all types). # However, this may need to move into the 'try' below if other database # features need to compute database arguments based on runtime state. derived = _deriveParameters(self._cursor, args) try: self._cursor.execute(sql, args) except: # If execute() raised an exception, and this was the first thing to # happen in the transaction, then the connection has probably gone # bad in the meanwhile, and we should try again. if wasFirst: # Report the error before doing anything else, since doing # other things may cause the traceback stack to be eliminated # if they raise exceptions (even internally). log.err( Failure(), "Exception from execute() on first statement in " "transaction. Possibly caused by a database server " "restart. Automatically reconnecting now." ) try: self._connection.close() except: # close() may raise an exception to alert us of an error as # well. Right now the only type of error we know about is # "the connection is already closed", which obviously # doesn't need to be handled specially. Unfortunately the # reporting of this type of error is not consistent or # predictable across different databases, or even different # bindings to the same database, so we have to do a # catch-all here. While I can't imagine another type of # error at the moment, bare 'except:'s are notorious for # making debugging surprising error conditions very # difficult, so let's make sure that the error is logged # just in case. log.err( Failure(), "Exception from close() while automatically " "reconnecting. (Probably not serious.)" ) # Now, if either of *these* things fail, there's an error here # that we cannot workaround or address automatically, so no # try:except: for them. self._connection = self._pool.connectionFactory() self._cursor = self._connection.cursor() # Note that although this method is being invoked recursively, # the '_first' flag is re-set at the very top, so we will _not_ # be re-entering it more than once. result = self._reallyExecSQL(sql, args, raiseOnZeroRowCount) return result else: raise if derived is not None: _deriveQueryEnded(self._cursor, derived) if self._cursor.description: # see test_raiseOnZeroRowCountWithUnreliableRowCount rows = self._cursor.fetchall() if not rows: if raiseOnZeroRowCount is not None: raise raiseOnZeroRowCount() return rows else: if raiseOnZeroRowCount is not None and self._cursor.rowcount == 0: raise raiseOnZeroRowCount() return None def execSQL(self, *args, **kw): result = self._holder.submit( lambda: self._reallyExecSQL(*args, **kw) ) if self.noisy: def reportResult(results): sys.stdout.write("\n".join([ "", "SQL: %r %r" % (args, kw), "Results: %r" % (results,), "", ])) return results result.addBoth(reportResult) return result def _end(self, really): """ Common logic for commit or abort. Executed in the main reactor thread. @param really: the callable to execute in the cursor thread to actually do the commit or rollback. @return: a L{Deferred} which fires when the database logic has completed. @raise: L{AlreadyFinishedError} if the transaction has already been committed or aborted. """ if not self._completed: self._completed = "ended" def reallySomething(): """ Do the database work and set appropriate flags. Executed in the cursor thread. """ if self._cursor is None or self._first: return really() self._first = True result = self._holder.submit(reallySomething) self._pool._repoolAfter(self, result) return result else: raise AlreadyFinishedError(self._completed) def commit(self): return self._end(self._connection.commit) def abort(self): return self._end(self._connection.rollback).addErrback(log.err) def reset(self): """ Call this when placing this transaction back into the pool. @raise RuntimeError: if the transaction has not been committed or aborted. """ if not self._completed: raise RuntimeError("Attempt to re-set active transaction.") self._completed = False def _releaseConnection(self): """ Release the thread and database connection associated with this transaction. """ self._completed = "released" self._stopped = True holder = self._holder self._holder = None def _reallyClose(): if self._cursor is None: return self._connection.close() holder.submit(_reallyClose) return holder.stop() class _NoTxn(object): """ An L{IAsyncTransaction} that indicates a local failure before we could even communicate any statements (or possibly even any connection attempts) to the server. """ implements(IAsyncTransaction) def __init__(self, pool, reason): self.paramstyle = pool.paramstyle self.dialect = pool.dialect self.reason = reason def _everything(self, *a, **kw): """ Everything fails with a L{ConnectionError}. """ return fail(ConnectionError(self.reason)) execSQL = _everything commit = _everything abort = _everything class _WaitingTxn(object): """ A L{_WaitingTxn} is an implementation of L{IAsyncTransaction} which cannot yet actually execute anything, so it waits and spools SQL requests for later execution. When a L{_ConnectedTxn} becomes available later, it can be unspooled onto that. """ implements(IAsyncTransaction) def __init__(self, pool): """ Initialize a L{_WaitingTxn} based on a L{ConnectionPool}. (The C{pool} is used only to reflect C{dialect} and C{paramstyle} attributes; not remembered or modified in any way.) """ self._spool = [] self.paramstyle = pool.paramstyle self.dialect = pool.dialect def _enspool(self, cmd, a=(), kw={}): d = Deferred() self._spool.append((d, cmd, a, kw)) return d def _iterDestruct(self): """ Iterate the spool list destructively, while popping items from the beginning. This allows code which executes more SQL in the callback of a Deferred to not interfere with the originally submitted order of commands. """ return _destructively(self._spool) def _unspool(self, other): """ Unspool this transaction onto another transaction. @param other: another provider of L{IAsyncTransaction} which will actually execute the SQL statements we have been buffering. """ for (d, cmd, a, kw) in self._iterDestruct(): self._relayCommand(other, d, cmd, a, kw) def _relayCommand(self, other, d, cmd, a, kw): """ Relay a single command to another transaction. """ maybeDeferred(getattr(other, cmd), *a, **kw).chainDeferred(d) def execSQL(self, *a, **kw): return self._enspool('execSQL', a, kw) def commit(self): return self._enspool('commit') def abort(self): """ Succeed and do nothing. The actual logic for this method is mostly implemented by L{_SingleTxn._stopWaiting}. """ return succeed(None) class _HookableOperation(object): def __init__(self): self._hooks = [] @inlineCallbacks def runHooks(self, ignored=None): """ Callback for C{commit} and C{abort} Deferreds. """ for operation in _destructively(self._hooks): yield operation() self.clear() returnValue(ignored) def addHook(self, operation): """ Implement L{IAsyncTransaction.postCommit}. """ if self._hooks is not None: self._hooks.append(operation) def clear(self): """ Remove all hooks from this operation. Once this is called, no more hooks can be added ever again. """ self._hooks = None class _CommitAndAbortHooks(object): """ Shared implementation of post-commit and post-abort hooks. """ # FIXME: this functionality needs direct tests, although it's pretty well- # covered by txdav's test suite. def __init__(self): self._preCommit = _HookableOperation() self._commit = _HookableOperation() self._abort = _HookableOperation() def _commitWithHooks(self, doCommit): """ Run pre-hooks, commit, the real DB commit, and then post-hooks. """ pre = self._preCommit.runHooks() def ok(ignored): self._abort.clear() return doCommit().addCallback(self._commit.runHooks) def failed(why): return self.abort().addCallback(lambda ignored: why) return pre.addCallbacks(ok, failed) def preCommit(self, operation): return self._preCommit.addHook(operation) def postCommit(self, operation): return self._commit.addHook(operation) def postAbort(self, operation): return self._abort.addHook(operation) class _SingleTxn(_CommitAndAbortHooks, proxyForInterface(iface=IAsyncTransaction, originalAttribute='_baseTxn')): """ A L{_SingleTxn} is a single-use wrapper for the longer-lived L{_ConnectedTxn}, so that if a badly-behaved API client accidentally hangs on to one of these and, for example C{.abort()}s it multiple times once another client is using that connection, it will get some harmless tracebacks. It's a wrapper around a "real" implementation; either a L{_ConnectedTxn}, L{_NoTxn}, or L{_WaitingTxn} depending on the availability of real underlying datbase connections. This is the only L{IAsyncTransaction} implementation exposed to application code. It's also the only implementor of the C{commandBlock} method for grouping commands together. """ def __init__(self, pool, baseTxn): super(_SingleTxn, self).__init__() self._pool = pool self._baseTxn = baseTxn self._completed = False self._currentBlock = None self._blockedQueue = None self._pendingBlocks = [] self._stillExecuting = [] def __repr__(self): """ Reveal the backend in the string representation. """ return '_SingleTxn(%r)' % (self._baseTxn,) def _unspoolOnto(self, baseTxn): """ Replace my C{_baseTxn}, currently a L{_WaitingTxn}, with a new implementation of L{IAsyncTransaction} that will actually do the work; either a L{_ConnectedTxn} or a L{_NoTxn}. """ spooledBase = self._baseTxn self._baseTxn = baseTxn spooledBase._unspool(baseTxn) def execSQL(self, sql, args=None, raiseOnZeroRowCount=None): return self._execSQLForBlock(sql, args, raiseOnZeroRowCount, None) def _execSQLForBlock(self, sql, args, raiseOnZeroRowCount, block): """ Execute some SQL for a particular L{CommandBlock}; or, if the given C{block} is C{None}, execute it in the outermost transaction context. """ self._checkComplete() if block is None and self._blockedQueue is not None: return self._blockedQueue.execSQL(sql, args, raiseOnZeroRowCount) # 'block' should always be _currentBlock at this point. d = super(_SingleTxn, self).execSQL(sql, args, raiseOnZeroRowCount) self._stillExecuting.append(d) def itsDone(result): self._stillExecuting.remove(d) self._checkNextBlock() return result d.addBoth(itsDone) return d def _checkNextBlock(self): """ Check to see if there are any blocks pending statements waiting to execute, and execute the next one if there are no outstanding execute calls. """ if self._stillExecuting: # If we're still executing statements, nevermind. We'll get called # again by the 'itsDone' callback above. return if self._currentBlock is not None: # If there's still a current block, then keep it going. We'll be # called by the '_finishExecuting' callback below. return # There's no block executing now. What to do? if self._pendingBlocks: # If there are pending blocks, start one of them. self._currentBlock = self._pendingBlocks.pop(0) d = self._currentBlock._startExecuting() d.addCallback(self._finishExecuting) elif self._blockedQueue is not None: # If there aren't any pending blocks any more, and there are # spooled statements that aren't part of a block, unspool all the # statements that have been held up until this point. bq = self._blockedQueue self._blockedQueue = None bq._unspool(self) def _finishExecuting(self, result): """ The active block just finished executing. Clear it and see if there are more blocks to execute, or if all the blocks are done and we should execute any queued free statements. """ self._currentBlock = None self._checkNextBlock() def commit(self): if self._blockedQueue is not None: # We're in the process of executing a block of commands. Wait # until they're done. (Commit will be repeated in # _checkNextBlock.) return self._blockedQueue.commit() def reallyCommit(): self._markComplete() return super(_SingleTxn, self).commit() return self._commitWithHooks(reallyCommit) def abort(self): self._markComplete() self._commit.clear() self._preCommit.clear() result = super(_SingleTxn, self).abort() if self in self._pool._waiting: self._stopWaiting() result.addCallback(self._abort.runHooks) return result def _stopWaiting(self): """ Stop waiting for a free transaction and fail. """ self._pool._waiting.remove(self) self._completed = True self._unspoolOnto(_NoTxn(self._pool, "connection pool shut down while txn " "waiting for database connection.")) def _checkComplete(self): """ If the transaction is complete, raise L{AlreadyFinishedError} """ if self._completed: raise AlreadyFinishedError() def _markComplete(self): """ Mark the transaction as complete, raising AlreadyFinishedError. """ self._checkComplete() self._completed = True def commandBlock(self): """ Create a L{CommandBlock} which will wait for all currently spooled commands to complete before executing its own. """ self._checkComplete() block = CommandBlock(self) if self._currentBlock is None: self._blockedQueue = _WaitingTxn(self._pool) # FIXME: test the case where it's ready immediately. self._checkNextBlock() return block def __del__(self): """ When garbage collected, a L{_SingleTxn} recycles itself. """ try: if not self._completed: self.abort() except AlreadyFinishedError: # The underlying transaction might already be completed without us # knowing; for example if the service shuts down. pass class _Unspooler(object): def __init__(self, orig): self.orig = orig def execSQL(self, sql, args=None, raiseOnZeroRowCount=None): """ Execute some SQL, but don't track a new Deferred. """ return self.orig.execSQL(sql, args, raiseOnZeroRowCount, False) class CommandBlock(object): """ A partial implementation of L{IAsyncTransaction} that will group execSQL calls together. Does not implement commit() or abort(), because this will simply group commands. In order to implement sub-transactions or checkpoints, some understanding of the SQL dialect in use by the underlying connection is required. Instead, it provides 'end'. """ implements(ICommandBlock) def __init__(self, singleTxn): self._singleTxn = singleTxn self.paramstyle = singleTxn.paramstyle self.dialect = singleTxn.dialect self._spool = _WaitingTxn(singleTxn._pool) self._started = False self._ended = False self._waitingForEnd = [] self._endDeferred = Deferred() singleTxn._pendingBlocks.append(self) def _startExecuting(self): self._started = True self._spool._unspool(_Unspooler(self)) return self._endDeferred def execSQL(self, sql, args=None, raiseOnZeroRowCount=None, track=True): """ Execute some SQL within this command block. @param sql: the SQL string to execute. @param args: the SQL arguments. @param raiseOnZeroRowCount: see L{IAsyncTransaction.execSQL} @param track: an internal parameter; was this called by application code or as part of unspooling some previously-queued requests? True if application code, False if unspooling. """ if track and self._ended: raise AlreadyFinishedError() self._singleTxn._checkComplete() if self._singleTxn._currentBlock is self and self._started: d = self._singleTxn._execSQLForBlock( sql, args, raiseOnZeroRowCount, self) else: d = self._spool.execSQL(sql, args, raiseOnZeroRowCount) if track: self._trackForEnd(d) return d def _trackForEnd(self, d): """ Watch the following L{Deferred}, since we need to watch it to determine when C{end} should be considered done, and the next CommandBlock or regular SQL statement should be unqueued. """ self._waitingForEnd.append(d) def end(self): """ The block of commands has completed. Allow other SQL to run on the underlying L{IAsyncTransaction}. """ # FIXME: test the case where end() is called when it's not the current # executing block. if self._ended: raise AlreadyFinishedError() self._ended = True # TODO: maybe this should return a Deferred that's a clone of # _endDeferred, so that callers can determine when the block is really # complete? Struggling for an actual use-case on that one. DeferredList(self._waitingForEnd).chainDeferred(self._endDeferred) class _ConnectingPseudoTxn(object): """ This is a pseudo-Transaction for bookkeeping purposes. When a connection has asked to connect, but has not yet completed connecting, the L{ConnectionPool} still needs a way to shut it down. This object provides that tracking handle, and will be present in the pool's C{busy} list while it is populating the list. """ _retry = None def __init__(self, pool, holder): """ Initialize the L{_ConnectingPseudoTxn}; get ready to connect. @param pool: The pool that this connection attempt is participating in. @type pool: L{ConnectionPool} @param holder: the L{ThreadHolder} allocated to this connection attempt and subsequent SQL executions for this connection. @type holder: L{ThreadHolder} """ self._pool = pool self._holder = holder self._aborted = False def abort(self): """ Ignore the result of attempting to connect to this database, and instead simply close the connection and free the L{ThreadHolder} allocated for it. """ self._aborted = True if self._retry is not None: self._retry.cancel() d = self._holder.stop() def removeme(ignored): if self in self._pool._busy: self._pool._busy.remove(self) d.addCallback(removeme) return d def _fork(x): """ Produce a L{Deferred} that will fire when another L{Deferred} fires without disturbing its results. """ d = Deferred() def fired(result): d.callback(result) return result x.addBoth(fired) return d class ConnectionPool(Service, object): """ This is a central service that has a threadpool and executes SQL statements asynchronously, in a pool. @ivar connectionFactory: a 0-or-1-argument callable that returns a DB-API connection. The optional argument can be used as a label for diagnostic purposes. @ivar maxConnections: The connection pool will not attempt to make more than this many concurrent connections to the database. @type maxConnections: C{int} @ivar reactor: The reactor used for scheduling threads as well as retries for failed connect() attempts. @type reactor: L{IReactorTime} and L{IReactorThreads} provider. @ivar _free: The list of free L{_ConnectedTxn} objects which are not currently attached to a L{_SingleTxn} object, and have active connections ready for processing a new transaction. @ivar _busy: The list of busy L{_ConnectedTxn} objects; those currently servicing an unfinished L{_SingleTxn} object. @ivar _finishing: The list of 2-tuples of L{_ConnectedTxn} objects which have had C{abort} or C{commit} called on them, but are not done executing that method, and the L{Deferred} returned from that method that will be fired when its execution has completed. @ivar _waiting: The list of L{_SingleTxn} objects attached to a L{_WaitingTxn}; i.e. those which are awaiting a connection to become free so that they can be executed. @ivar _stopping: Is this L{ConnectionPool} in the process of shutting down? (If so, new connections will not be established.) """ reactor = _reactor RETRY_TIMEOUT = 10.0 def __init__(self, connectionFactory, maxConnections=10, paramstyle=DEFAULT_PARAM_STYLE, dialect=DEFAULT_DIALECT): super(ConnectionPool, self).__init__() self.connectionFactory = connectionFactory self.maxConnections = maxConnections self.paramstyle = paramstyle self.dialect = dialect self._free = [] self._busy = [] self._waiting = [] self._finishing = [] self._stopping = False def startService(self): """ Increase the thread pool size of the reactor by the number of threads that this service may consume. This is important because unlike most L{IReactorThreads} users, the connection work units are very long-lived and block until this service has been stopped. """ super(ConnectionPool, self).startService() tp = self.reactor.getThreadPool() self.reactor.suggestThreadPoolSize(tp.max + self.maxConnections) @inlineCallbacks def stopService(self): """ Forcibly abort any outstanding transactions, and release all resources (notably, threads). """ super(ConnectionPool, self).stopService() self._stopping = True # Phase 1: Cancel any transactions that are waiting so they won't try # to eagerly acquire new connections as they flow into the free-list. while self._waiting: waiting = self._waiting[0] waiting._stopWaiting() # Phase 2: Wait for all the Deferreds from the L{_ConnectedTxn}s that # have *already* been stopped. while self._finishing: yield _fork(self._finishing[0][1]) # Phase 3: All of the busy transactions must be aborted first. As each # one is aborted, it will remove itself from the list. while self._busy: yield self._busy[0].abort() # Phase 4: All transactions should now be in the free list, since # 'abort()' will have put them there. Shut down all the associated # ThreadHolders. while self._free: # Releasing a L{_ConnectedTxn} doesn't automatically recycle it / # remove it the way aborting a _SingleTxn does, so we need to # .pop() here. L{_ConnectedTxn.stop} really shouldn't be able to # fail, as it's just stopping the thread, and the holder's stop() # is independently submitted from .abort() / .close(). yield self._free.pop()._releaseConnection() tp = self.reactor.getThreadPool() self.reactor.suggestThreadPoolSize(tp.max - self.maxConnections) def _createHolder(self): """ Create a L{ThreadHolder}. (Test hook.) """ return ThreadHolder(self.reactor) def connection(self, label=""): """ Find and immediately return an L{IAsyncTransaction} object. Execution of statements, commit and abort on that transaction may be delayed until a real underlying database connection is available. @return: an L{IAsyncTransaction} """ if self._stopping: # FIXME: should be wrapping a _SingleTxn around this to get # .commandBlock() return _NoTxn(self, "txn created while DB pool shutting down") if self._free: basetxn = self._free.pop(0) self._busy.append(basetxn) txn = _SingleTxn(self, basetxn) else: txn = _SingleTxn(self, _WaitingTxn(self)) self._waiting.append(txn) # FIXME/TESTME: should be len(self._busy) + len(self._finishing) # (free doesn't need to be considered, as it's tested above) if self._activeConnectionCount() < self.maxConnections: self._startOneMore() return txn def _activeConnectionCount(self): """ @return: the number of active outgoing connections to the database. """ return len(self._busy) + len(self._finishing) def _startOneMore(self): """ Start one more _ConnectedTxn. """ holder = self._createHolder() holder.start() txn = _ConnectingPseudoTxn(self, holder) # take up a slot in the 'busy' list, sit there so we can be aborted. self._busy.append(txn) def initCursor(): # support threadlevel=1; we can't necessarily cursor() in a # different thread than we do transactions in. connection = self.connectionFactory() cursor = connection.cursor() return (connection, cursor) def finishInit((connection, cursor)): if txn._aborted: return baseTxn = _ConnectedTxn( pool=self, threadHolder=holder, connection=connection, cursor=cursor ) self._busy.remove(txn) self._repoolNow(baseTxn) def maybeTryAgain(f): log.err(f, "Re-trying connection due to connection failure") txn._retry = self.reactor.callLater(self.RETRY_TIMEOUT, resubmit) def resubmit(): d = holder.submit(initCursor) d.addCallbacks(finishInit, maybeTryAgain) resubmit() def _repoolAfter(self, txn, d): """ Re-pool the given L{_ConnectedTxn} after the given L{Deferred} has fired. """ self._busy.remove(txn) finishRecord = (txn, d) self._finishing.append(finishRecord) def repool(result): self._finishing.remove(finishRecord) self._repoolNow(txn) return result def discard(result): self._finishing.remove(finishRecord) txn._releaseConnection() self._startOneMore() return result return d.addCallbacks(repool, discard) def _repoolNow(self, txn): """ Recycle a L{_ConnectedTxn} into the free list. """ txn.reset() if self._waiting: waiting = self._waiting.pop(0) self._busy.append(txn) waiting._unspoolOnto(txn) else: self._free.append(txn) def txnarg(): return [('transactionID', Integer())] CHUNK_MAX = 0xffff class BigArgument(Argument): """ An argument whose payload can be larger than L{CHUNK_MAX}, by splitting across multiple AMP keys. """ def fromBox(self, name, strings, objects, proto): value = StringIO() for counter in count(): chunk = strings.get("%s.%d" % (name, counter)) if chunk is None: break value.write(chunk) objects[name] = self.fromString(value.getvalue()) def toBox(self, name, strings, objects, proto): value = StringIO(self.toString(objects[name])) for counter in count(): nextChunk = value.read(CHUNK_MAX) if not nextChunk: break strings["%s.%d" % (name, counter)] = nextChunk class Pickle(BigArgument): """ A pickle sent over AMP. This is to serialize the 'args' argument to C{execSQL}, which is the dynamically-typed 'args' list argument to a DB-API C{execute} function, as well as its dynamically-typed result ('rows'). This should be cleaned up into a nicer structure, but this is not a network protocol, so we can be a little relaxed about security. This is a L{BigArgument} rather than a regular L{Argument} because individual arguments and query results need to contain entire vCard or iCalendar documents, which can easily be greater than 64k. """ def toString(self, inObject): return dumps(inObject) def fromString(self, inString): return loads(inString) class FailsafeException(Exception): """ Exception raised by all responders. """ _quashErrors = { FailsafeException: "SOMETHING_UNKNOWN", AlreadyFinishedError: "ALREADY_FINISHED", ConnectionError: "CONNECTION_ERROR", } def failsafeResponder(command): """ Wrap an AMP command responder in some fail-safe logic, to make it so that unknown errors won't drop the connection, as AMP's default behavior would. """ def wrap(inner): @inlineCallbacks def innerinner(*a, **k): try: val = yield inner(*a, **k) except: f = Failure() if f.type in command.errors: returnValue(f) else: log.err(Failure(), "shared database connection pool error") raise FailsafeException() else: returnValue(val) return command.responder(innerinner) return wrap class StartTxn(Command): """ Start a transaction, identified with an ID generated by the client. """ arguments = txnarg() errors = _quashErrors class ExecSQL(Command): """ Execute an SQL statement. """ arguments = [('sql', String()), ('queryID', String()), ('args', Pickle()), ('blockID', String()), ('reportZeroRowCount', Boolean())] + txnarg() errors = _quashErrors class StartBlock(Command): """ Create a new SQL command block. """ arguments = [("blockID", String())] + txnarg() errors = _quashErrors class EndBlock(Command): """ Create a new SQL command block. """ arguments = [("blockID", String())] + txnarg() errors = _quashErrors class Row(Command): """ A row has been returned. Sent from server to client in response to L{ExecSQL}. """ arguments = [('queryID', String()), ('row', Pickle())] errors = _quashErrors class QueryComplete(Command): """ A query issued with L{ExecSQL} is complete. """ arguments = [('queryID', String()), ('norows', Boolean()), ('derived', Pickle()), ('noneResult', Boolean())] errors = _quashErrors class Commit(Command): arguments = txnarg() errors = _quashErrors class Abort(Command): arguments = txnarg() errors = _quashErrors class _NoRows(Exception): """ Placeholder exception to report zero rows. """ class ConnectionPoolConnection(AMP): """ A L{ConnectionPoolConnection} is a single connection to a L{ConnectionPool}. This is the server side of the connection-pool-sharing protocol; it implements all the AMP responders necessary. """ def __init__(self, pool): """ Initialize a mapping of transaction IDs to transaction objects. """ super(ConnectionPoolConnection, self).__init__() self.pool = pool self._txns = {} self._blocks = {} def stopReceivingBoxes(self, why): log.msg("(S) Stopped receiving boxes: " + why.getTraceback()) def unhandledError(self, failure): """ An unhandled error has occurred. Since we can't really classify errors well on this protocol, log it and forget it. """ log.err(failure, "Shared connection pool server encountered an error.") @failsafeResponder(StartTxn) def start(self, transactionID): self._txns[transactionID] = self.pool.connection() return {} @failsafeResponder(StartBlock) def startBlock(self, transactionID, blockID): self._blocks[blockID] = self._txns[transactionID].commandBlock() return {} @failsafeResponder(EndBlock) def endBlock(self, transactionID, blockID): self._blocks[blockID].end() return {} @failsafeResponder(ExecSQL) @inlineCallbacks def receivedSQL(self, transactionID, queryID, sql, args, blockID, reportZeroRowCount): derived = None noneResult = False for param in args: if IDerivedParameter.providedBy(param): if derived is None: derived = [] derived.append(param) if blockID: txn = self._blocks[blockID] else: txn = self._txns[transactionID] if reportZeroRowCount: rozrc = _NoRows else: rozrc = None try: rows = yield txn.execSQL(sql, args, rozrc) except _NoRows: norows = True else: norows = False if rows is not None: for row in rows: # Either this should be yielded or it should be # requiresAnswer=False self.callRemote(Row, queryID=queryID, row=row) else: noneResult = True self.callRemote(QueryComplete, queryID=queryID, norows=norows, derived=derived, noneResult=noneResult) returnValue({}) def _complete(self, transactionID, thunk): txn = self._txns.pop(transactionID) return thunk(txn).addCallback(lambda ignored: {}) @failsafeResponder(Commit) def commit(self, transactionID): """ Successfully complete the given transaction. """ def commitme(x): return x.commit() return self._complete(transactionID, commitme) @failsafeResponder(Abort) def abort(self, transactionID): """ Roll back the given transaction. """ def abortme(x): return x.abort() return self._complete(transactionID, abortme) class ConnectionPoolClient(AMP): """ A client which can execute SQL. """ def __init__(self, dialect=POSTGRES_DIALECT, paramstyle=DEFAULT_PARAM_STYLE): # See DEFAULT_PARAM_STYLE FIXME above. super(ConnectionPoolClient, self).__init__() self._nextID = count().next self._txns = weakref.WeakValueDictionary() self._queries = {} self.dialect = dialect self.paramstyle = paramstyle def unhandledError(self, failure): """ An unhandled error has occurred. Since we can't really classify errors well on this protocol, log it and forget it. """ log.err(failure, "Shared connection pool client encountered an error.") def stopReceivingBoxes(self, why): log.msg("(C) Stopped receiving boxes: " + why.getTraceback()) def newTransaction(self): """ Create a new networked provider of L{IAsyncTransaction}. (This will ultimately call L{ConnectionPool.connection} on the other end of the wire.) @rtype: L{IAsyncTransaction} """ txnid = str(self._nextID()) txn = _NetTransaction(client=self, transactionID=txnid) self._txns[txnid] = txn self.callRemote(StartTxn, transactionID=txnid) return txn @failsafeResponder(Row) def row(self, queryID, row): self._queries[queryID].row(row) return {} @failsafeResponder(QueryComplete) def complete(self, queryID, norows, derived, noneResult): self._queries.pop(queryID).done(norows, derived, noneResult) return {} class _Query(object): def __init__(self, sql, raiseOnZeroRowCount, args): self.sql = sql self.args = args self.results = [] self.deferred = Deferred() self.raiseOnZeroRowCount = raiseOnZeroRowCount def row(self, row): """ A row was received. """ self.results.append(row) def done(self, norows, derived, noneResult): """ The query is complete. @param norows: A boolean. True if there were not any rows. @param derived: either C{None} or a C{list} of L{IDerivedParameter} providers initially passed into the C{execSQL} that started this query. The values of these object swill mutate the original input parameters to resemble them. Although L{IDerivedParameter.preQuery} and L{IDerivedParameter.postQuery} are invoked on the other end of the wire, the local objects will be made to appear as though they were called here. @param noneResult: should the result of the query be C{None} (i.e. did it not have a C{description} on the cursor). """ if noneResult and not self.results: results = None else: results = self.results if derived is not None: # 1) Bleecchh. # 2) FIXME: add some direct tests in test_adbapi2, the unit test # for this crosses some abstraction boundaries so it's a little # integration-y and in the tests for twext.enterprise.dal for remote, local in zip(derived, self._deriveDerived()): local.__dict__ = remote.__dict__ if norows and (self.raiseOnZeroRowCount is not None): exc = self.raiseOnZeroRowCount() self.deferred.errback(Failure(exc)) else: self.deferred.callback(results) def _deriveDerived(self): derived = None for param in self.args: if IDerivedParameter.providedBy(param): if derived is None: derived = [] derived.append(param) return derived class _NetTransaction(_CommitAndAbortHooks): """ A L{_NetTransaction} is an L{AMP}-protocol-based provider of the L{IAsyncTransaction} interface. It sends SQL statements, query results, and commit/abort commands via an AMP socket to a pooling process. """ implements(IAsyncTransaction) def __init__(self, client, transactionID): """ Initialize a transaction with a L{ConnectionPoolClient} and a unique transaction identifier. """ super(_NetTransaction, self).__init__() self._client = client self._transactionID = transactionID self._completed = False self._committing = False self._committed = False @property def paramstyle(self): """ Forward 'paramstyle' attribute to the client. """ return self._client.paramstyle @property def dialect(self): """ Forward 'dialect' attribute to the client. """ return self._client.dialect def execSQL(self, sql, args=None, raiseOnZeroRowCount=None, blockID=""): if not blockID: if self._completed: raise AlreadyFinishedError() if args is None: args = [] client = self._client queryID = str(client._nextID()) query = client._queries[queryID] = _Query(sql, raiseOnZeroRowCount, args) result = ( client.callRemote( ExecSQL, queryID=queryID, sql=sql, args=args, transactionID=self._transactionID, blockID=blockID, reportZeroRowCount=raiseOnZeroRowCount is not None, ) .addCallback(lambda nothing: query.deferred) ) return result def _complete(self, command): if self._completed: raise AlreadyFinishedError() self._completed = True return self._client.callRemote( command, transactionID=self._transactionID ).addCallback(lambda x: None) def commit(self): def reallyCommit(): self._committing = True def done(whatever): self._committed = True return whatever return self._complete(Commit).addBoth(done) return self._commitWithHooks(reallyCommit) def abort(self): self._commit.clear() self._preCommit.clear() return self._complete(Abort).addCallback(self._abort.runHooks) def commandBlock(self): if self._completed: raise AlreadyFinishedError() blockID = str(self._client._nextID()) self._client.callRemote( StartBlock, blockID=blockID, transactionID=self._transactionID ) return _NetCommandBlock(self, blockID) def __del__(self): """ When a L{_NetTransaction} is garabage collected, it aborts itself. """ if not self._completed: def shush(f): f.trap(ConnectionError, AlreadyFinishedError) self.abort().addErrback(shush) class _NetCommandBlock(object): """ Net command block. """ implements(ICommandBlock) def __init__(self, transaction, blockID): self._transaction = transaction self._blockID = blockID self._ended = False @property def paramstyle(self): """ Forward 'paramstyle' attribute to the transaction. """ return self._transaction.paramstyle @property def dialect(self): """ Forward 'dialect' attribute to the transaction. """ return self._transaction.dialect def execSQL(self, sql, args=None, raiseOnZeroRowCount=None): """ Execute some SQL on this command block. """ if ( self._ended or self._transaction._completed and not self._transaction._committing or self._transaction._committed ): raise AlreadyFinishedError() return self._transaction.execSQL(sql, args, raiseOnZeroRowCount, self._blockID) def end(self): """ End this block. """ if self._ended: raise AlreadyFinishedError() self._ended = True self._transaction._client.callRemote( EndBlock, blockID=self._blockID, transactionID=self._transaction._transactionID ) calendarserver-5.2+dfsg/twext/enterprise/locking.py0000644000175000017500000000654612263346572021667 0ustar rahulrahul# -*- test-case-name: twext.enterprise.test.test_locking -*- ## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Utilities to restrict concurrency based on mutual exclusion. """ from twext.enterprise.dal.model import Table from twext.enterprise.dal.model import SQLType from twext.enterprise.dal.model import Constraint from twext.enterprise.dal.syntax import SchemaSyntax from twext.enterprise.dal.model import Schema from twext.enterprise.dal.record import Record from twext.enterprise.dal.record import fromTable class AlreadyUnlocked(Exception): """ The lock you were trying to unlock was already unlocked. """ class LockTimeout(Exception): """ The lock you were trying to lock was already locked causing a timeout. """ def makeLockSchema(inSchema): """ Create a self-contained schema just for L{Locker} use, in C{inSchema}. @param inSchema: a L{Schema} to add the locks table to. @type inSchema: L{Schema} @return: inSchema """ LockTable = Table(inSchema, 'NAMED_LOCK') LockTable.addColumn("LOCK_NAME", SQLType("varchar", 255)) LockTable.tableConstraint(Constraint.NOT_NULL, ["LOCK_NAME"]) LockTable.tableConstraint(Constraint.UNIQUE, ["LOCK_NAME"]) LockTable.primaryKey = [LockTable.columnNamed("LOCK_NAME")] return inSchema LockSchema = SchemaSyntax(makeLockSchema(Schema(__file__))) class NamedLock(Record, fromTable(LockSchema.NAMED_LOCK)): """ An L{AcquiredLock} lock against a shared data store that the current process holds via the referenced transaction. """ @classmethod def acquire(cls, txn, name): """ Acquire a lock with the given name. @param name: The name of the lock to acquire. Against the same store, no two locks may be acquired. @type name: L{unicode} @return: a L{Deferred} that fires with an L{AcquiredLock} when the lock has fired, or fails when the lock has not been acquired. """ def autoRelease(self): txn.preCommit(lambda: self.release(True)) return self def lockFailed(f): raise LockTimeout(name) return cls.create(txn, lockName=name).addCallback(autoRelease).addErrback(lockFailed) def release(self, ignoreAlreadyUnlocked=False): """ Release this lock. @param ignoreAlreadyUnlocked: If you don't care about the current status of this lock, and just want to release it if it is still acquired, pass this parameter as L{True}. Otherwise this method will raise an exception if it is invoked when the lock has already been released. @raise: L{AlreadyUnlocked} @return: A L{Deferred} that fires with L{None} when the lock has been unlocked. """ return self.delete() calendarserver-5.2+dfsg/twext/enterprise/queue.py0000644000175000017500000015016512276242656021364 0ustar rahulrahul# -*- test-case-name: twext.enterprise.test.test_queue -*- ## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ L{twext.enterprise.queue} is an U{eventually consistent } task-queueing system for use by applications with multiple front-end servers talking to a single database instance, that want to defer and parallelize work that involves storing the results of computation. By enqueuing with L{twisted.enterprise.queue}, you may guarantee that the work will I{eventually} be done, and reliably commit to doing it in the future, but defer it if it does not need to be done I{now}. To pick a hypothetical example, let's say that you have a store which wants to issue a promotional coupon based on a customer loyalty program, in response to an administrator clicking on a button. Determining the list of customers to send the coupon to is quick: a simple query will get you all their names. However, analyzing each user's historical purchase data is (A) time consuming and (B) relatively isolated, so it would be good to do that in parallel, and it would also be acceptable to have that happen at a later time, outside the critical path. Such an application might be implemented with this queueing system like so:: from twext.enterprise.queue import WorkItem, queueFromTransaction from twext.enterprise.dal.parseschema import addSQLToSchema from twext.enterprise.dal.syntax import SchemaSyntax schemaModel = Schema() addSQLToSchema(''' create table CUSTOMER (NAME varchar(255), ID integer primary key); create table PRODUCT (NAME varchar(255), ID integer primary key); create table PURCHASE (NAME varchar(255), WHEN timestamp, CUSTOMER_ID integer references CUSTOMER, PRODUCT_ID integer references PRODUCT; create table COUPON_WORK (WORK_ID integer primary key, CUSTOMER_ID integer references CUSTOMER); create table COUPON (ID integer primary key, CUSTOMER_ID integer references customer, AMOUNT integer); ''') schema = SchemaSyntax(schemaModel) class Coupon(Record, fromTable(schema.COUPON_WORK)): pass class CouponWork(WorkItem, fromTable(schema.COUPON_WORK)): @inlineCallbacks def doWork(self): purchases = yield Select(schema.PURCHASE, Where=schema.PURCHASE.CUSTOMER_ID == self.customerID).on(self.transaction) couponAmount = yield doSomeMathThatTakesAWhile(purchases) yield Coupon.create(customerID=self.customerID, amount=couponAmount) @inlineCallbacks def makeSomeCoupons(txn): # Note, txn was started before, will be committed later... for customerID in (yield Select([schema.CUSTOMER.CUSTOMER_ID], From=schema.CUSTOMER).on(txn)): # queuer is a provider of IQueuer, of which there are several # implementations in this module. queuer.enqueueWork(txn, CouponWork, customerID=customerID) """ from functools import wraps from datetime import datetime from zope.interface import implements from twisted.application.service import MultiService from twisted.internet.protocol import Factory from twisted.internet.defer import ( inlineCallbacks, returnValue, Deferred, passthru, succeed ) from twisted.internet.endpoints import TCP4ClientEndpoint from twisted.protocols.amp import AMP, Command, Integer, Argument, String from twisted.python.reflect import qual from twisted.python import log from twext.enterprise.dal.syntax import SchemaSyntax, Lock, NamedValue from twext.enterprise.dal.model import ProcedureCall from twext.enterprise.dal.record import Record, fromTable, NoSuchRecord from twisted.python.failure import Failure from twext.enterprise.dal.model import Table, Schema, SQLType, Constraint from twisted.internet.endpoints import TCP4ServerEndpoint from twext.enterprise.ienterprise import IQueuer from zope.interface.interface import Interface from twext.enterprise.locking import NamedLock class _IWorkPerformer(Interface): """ An object that can perform work. Internal interface; implemented by several classes here since work has to (in the worst case) pass from worker->controller->controller->worker. """ def performWork(table, workID): #@NoSelf """ @param table: The table where work is waiting. @type table: L{TableSyntax} @param workID: The primary key identifier of the given work. @type workID: L{int} @return: a L{Deferred} firing with an empty dictionary when the work is complete. @rtype: L{Deferred} firing L{dict} """ def makeNodeSchema(inSchema): """ Create a self-contained schema for L{NodeInfo} to use, in C{inSchema}. @param inSchema: a L{Schema} to add the node-info table to. @type inSchema: L{Schema} @return: a schema with just the one table. """ # Initializing this duplicate schema avoids a circular dependency, but this # should really be accomplished with independent schema objects that the # transaction is made aware of somehow. NodeTable = Table(inSchema, 'NODE_INFO') NodeTable.addColumn("HOSTNAME", SQLType("varchar", 255)) NodeTable.addColumn("PID", SQLType("integer", None)) NodeTable.addColumn("PORT", SQLType("integer", None)) NodeTable.addColumn("TIME", SQLType("timestamp", None)).setDefaultValue( # Note: in the real data structure, this is actually a not-cleaned-up # sqlparse internal data structure, but it *should* look closer to # this. ProcedureCall("timezone", ["UTC", NamedValue('CURRENT_TIMESTAMP')]) ) for column in NodeTable.columns: NodeTable.tableConstraint(Constraint.NOT_NULL, [column.name]) NodeTable.primaryKey = [NodeTable.columnNamed("HOSTNAME"), NodeTable.columnNamed("PORT")] return inSchema NodeInfoSchema = SchemaSyntax(makeNodeSchema(Schema(__file__))) @inlineCallbacks def inTransaction(transactionCreator, operation): """ Perform the given operation in a transaction, committing or aborting as required. @param transactionCreator: a 0-arg callable that returns an L{IAsyncTransaction} @param operation: a 1-arg callable that takes an L{IAsyncTransaction} and returns a value. @return: a L{Deferred} that fires with C{operation}'s result or fails with its error, unless there is an error creating, aborting or committing the transaction. """ txn = transactionCreator() try: result = yield operation(txn) except: f = Failure() yield txn.abort() returnValue(f) else: yield txn.commit() returnValue(result) def astimestamp(v): """ Convert the given datetime to a POSIX timestamp. """ return (v - datetime.utcfromtimestamp(0)).total_seconds() class TableSyntaxByName(Argument): """ Serialize and deserialize L{TableSyntax} objects for an AMP protocol with an attached schema. """ def fromStringProto(self, inString, proto): """ Convert the name of the table into a table, given a C{proto} with an attached C{schema}. @param inString: the name of a table, as utf-8 encoded bytes @type inString: L{bytes} @param proto: an L{SchemaAMP} """ return getattr(proto.schema, inString.decode("UTF-8")) def toString(self, inObject): """ Convert a L{TableSyntax} object into just its name for wire transport. @param inObject: a table. @type inObject: L{TableSyntax} @return: the name of that table @rtype: L{bytes} """ return inObject.model.name.encode("UTF-8") class NodeInfo(Record, fromTable(NodeInfoSchema.NODE_INFO)): """ A L{NodeInfo} is information about a currently-active Node process. """ def endpoint(self, reactor): """ Create an L{IStreamServerEndpoint} that will talk to the node process that is described by this L{NodeInfo}. @return: an endpoint that will connect to this host. @rtype: L{IStreamServerEndpoint} """ return TCP4ClientEndpoint(reactor, self.hostname, self.port) def abstract(thunk): """ The decorated function is abstract. @note: only methods are currently supported. """ @classmethod @wraps(thunk) def inner(cls, *a, **k): raise NotImplementedError(qual(cls) + " does not implement " + thunk.func_name) return inner class WorkItem(Record): """ A L{WorkItem} is an item of work which may be stored in a database, then executed later. L{WorkItem} is an abstract class, since it is a L{Record} with no table associated via L{fromTable}. Concrete subclasses must associate a specific table by inheriting like so:: class MyWorkItem(WorkItem, fromTable(schema.MY_TABLE)): Concrete L{WorkItem}s should generally not be created directly; they are both created and thereby implicitly scheduled to be executed by calling L{enqueueWork } with the appropriate L{WorkItem} concrete subclass. There are different queue implementations (L{PeerConnectionPool} and L{LocalQueuer}, for example), so the exact timing and location of the work execution may differ. L{WorkItem}s may be constrained in the ordering and timing of their execution, to control concurrency and for performance reasons repsectively. Although all the usual database mutual-exclusion rules apply to work executed in L{WorkItem.doWork}, implicit database row locking is not always the best way to manage concurrency. They have some problems, including: - implicit locks are easy to accidentally acquire out of order, which can lead to deadlocks - implicit locks are easy to forget to acquire correctly - for example, any read operation which subsequently turns into a write operation must have been acquired with C{Select(..., ForUpdate=True)}, but it is difficult to consistently indicate that methods which abstract out read operations must pass this flag in certain cases and not others. - implicit locks are held until the transaction ends, which means that if expensive (long-running) queue operations share the same lock with cheap (short-running) queue operations or user interactions, the cheap operations all have to wait for the expensive ones to complete, but continue to consume whatever database resources they were using. In order to ameliorate these problems with potentiallly concurrent work that uses the same resources, L{WorkItem} provides a database-wide mutex that is automatically acquired at the beginning of the transaction and released at the end. To use it, simply L{align } the C{group} attribute on your L{WorkItem} subclass with a column holding a string (varchar). L{WorkItem} subclasses with the same value for C{group} will not execute their C{doWork} methods concurrently. Furthermore, if the lock cannot be quickly acquired, database resources associated with the transaction attempting it will be released, and the transaction rolled back until a future transaction I{can} can acquire it quickly. If you do not want any limits to concurrency, simply leave it set to C{None}. In some applications it's possible to coalesce work together; to grab multiple L{WorkItem}s in one C{doWork} transaction. All you need to do is to delete the rows which back other L{WorkItem}s from the database, and they won't be processed. Using the C{group} attribute, you can easily prevent concurrency so that you can easily group these items together and remove them as a set (otherwise, other workers might be attempting to concurrently work on them and you'll get deletion errors). However, if doing more work at once is less expensive, and you want to avoid processing lots of individual rows in tiny transactions, you may also delay the execution of a L{WorkItem} by setting its C{notBefore} attribute. This must be backed by a database timestamp, so that processes which happen to be restarting and examining the work to be done in the database don't jump the gun and do it too early. @cvar workID: the unique identifier (primary key) for items of this type. On an instance of a concrete L{WorkItem} subclass, this attribute must be an integer; on the concrete L{WorkItem} subclass itself, this attribute must be a L{twext.enterprise.dal.syntax.ColumnSyntax}. Note that this is automatically taken care of if you simply have a corresponding C{work_id} column in the associated L{fromTable} on your L{WorkItem} subclass. This column must be unique, and it must be an integer. In almost all cases, this column really ought to be filled out by a database-defined sequence; if not, you need some other mechanism for establishing a cluster-wide sequence. @type workID: L{int} on instance, L{twext.enterprise.dal.syntax.ColumnSyntax} on class. @cvar notBefore: the timestamp before which this item should I{not} be processed. If unspecified, this should be the date and time of the creation of the L{WorkItem}. @type notBefore: L{datetime.datetime} on instance, L{twext.enterprise.dal.syntax.ColumnSyntax} on class. @ivar group: If not C{None}, a unique-to-the-database identifier for which only one L{WorkItem} will execute at a time. @type group: L{unicode} or L{NoneType} """ group = None @abstract def doWork(self): """ Subclasses must implement this to actually perform the queued work. This method will be invoked in a worker process. This method does I{not} need to delete the row referencing it; that will be taken care of by the job queueing machinery. """ @classmethod def forTable(cls, table): """ Look up a work-item class given a particular L{TableSyntax}. Factoring this correctly may place it into L{twext.enterprise.record.Record} instead; it is probably generally useful to be able to look up a mapped class from a table. @param table: the table to look up @type table: L{twext.enterprise.dal.model.Table} @return: the relevant subclass @rtype: L{type} """ tableName = table.model.name for subcls in cls.__subclasses__(): clstable = getattr(subcls, "table", None) if table == clstable: return subcls raise KeyError("No mapped {0} class for {1}.".format( cls, tableName )) class PerformWork(Command): """ Notify another process that it must do some work that has been persisted to the database, by informing it of the table and the ID where said work has been persisted. """ arguments = [ ("table", TableSyntaxByName()), ("workID", Integer()), ] response = [] class ReportLoad(Command): """ Notify another node of the total, current load for this whole node (all of its workers). """ arguments = [ ("load", Integer()) ] response = [] class IdentifyNode(Command): """ Identify this node to its peer. The connector knows which hostname it's looking for, and which hostname it considers itself to be, only the initiator (not the listener) issues this command. This command is necessary because we don't want to rely on DNS; if reverse DNS weren't set up perfectly, the listener would not be able to identify its peer, and it is easier to modify local configuration so that L{socket.getfqdn} returns the right value than to ensure that DNS doesself. """ arguments = [ ("host", String()), ("port", Integer()), ] class SchemaAMP(AMP): """ An AMP instance which also has a L{Schema} attached to it. @ivar schema: The schema to look up L{TableSyntaxByName} arguments in. @type schema: L{Schema} """ def __init__(self, schema, boxReceiver=None, locator=None): self.schema = schema super(SchemaAMP, self).__init__(boxReceiver, locator) class ConnectionFromPeerNode(SchemaAMP): """ A connection to a peer node. Symmetric; since the 'client' and the 'server' both serve the same role, the logic is the same in every node. @ivar localWorkerPool: the pool of local worker procesess that can process queue work. @type localWorkerPool: L{WorkerConnectionPool} @ivar _reportedLoad: The number of outstanding requests being processed by the peer of this connection, from all requestors (both the host of this connection and others), as last reported by the most recent L{ReportLoad} message received from the peer. @type _reportedLoad: L{int} @ivar _bonusLoad: The number of additional outstanding requests being processed by the peer of this connection; the number of requests made by the host of this connection since the last L{ReportLoad} message. @type _bonusLoad: L{int} """ implements(_IWorkPerformer) def __init__(self, peerPool, boxReceiver=None, locator=None): """ Initialize this L{ConnectionFromPeerNode} with a reference to a L{PeerConnectionPool}, as well as required initialization arguments for L{AMP}. @param peerPool: the connection pool within which this L{ConnectionFromPeerNode} is a participant. @type peerPool: L{PeerConnectionPool} @see: L{AMP.__init__} """ self.peerPool = peerPool self._bonusLoad = 0 self._reportedLoad = 0 super(ConnectionFromPeerNode, self).__init__(peerPool.schema, boxReceiver, locator) def reportCurrentLoad(self): """ Report the current load for the local worker pool to this peer. """ return self.callRemote(ReportLoad, load=self.totalLoad()) @ReportLoad.responder def reportedLoad(self, load): """ The peer reports its load. """ self._reportedLoad = (load - self._bonusLoad) return {} def startReceivingBoxes(self, sender): """ Connection is up and running; add this to the list of active peers. """ r = super(ConnectionFromPeerNode, self).startReceivingBoxes(sender) self.peerPool.addPeerConnection(self) return r def stopReceivingBoxes(self, reason): """ The connection has shut down; remove this from the list of active peers. """ self.peerPool.removePeerConnection(self) r = super(ConnectionFromPeerNode, self).stopReceivingBoxes(reason) return r def currentLoadEstimate(self): """ What is the current load estimate for this peer? @return: The number of full "slots", i.e. currently-being-processed queue items (and other items which may contribute to this process's load, such as currently-being-processed client requests). @rtype: L{int} """ return self._reportedLoad + self._bonusLoad def performWork(self, table, workID): """ A L{local worker connection } is asking this specific peer node-controller process to perform some work, having already determined that it's appropriate. @see: L{_IWorkPerformer.performWork} """ d = self.callRemote(PerformWork, table=table, workID=workID) self._bonusLoad += 1 @d.addBoth def performed(result): self._bonusLoad -= 1 return result @d.addCallback def success(result): return None return d @PerformWork.responder def dispatchToWorker(self, table, workID): """ A remote peer node has asked this node to do some work; dispatch it to a local worker on this node. @param table: the table to work on. @type table: L{TableSyntax} @param workID: the identifier within the table. @type workID: L{int} @return: a L{Deferred} that fires when the work has been completed. """ return self.peerPool.performWorkForPeer(table, workID).addCallback( lambda ignored: {} ) @IdentifyNode.responder def identifyPeer(self, host, port): self.peerPool.mapPeer(host, port, self) return {} class WorkerConnectionPool(object): """ A pool of L{ConnectionFromWorker}s. L{WorkerConnectionPool} also implements the same implicit protocol as a L{ConnectionFromPeerNode}, but one that dispenses work to the local worker processes rather than to a remote connection pool. """ implements(_IWorkPerformer) def __init__(self, maximumLoadPerWorker=5): self.workers = [] self.maximumLoadPerWorker = maximumLoadPerWorker def addWorker(self, worker): """ Add a L{ConnectionFromWorker} to this L{WorkerConnectionPool} so that it can be selected. """ self.workers.append(worker) def removeWorker(self, worker): """ Remove a L{ConnectionFromWorker} from this L{WorkerConnectionPool} that was previously added. """ self.workers.remove(worker) def hasAvailableCapacity(self): """ Does this worker connection pool have any local workers who have spare hasAvailableCapacity to process another queue item? """ for worker in self.workers: if worker.currentLoad < self.maximumLoadPerWorker: return True return False def allWorkerLoad(self): """ The total load of all currently connected workers. """ return sum(worker.currentLoad for worker in self.workers) def _selectLowestLoadWorker(self): """ Select the local connection with the lowest current load, or C{None} if all workers are too busy. @return: a worker connection with the lowest current load. @rtype: L{ConnectionFromWorker} """ return sorted(self.workers[:], key=lambda w: w.currentLoad)[0] def performWork(self, table, workID): """ Select a local worker that is idle enough to perform the given work, then ask them to perform it. @param table: The table where work is waiting. @type table: L{TableSyntax} @param workID: The primary key identifier of the given work. @type workID: L{int} @return: a L{Deferred} firing with an empty dictionary when the work is complete. @rtype: L{Deferred} firing L{dict} """ preferredWorker = self._selectLowestLoadWorker() result = preferredWorker.performWork(table, workID) return result class ConnectionFromWorker(SchemaAMP): """ An individual connection from a worker, as seem from the master's perspective. L{ConnectionFromWorker}s go into a L{WorkerConnectionPool}. """ def __init__(self, peerPool, boxReceiver=None, locator=None): super(ConnectionFromWorker, self).__init__(peerPool.schema, boxReceiver, locator) self.peerPool = peerPool self._load = 0 @property def currentLoad(self): """ What is the current load of this worker? """ return self._load def startReceivingBoxes(self, sender): """ Start receiving AMP boxes from the peer. Initialize all necessary state. """ result = super(ConnectionFromWorker, self).startReceivingBoxes(sender) self.peerPool.workerPool.addWorker(self) return result def stopReceivingBoxes(self, reason): """ AMP boxes will no longer be received. """ result = super(ConnectionFromWorker, self).stopReceivingBoxes(reason) self.peerPool.workerPool.removeWorker(self) return result @PerformWork.responder def performWork(self, table, workID): """ Dispatch work to this worker. @see: The responder for this should always be L{ConnectionFromController.actuallyReallyExecuteWorkHere}. """ d = self.callRemote(PerformWork, table=table, workID=workID) self._load += 1 @d.addBoth def f(result): self._load -= 1 return result return d class ConnectionFromController(SchemaAMP): """ A L{ConnectionFromController} is the connection to a node-controller process, in a worker process. It processes requests from its own controller to do work. It is the opposite end of the connection from L{ConnectionFromWorker}. """ implements(IQueuer) def __init__(self, transactionFactory, schema, whenConnected, boxReceiver=None, locator=None): super(ConnectionFromController, self).__init__(schema, boxReceiver, locator) self.transactionFactory = transactionFactory self.whenConnected = whenConnected # FIXME: Glyph it appears WorkProposal expects this to have reactor... from twisted.internet import reactor self.reactor = reactor def startReceivingBoxes(self, sender): super(ConnectionFromController, self).startReceivingBoxes(sender) self.whenConnected(self) def choosePerformer(self): """ To conform with L{WorkProposal}'s expectations, which may run in either a controller (against a L{PeerConnectionPool}) or in a worker (against a L{ConnectionFromController}), this is implemented to always return C{self}, since C{self} is also an object that has a C{performWork} method. """ return self def performWork(self, table, workID): """ Ask the controller to perform some work on our behalf. """ return self.callRemote(PerformWork, table=table, workID=workID) def enqueueWork(self, txn, workItemType, **kw): """ There is some work to do. Do it, ideally someplace else, ideally in parallel. Later, let the caller know that the work has been completed by firing a L{Deferred}. @param workItemType: The type of work item to be enqueued. @type workItemType: A subtype of L{WorkItem} @param kw: The parameters to construct a work item. @type kw: keyword parameters to C{workItemType.create}, i.e. C{workItemType.__init__} @return: an object that can track the enqueuing and remote execution of this work. @rtype: L{WorkProposal} """ wp = WorkProposal(self, txn, workItemType, kw) wp._start() return wp @PerformWork.responder def actuallyReallyExecuteWorkHere(self, table, workID): """ This is where it's time to actually do the work. The controller process has instructed this worker to do it; so, look up the data in the row, and do it. """ return (ultimatelyPerform(self.transactionFactory, table, workID) .addCallback(lambda ignored: {})) def ultimatelyPerform(txnFactory, table, workID): """ Eventually, after routing the work to the appropriate place, somebody actually has to I{do} it. @param txnFactory: a 0- or 1-argument callable that creates an L{IAsyncTransaction} @type txnFactory: L{callable} @param table: the table object that corresponds to the necessary work item @type table: L{twext.enterprise.dal.syntax.TableSyntax} @param workID: the ID of the work to be performed @type workID: L{int} @return: a L{Deferred} which fires with C{None} when the work has been performed, or fails if the work can't be performed. """ @inlineCallbacks def work(txn): workItemClass = WorkItem.forTable(table) try: workItem = yield workItemClass.load(txn, workID) if workItem.group is not None: yield NamedLock.acquire(txn, workItem.group) # TODO: what if we fail? error-handling should be recorded # someplace, the row should probably be marked, re-tries should be # triggerable administratively. yield workItem.delete() # TODO: verify that workID is the primary key someplace. yield workItem.doWork() except NoSuchRecord: # The record has already been removed pass return inTransaction(txnFactory, work) class LocalPerformer(object): """ Implementor of C{performWork} that does its work in the local process, regardless of other conditions. """ implements(_IWorkPerformer) def __init__(self, txnFactory): """ Create this L{LocalPerformer} with a transaction factory. """ self.txnFactory = txnFactory def performWork(self, table, workID): """ Perform the given work right now. """ return ultimatelyPerform(self.txnFactory, table, workID) class WorkerFactory(Factory, object): """ Factory, to be used as the client to connect from the worker to the controller. """ def __init__(self, transactionFactory, schema, whenConnected): """ Create a L{WorkerFactory} with a transaction factory and a schema. """ self.transactionFactory = transactionFactory self.schema = schema self.whenConnected = whenConnected def buildProtocol(self, addr): """ Create a L{ConnectionFromController} connected to the transactionFactory and store. """ return ConnectionFromController(self.transactionFactory, self.schema, self.whenConnected) class TransactionFailed(Exception): """ A transaction failed. """ def _cloneDeferred(d): """ Make a new Deferred, adding callbacks to C{d}. @return: another L{Deferred} that fires with C{d's} result when C{d} fires. @rtype: L{Deferred} """ d2 = Deferred() d.chainDeferred(d2) return d2 class WorkProposal(object): """ A L{WorkProposal} is a proposal for work that will be executed, perhaps on another node, perhaps in the future. @ivar _chooser: The object which will choose where the work in this proposal gets performed. This must have both a C{choosePerformer} method and a C{reactor} attribute, providing an L{IReactorTime}. @type _chooser: L{PeerConnectionPool} or L{LocalQueuer} @ivar txn: The transaction where the work will be enqueued. @type txn: L{IAsyncTransaction} @ivar workItemType: The type of work to be enqueued by this L{WorkProposal} @type workItemType: L{WorkItem} subclass @ivar kw: The keyword arguments to pass to C{self.workItemType.create} to construct it. @type kw: L{dict} """ def __init__(self, chooser, txn, workItemType, kw): self._chooser = chooser self.txn = txn self.workItemType = workItemType self.kw = kw self._whenProposed = Deferred() self._whenExecuted = Deferred() self._whenCommitted = Deferred() def _start(self): """ Execute this L{WorkProposal} by creating the work item in the database, waiting for the transaction where that addition was completed to commit, and asking the local node controller process to do the work. """ created = self.workItemType.create(self.txn, **self.kw) def whenCreated(item): self._whenProposed.callback(self) @self.txn.postCommit def whenDone(): self._whenCommitted.callback(self) def maybeLater(): performer = self._chooser.choosePerformer() @passthru(performer.performWork(item.table, item.workID) .addCallback) def performed(result): self._whenExecuted.callback(self) @performed.addErrback def notPerformed(why): self._whenExecuted.errback(why) reactor = self._chooser.reactor when = max(0, astimestamp(item.notBefore) - reactor.seconds()) # TODO: Track the returned DelayedCall so it can be stopped # when the service stops. self._chooser.reactor.callLater(when, maybeLater) @self.txn.postAbort def whenFailed(): self._whenCommitted.errback(TransactionFailed) def whenNotCreated(failure): self._whenProposed.errback(failure) created.addCallbacks(whenCreated, whenNotCreated) def whenExecuted(self): """ Let the caller know when the proposed work has been fully executed. @note: The L{Deferred} returned by C{whenExecuted} should be used with extreme caution. If an application decides to do any database-persistent work as a result of this L{Deferred} firing, that work I{may be lost} as a result of a service being normally shut down between the time that the work is scheduled and the time that it is executed. So, the only things that should be added as callbacks to this L{Deferred} are those which are ephemeral, in memory, and reflect only presentation state associated with the user's perception of the completion of work, not logical chains of work which need to be completed in sequence; those should all be completed within the transaction of the L{WorkItem.doWork} that gets executed. @return: a L{Deferred} that fires with this L{WorkProposal} when the work has been completed remotely. """ return _cloneDeferred(self._whenExecuted) def whenProposed(self): """ Let the caller know when the work has been proposed; i.e. when the work is first transmitted to the database. @return: a L{Deferred} that fires with this L{WorkProposal} when the relevant commands have been sent to the database to create the L{WorkItem}, and fails if those commands do not succeed for some reason. """ return _cloneDeferred(self._whenProposed) def whenCommitted(self): """ Let the caller know when the work has been committed to; i.e. when the transaction where the work was proposed has been committed to the database. @return: a L{Deferred} that fires with this L{WorkProposal} when the relevant transaction has been committed, or fails if the transaction is not committed for any reason. """ return _cloneDeferred(self._whenCommitted) class _BaseQueuer(object): implements(IQueuer) def __init__(self): super(_BaseQueuer, self).__init__() self.proposalCallbacks = set() def callWithNewProposals(self, callback): self.proposalCallbacks.add(callback) def transferProposalCallbacks(self, newQueuer): newQueuer.proposalCallbacks = self.proposalCallbacks return newQueuer def enqueueWork(self, txn, workItemType, **kw): """ There is some work to do. Do it, someplace else, ideally in parallel. Later, let the caller know that the work has been completed by firing a L{Deferred}. @param workItemType: The type of work item to be enqueued. @type workItemType: A subtype of L{WorkItem} @param kw: The parameters to construct a work item. @type kw: keyword parameters to C{workItemType.create}, i.e. C{workItemType.__init__} @return: an object that can track the enqueuing and remote execution of this work. @rtype: L{WorkProposal} """ wp = WorkProposal(self, txn, workItemType, kw) wp._start() for callback in self.proposalCallbacks: callback(wp) return wp class PeerConnectionPool(_BaseQueuer, MultiService, object): """ Each node has a L{PeerConnectionPool} connecting it to all the other nodes currently active on the same database. @ivar hostname: The hostname where this node process is running, as reported by the local host's configuration. Possibly this should be obtained via C{config.ServerHostName} instead of C{socket.getfqdn()}; although hosts within a cluster may be configured with the same C{ServerHostName}; TODO need to confirm. @type hostname: L{bytes} @ivar thisProcess: a L{NodeInfo} representing this process, which is initialized when this L{PeerConnectionPool} service is started via C{startService}. May be C{None} if this service is not fully started up or if it is shutting down. @type thisProcess: L{NodeInfo} @ivar queueProcessTimeout: The amount of time after a L{WorkItem} is scheduled to be processed (its C{notBefore} attribute) that it is considered to be "orphaned" and will be run by a lost-work check rather than waiting for it to be requested. By default, 10 minutes. @type queueProcessTimeout: L{float} (in seconds) @ivar queueDelayedProcessInterval: The amount of time between database pings, i.e. checks for over-due queue items that might have been orphaned by a controller process that died mid-transaction. This is how often the shared database should be pinged by I{all} nodes (i.e., all controller processes, or each instance of L{PeerConnectionPool}); each individual node will ping commensurately less often as more nodes join the database. @type queueDelayedProcessInterval: L{float} (in seconds) @ivar reactor: The reactor used for scheduling timed events. @type reactor: L{IReactorTime} provider. @ivar peers: The list of currently connected peers. @type peers: L{list} of L{PeerConnectionPool} """ implements(IQueuer) from socket import getfqdn from os import getpid getfqdn = staticmethod(getfqdn) getpid = staticmethod(getpid) queueProcessTimeout = (10.0 * 60.0) queueDelayedProcessInterval = (60.0) def __init__(self, reactor, transactionFactory, ampPort, schema): """ Initialize a L{PeerConnectionPool}. @param ampPort: The AMP TCP port number to listen on for inter-host communication. This must be an integer (and not, say, an endpoint, or an endpoint description) because we need to communicate it to the other peers in the cluster in a way that will be meaningful to them as clients. @type ampPort: L{int} @param transactionFactory: a 0- or 1-argument callable that produces an L{IAsyncTransaction} @param schema: The schema which contains all the tables associated with the L{WorkItem}s that this L{PeerConnectionPool} will process. @type schema: L{Schema} """ super(PeerConnectionPool, self).__init__() self.reactor = reactor self.transactionFactory = transactionFactory self.hostname = self.getfqdn() self.pid = self.getpid() self.ampPort = ampPort self.thisProcess = None self.workerPool = WorkerConnectionPool() self.peers = [] self.mappedPeers = {} self.schema = schema self._startingUp = None self._listeningPort = None self._lastSeenTotalNodes = 1 self._lastSeenNodeIndex = 1 def addPeerConnection(self, peer): """ Add a L{ConnectionFromPeerNode} to the active list of peers. """ self.peers.append(peer) def totalLoad(self): return self.workerPool.allWorkerLoad() def workerListenerFactory(self): """ Factory that listens for connections from workers. """ f = Factory() f.buildProtocol = lambda addr: ConnectionFromWorker(self) return f def removePeerConnection(self, peer): """ Remove a L{ConnectionFromPeerNode} to the active list of peers. """ self.peers.remove(peer) def choosePerformer(self, onlyLocally=False): """ Choose a peer to distribute work to based on the current known slot occupancy of the other nodes. Note that this will prefer distributing work to local workers until the current node is full, because that should be lower-latency. Also, if no peers are available, work will be submitted locally even if the worker pool is already over-subscribed. @return: the chosen peer. @rtype: L{_IWorkPerformer} L{ConnectionFromPeerNode} or L{WorkerConnectionPool} """ if self.workerPool.hasAvailableCapacity(): return self.workerPool if self.peers and not onlyLocally: return sorted(self.peers, key=lambda p: p.currentLoadEstimate())[0] else: return LocalPerformer(self.transactionFactory) def performWorkForPeer(self, table, workID): """ A peer has requested us to perform some work; choose a work performer local to this node, and then execute it. """ performer = self.choosePerformer(onlyLocally=True) return performer.performWork(table, workID) def allWorkItemTypes(self): """ Load all the L{WorkItem} types that this node can process and return them. @return: L{list} of L{type} """ # TODO: For completeness, this may need to involve a plugin query to # make sure that all WorkItem subclasses are imported first. for workItemSubclass in WorkItem.__subclasses__(): # TODO: It might be a good idea to offload this table-filtering to # SchemaSyntax.__contains__, adding in some more structure- # comparison of similarly-named tables. For now a name check is # sufficient. if workItemSubclass.table.model.name in set([x.model.name for x in self.schema]): yield workItemSubclass def totalNumberOfNodes(self): """ How many nodes are there, total? @return: the maximum number of other L{PeerConnectionPool} instances that may be connected to the database described by C{self.transactionFactory}. Note that this is not the current count by connectivity, but the count according to the database. @rtype: L{int} """ # TODO return self._lastSeenTotalNodes def nodeIndex(self): """ What ordinal does this node, i.e. this instance of L{PeerConnectionPool}, occupy within the ordered set of all nodes connected to the database described by C{self.transactionFactory}? @return: the index of this node within the total collection. For example, if this L{PeerConnectionPool} is 6 out of 30, this method will return C{6}. @rtype: L{int} """ # TODO return self._lastSeenNodeIndex def _periodicLostWorkCheck(self): """ Periodically, every node controller has to check to make sure that work hasn't been dropped on the floor by someone. In order to do that it queries each work-item table. """ @inlineCallbacks def workCheck(txn): if self.thisProcess: nodes = [(node.hostname, node.port) for node in (yield self.activeNodes(txn))] nodes.sort() self._lastSeenTotalNodes = len(nodes) self._lastSeenNodeIndex = nodes.index( (self.thisProcess.hostname, self.thisProcess.port) ) for itemType in self.allWorkItemTypes(): tooLate = datetime.utcfromtimestamp( self.reactor.seconds() - self.queueProcessTimeout ) overdueItems = (yield itemType.query( txn, (itemType.notBefore < tooLate)) ) for overdueItem in overdueItems: peer = self.choosePerformer() yield peer.performWork(overdueItem.table, overdueItem.workID) return inTransaction(self.transactionFactory, workCheck) _currentWorkDeferred = None _lostWorkCheckCall = None def _lostWorkCheckLoop(self): """ While the service is running, keep checking for any overdue / lost work items and re-submit them to the cluster for processing. Space out those checks in time based on the size of the cluster. """ self._lostWorkCheckCall = None @passthru(self._periodicLostWorkCheck().addErrback(log.err) .addCallback) def scheduleNext(result): self._currentWorkDeferred = None if not self.running: return index = self.nodeIndex() now = self.reactor.seconds() interval = self.queueDelayedProcessInterval count = self.totalNumberOfNodes() when = (now - (now % interval)) + (interval * (count + index)) delay = when - now self._lostWorkCheckCall = self.reactor.callLater( delay, self._lostWorkCheckLoop ) self._currentWorkDeferred = scheduleNext def startService(self): """ Register ourselves with the database and establish all outgoing connections to other servers in the cluster. """ @inlineCallbacks def startup(txn): endpoint = TCP4ServerEndpoint(self.reactor, self.ampPort) # If this fails, the failure mode is going to be ugly, just like # all conflicted-port failures. But, at least it won't proceed. self._listeningPort = yield endpoint.listen(self.peerFactory()) self.ampPort = self._listeningPort.getHost().port yield Lock.exclusive(NodeInfo.table).on(txn) nodes = yield self.activeNodes(txn) selves = [node for node in nodes if ((node.hostname == self.hostname) and (node.port == self.ampPort))] if selves: self.thisProcess = selves[0] nodes.remove(self.thisProcess) yield self.thisProcess.update(pid=self.pid, time=datetime.now()) else: self.thisProcess = yield NodeInfo.create( txn, hostname=self.hostname, port=self.ampPort, pid=self.pid, time=datetime.now() ) for node in nodes: self._startConnectingTo(node) self._startingUp = inTransaction(self.transactionFactory, startup) @self._startingUp.addBoth def done(result): self._startingUp = None super(PeerConnectionPool, self).startService() self._lostWorkCheckLoop() return result @inlineCallbacks def stopService(self): """ Stop this service, terminating any incoming or outgoing connections. """ yield super(PeerConnectionPool, self).stopService() if self._startingUp is not None: yield self._startingUp if self._listeningPort is not None: yield self._listeningPort.stopListening() if self._lostWorkCheckCall is not None: self._lostWorkCheckCall.cancel() if self._currentWorkDeferred is not None: yield self._currentWorkDeferred for peer in self.peers: peer.transport.abortConnection() def activeNodes(self, txn): """ Load information about all other nodes. """ return NodeInfo.all(txn) def mapPeer(self, host, port, peer): """ A peer has been identified as belonging to the given host/port combination. Disconnect any other peer that claims to be connected for the same peer. """ # if (host, port) in self.mappedPeers: # TODO: think about this for race conditions # self.mappedPeers.pop((host, port)).transport.loseConnection() self.mappedPeers[(host, port)] = peer def _startConnectingTo(self, node): """ Start an outgoing connection to another master process. @param node: a description of the master to connect to. @type node: L{NodeInfo} """ connected = node.endpoint(self.reactor).connect(self.peerFactory()) def whenConnected(proto): self.mapPeer(node.hostname, node.port, proto) proto.callRemote(IdentifyNode, host=self.thisProcess.hostname, port=self.thisProcess.port).addErrback( noted, "identify" ) def noted(err, x="connect"): log.msg("Could not {0} to cluster peer {1} because {2}" .format(x, node, str(err.value))) connected.addCallbacks(whenConnected, noted) def peerFactory(self): """ Factory for peer connections. @return: a L{Factory} that will produce L{ConnectionFromPeerNode} protocols attached to this L{PeerConnectionPool}. """ return _PeerPoolFactory(self) class _PeerPoolFactory(Factory, object): """ Protocol factory responsible for creating L{ConnectionFromPeerNode} connections, both client and server. """ def __init__(self, peerConnectionPool): self.peerConnectionPool = peerConnectionPool def buildProtocol(self, addr): return ConnectionFromPeerNode(self.peerConnectionPool) class LocalQueuer(_BaseQueuer): """ When work is enqueued with this queuer, it is just executed locally. """ implements(IQueuer) def __init__(self, txnFactory, reactor=None): super(LocalQueuer, self).__init__() self.txnFactory = txnFactory if reactor is None: from twisted.internet import reactor self.reactor = reactor def choosePerformer(self): """ Choose to perform the work locally. """ return LocalPerformer(self.txnFactory) class NonPerformer(object): """ Implementor of C{performWork} that doesn't actual perform any work. This is used in the case where you want to be able to enqueue work for someone else to do, but not take on any work yourself (such as a command line tool). """ implements(_IWorkPerformer) def performWork(self, table, workID): """ Don't perform work. """ return succeed(None) class NonPerformingQueuer(_BaseQueuer): """ When work is enqueued with this queuer, it is never executed locally. It's expected that the polling machinery will find the work and perform it. """ implements(IQueuer) def __init__(self, reactor=None): super(NonPerformingQueuer, self).__init__() if reactor is None: from twisted.internet import reactor self.reactor = reactor def choosePerformer(self): """ Choose to perform the work locally. """ return NonPerformer() calendarserver-5.2+dfsg/twext/enterprise/__init__.py0000644000175000017500000000133312263343324021755 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extensions in the spirit of Twisted's "enterprise" package; things related to database connectivity and management. """ calendarserver-5.2+dfsg/twext/enterprise/dal/0000755000175000017500000000000012322625326020405 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/enterprise/dal/model.py0000644000175000017500000004017112263343324022061 0ustar rahulrahul# -*- test-case-name: twext.enterprise.dal.test.test_parseschema -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Model classes for SQL. """ from twisted.python.util import FancyEqMixin class SQLType(object): """ A data-type as defined in SQL; like "integer" or "real" or "varchar(255)". @ivar name: the name of this type. @type name: C{str} @ivar length: the length of this type, if it is a type like 'varchar' or 'character' that comes with a parenthetical length. @type length: C{int} or C{NoneType} """ def __init__(self, name, length): _checkstr(name) self.name = name self.length = length def __eq__(self, other): """ Compare equal to other L{SQLTypes} with matching name and length. """ if not isinstance(other, SQLType): return NotImplemented return (self.name, self.length) == (other.name, other.length) def __ne__(self, other): """ (Inequality is the opposite of equality.) """ if not isinstance(other, SQLType): return NotImplemented return not self.__eq__(other) def __repr__(self): """ A useful string representation which includes the name and length if present. """ if self.length: lendesc = '(%s)' % (self.length) else: lendesc = '' return '' % (self.name, lendesc) class Constraint(object): """ A constraint on a set of columns. @ivar type: the type of constraint. Currently, only C{'UNIQUE'} and C{'NOT NULL'} are supported. @type type: C{str} @ivar affectsColumns: Columns affected by this constraint. @type affectsColumns: C{list} of L{Column} """ # Values for 'type' attribute: NOT_NULL = 'NOT NULL' UNIQUE = 'UNIQUE' def __init__(self, type, affectsColumns, name=None): self.affectsColumns = affectsColumns # XXX: possibly different constraint types should have different # classes? self.type = type self.name = name class Check(Constraint): """ A 'check' constraint, which evaluates an SQL expression. @ivar expression: the expression that should evaluate to True. @type expression: L{twext.enterprise.dal.syntax.ExpressionSyntax} """ # XXX TODO: model for expression, rather than def __init__(self, syntaxExpression, name=None): self.expression = syntaxExpression super(Check, self).__init__( 'CHECK', [c.model for c in self.expression.allColumns()], name ) class ProcedureCall(object): """ An invocation of a stored procedure or built-in function. """ def __init__(self, name, args): _checkstr(name) self.name = name self.args = args class NO_DEFAULT(object): """ Placeholder value for not having a default. (C{None} would not be suitable, as that would imply a default of C{NULL}). """ def _checkstr(x): """ Verify that C{x} is a C{str}. Raise a L{ValueError} if not. This is to prevent pollution with unicode values. """ if not isinstance(x, str): raise ValueError("%r is not a str." % (x,)) class Column(FancyEqMixin, object): """ A column from a table. @ivar table: The L{Table} to which this L{Column} belongs. @type table: L{Table} @ivar name: The unqualified name of this column. For example, in the case of a column BAR in a table FOO, this would be the string C{'BAR'}. @type name: C{str} @ivar type: The declared type of this column. @type type: L{SQLType} @ivar references: If this column references a foreign key on another table, this will be a reference to that table; otherwise (normally) C{None}. @type references: L{Table} or C{NoneType} @ivar deleteAction: If this column references another table, home will this column's row be altered when the matching row in that other table is deleted? Possible values are None - for 'on delete no action' 'cascade' - for 'on delete cascade' 'set null' - for 'on delete set null' 'set default' - for 'on delete set default' @type deleteAction: C{bool} """ compareAttributes = 'table name'.split() def __init__(self, table, name, type): _checkstr(name) self.table = table self.name = name self.type = type self.default = NO_DEFAULT self.references = None self.deleteAction = None def __repr__(self): return '' % (self.name, self.type) def compare(self, other): """ Return the differences between two columns. @param other: the column to compare with @type other: L{Column} """ results = [] # TODO: sql_dump does not do types write now - so ignore this # if self.type != other.type: # results.append("Table: %s, mismatched column type: %s" % (self.table.name, self.name)) # TODO: figure out how to compare default, references and deleteAction return results def canBeNull(self): """ Can this column ever be C{NULL}, i.e. C{None}? In other words, is it free of any C{NOT NULL} constraints? @return: C{True} if so, C{False} if not. """ for constraint in self.table.constraints: if self in constraint.affectsColumns: if constraint.type is Constraint.NOT_NULL: return False return True def setDefaultValue(self, value): """ Change the default value of this column. (Should only be called during schema parsing.) """ self.default = value def needsValue(self): """ Does this column require a value in C{INSERT} statements which create rows? @return: C{True} for L{Column}s with no default specified which also cannot be NULL, C{False} otherwise. @rtype: C{bool} """ return not (self.canBeNull() or (self.default not in (None, NO_DEFAULT))) def doesReferenceName(self, name): """ Change this column to refer to a table in the schema. (Should only be called during schema parsing.) @param name: the name of a L{Table} in this L{Column}'s L{Schema}. @type name: L{str} """ self.references = self.table.schema.tableNamed(name) class Table(FancyEqMixin, object): """ A set of columns. @ivar descriptiveComment: A docstring for the table. Parsed from a '--' comment preceding this table in the SQL schema file that was parsed, if any. @type descriptiveComment: C{str} @ivar schema: a reference to the L{Schema} to which this table belongs. @ivar primaryKey: a C{list} of L{Column} objects representing the primary key of this table, or C{None} if no primary key has been specified. """ compareAttributes = 'schema name'.split() def __init__(self, schema, name): _checkstr(name) self.descriptiveComment = '' self.schema = schema self.name = name self.columns = [] self.constraints = [] self.schemaRows = [] self.primaryKey = None self.schema.tables.append(self) def __repr__(self): return '' % (self.name, self.columns) def compare(self, other): """ Return the differences between two tables. @param other: the table to compare with @type other: L{Table} """ results = [] myColumns = dict([(item.name.lower(), item) for item in self.columns]) otherColumns = dict([(item.name.lower(), item) for item in other.columns]) for item in set(myColumns.keys()) ^ set(otherColumns.keys()): results.append("Table: %s, missing column: %s" % (self.name, item,)) for name in set(myColumns.keys()) & set(otherColumns.keys()): results.extend(myColumns[name].compare(otherColumns[name])) # TODO: figure out how to compare schemaRows return results def columnNamed(self, name): """ Retrieve a column from this table with a given name. @raise KeyError: if no such table exists. @return: a column @rtype: L{Column} """ for column in self.columns: if column.name == name: return column raise KeyError("no such column: %r" % (name,)) def addColumn(self, name, type): """ A new column was parsed for this table. @param name: The unqualified name of the column. @type name: C{str} @param type: The L{SQLType} describing the column's type. """ column = Column(self, name, type) self.columns.append(column) return column def tableConstraint(self, constraintType, columnNames): """ This table is affected by a constraint. (Should only be called during schema parsing.) @param constraintType: the type of constraint; either L{Constraint.NOT_NULL} or L{Constraint.UNIQUE}, currently. """ affectsColumns = [] for name in columnNames: affectsColumns.append(self.columnNamed(name)) self.constraints.append(Constraint(constraintType, affectsColumns)) def checkConstraint(self, protoExpression, name=None): """ This table is affected by a 'check' constraint. (Should only be called during schema parsing.) @param protoExpression: proto expression. """ self.constraints.append(Check(protoExpression, name)) def insertSchemaRow(self, values): """ A statically-defined row was inserted as part of the schema itself. This is used for tables that want to track static enumerations, for example, but want to be referred to by a foreign key in other tables for proper referential integrity. Append this data to this L{Table}'s L{Table.schemaRows}. (Should only be called during schema parsing.) @param values: a C{list} of data items, one for each column in this table's current list of L{Column}s. """ row = {} for column, value in zip(self.columns, values): row[column] = value self.schemaRows.append(row) def addComment(self, comment): """ Add a comment to C{descriptiveComment}. @param comment: some additional descriptive text @type comment: C{str} """ self.descriptiveComment = comment def uniques(self): """ Get the groups of unique columns for this L{Table}. @return: an iterable of C{list}s of C{Column}s which are unique within this table. """ for constraint in self.constraints: if constraint.type is Constraint.UNIQUE: yield list(constraint.affectsColumns) class Index(object): """ An L{Index} is an SQL index. """ def __init__(self, schema, name, table, unique=False): self.name = name self.table = table self.unique = unique self.columns = [] schema.indexes.append(self) def addColumn(self, column): self.columns.append(column) class PseudoIndex(object): """ A class used to represent explicit and implicit indexes. An implicit index is one the DB creates for primary key and unique columns in a table. An explicit index is one created by a CREATE [UNIQUE] INDEX statement. Because the name of an implicit index is implementation defined, instead we create a name based on the table name, uniqueness and column names. """ def __init__(self, table, columns, unique=False): self.name = "%s%s:(%s)" % (table.name, "-unique" if unique else "", ",".join([col.name for col in columns])) self.table = table self.unique = unique self.columns = columns def compare(self, other): """ Return the differences between two indexes. @param other: the index to compare with @type other: L{Index} """ # Nothing to do as name comparison will catch differences return [] class Sequence(FancyEqMixin, object): """ A sequence object. """ compareAttributes = 'name'.split() def __init__(self, schema, name): _checkstr(name) self.name = name self.referringColumns = [] schema.sequences.append(self) def __repr__(self): return '' % (self.name,) def compare(self, other): """ Return the differences between two sequences. @param other: the sequence to compare with @type other: L{Sequence} """ # TODO: figure out whether to compare referringColumns attribute return [] def _namedFrom(name, sequence): """ Retrieve an item with a given name attribute from a given sequence, or raise a L{KeyError}. """ for item in sequence: if item.name == name: return item raise KeyError(name) class Schema(object): """ A schema containing tables, indexes, and sequences. """ def __init__(self, filename=''): self.filename = filename self.tables = [] self.indexes = [] self.sequences = [] def __repr__(self): return '' % (self.filename,) def compare(self, other): """ Return the differences between two schemas. @param other: the schema to compare with @type other: L{Schema} """ results = [] def _compareLists(list1, list2, descriptor): myItems = dict([(item.name.lower()[:63], item) for item in list1]) otherItems = dict([(item.name.lower()[:63], item) for item in list2]) for item in set(myItems.keys()) - set(otherItems.keys()): results.append("Schema: %s, missing %s: %s" % (other.filename, descriptor, item,)) for item in set(otherItems.keys()) - set(myItems.keys()): results.append("Schema: %s, missing %s: %s" % (self.filename, descriptor, item,)) for name in set(myItems.keys()) & set(otherItems.keys()): results.extend(myItems[name].compare(otherItems[name])) _compareLists(self.tables, other.tables, "table") _compareLists(self.pseudoIndexes(), other.pseudoIndexes(), "index") _compareLists(self.sequences, other.sequences, "sequence") return results def pseudoIndexes(self): """ Return a set of indexes that include "implicit" indexes from table/column constraints. The name of the index is formed from the table name and then list of columns. """ results = [] # First add the list of explicit indexes we have for index in self.indexes: results.append(PseudoIndex(index.table, index.columns, index.unique)) # Now do implicit index for each table for table in self.tables: if table.primaryKey is not None: results.append(PseudoIndex(table, table.primaryKey, True)) for constraint in table.constraints: if constraint.type == Constraint.UNIQUE: results.append(PseudoIndex(table, constraint.affectsColumns, True)) return results def tableNamed(self, name): return _namedFrom(name, self.tables) def sequenceNamed(self, name): return _namedFrom(name, self.sequences) def indexNamed(self, name): return _namedFrom(name, self.indexes) calendarserver-5.2+dfsg/twext/enterprise/dal/test/0000755000175000017500000000000012322625326021364 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/enterprise/dal/test/test_parseschema.py0000644000175000017500000003242212263343324025272 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for parsing an SQL schema, which cover L{twext.enterprise.dal.model} and L{twext.enterprise.dal.parseschema}. """ from twext.enterprise.dal.model import Schema from twext.enterprise.dal.syntax import CompoundComparison, ColumnSyntax from twext.enterprise.dal.parseschema import addSQLToSchema from twisted.trial.unittest import TestCase class SchemaTestHelper(object): """ Mix-in that can parse a schema from a string. """ def schemaFromString(self, string): """ Createa a L{Schema} """ s = Schema(self.id()) addSQLToSchema(s, string) return s class ParsingExampleTests(TestCase, SchemaTestHelper): """ Tests for parsing some sample schemas. """ def test_simplest(self): """ Parse an extremely simple schema with one table in it. """ s = self.schemaFromString("create table foo (bar integer);") self.assertEquals(len(s.tables), 1) foo = s.tableNamed('foo') self.assertEquals(len(foo.columns), 1) bar = foo.columns[0] self.assertEquals(bar.name, "bar") self.assertEquals(bar.type.name, "integer") def test_stringTypes(self): """ Table and column names should be byte strings. """ s = self.schemaFromString("create table foo (bar integer);") self.assertEquals(len(s.tables), 1) foo = s.tableNamed('foo') self.assertIsInstance(foo.name, str) self.assertIsInstance(foo.columnNamed('bar').name, str) def test_typeWithLength(self): """ Parse a type with a length. """ s = self.schemaFromString("create table foo (bar varchar(6543))") bar = s.tableNamed('foo').columnNamed('bar') self.assertEquals(bar.type.name, "varchar") self.assertEquals(bar.type.length, 6543) def test_sequence(self): """ Parsing a 'create sequence' statement adds a L{Sequence} to the L{Schema}. """ s = self.schemaFromString("create sequence myseq;") self.assertEquals(len(s.sequences), 1) self.assertEquals(s.sequences[0].name, "myseq") def test_sequenceColumn(self): """ Parsing a 'create sequence' statement adds a L{Sequence} to the L{Schema}, and then a table that contains a column which uses the SQL C{nextval()} function to retrieve its default value from that sequence, will cause the L{Column} object to refer to the L{Sequence} and vice versa. """ s = self.schemaFromString( """ create sequence thingy; create table thetable ( thecolumn integer default nextval('thingy') ); """) self.assertEquals(len(s.sequences), 1) self.assertEquals(s.sequences[0].name, "thingy") self.assertEquals(s.tables[0].columns[0].default, s.sequences[0]) self.assertEquals(s.sequences[0].referringColumns, [s.tables[0].columns[0]]) def test_sequenceDefault(self): """ Default sequence column. """ s = self.schemaFromString( """ create sequence alpha; create table foo ( bar integer default nextval('alpha') not null, qux integer not null ); """) self.assertEquals(s.tableNamed("foo").columnNamed("bar").needsValue(), False) def test_sequenceDefaultWithParens(self): """ SQLite requires 'default' expression to be in parentheses, and that should be equivalent on other databases; we should be able to parse that too. """ s = self.schemaFromString( """ create sequence alpha; create table foo ( bar integer default (nextval('alpha')) not null, qux integer not null ); """ ) self.assertEquals(s.tableNamed("foo").columnNamed("bar").needsValue(), False) def test_defaultConstantColumns(self): """ Parsing a 'default' column with an appropriate type in it will return that type as the 'default' attribute of the Column object. """ s = self.schemaFromString( """ create table a ( b integer default 4321, c boolean default false, d boolean default true, e varchar(255) default 'sample value', f varchar(255) default null ); """) table = s.tableNamed("a") self.assertEquals(table.columnNamed("b").default, 4321) self.assertEquals(table.columnNamed("c").default, False) self.assertEquals(table.columnNamed("d").default, True) self.assertEquals(table.columnNamed("e").default, 'sample value') self.assertEquals(table.columnNamed("f").default, None) def test_needsValue(self): """ Columns with defaults, or with a 'not null' constraint don't need a value; columns without one don't. """ s = self.schemaFromString( """ create table a ( b integer default 4321 not null, c boolean default false, d integer not null, e integer ) """) table = s.tableNamed("a") # Has a default, NOT NULL. self.assertEquals(table.columnNamed("b").needsValue(), False) # Has a default _and_ nullable. self.assertEquals(table.columnNamed("c").needsValue(), False) # No default, not nullable. self.assertEquals(table.columnNamed("d").needsValue(), True) # Just nullable. self.assertEquals(table.columnNamed("e").needsValue(), False) def test_notNull(self): """ A column with a NOT NULL constraint in SQL will be parsed as a constraint which returns False from its C{canBeNull()} method. """ s = self.schemaFromString( "create table alpha (beta integer, gamma integer not null);" ) t = s.tableNamed('alpha') self.assertEquals(True, t.columnNamed('beta').canBeNull()) self.assertEquals(False, t.columnNamed('gamma').canBeNull()) def test_unique(self): """ A column with a UNIQUE constraint in SQL will result in the table listing that column as a unique set. """ for identicalSchema in [ "create table sample (example integer unique);", "create table sample (example integer, unique (example));", "create table sample " "(example integer, constraint unique_example unique (example))"]: s = self.schemaFromString(identicalSchema) table = s.tableNamed('sample') column = table.columnNamed('example') self.assertEquals(list(table.uniques()), [[column]]) def test_checkExpressionConstraint(self): """ A column with a CHECK constraint in SQL that uses an inequality will result in a L{Check} constraint being added to the L{Table} object. """ def checkOneConstraint(sqlText, checkName=None): s = self.schemaFromString(sqlText) table = s.tableNamed('sample') self.assertEquals(len(table.constraints), 1) constraint = table.constraints[0] expr = constraint.expression self.assertIsInstance(expr, CompoundComparison) self.assertEqual(expr.a.model, table.columnNamed('example')) self.assertEqual(expr.b.value, 5) self.assertEqual(expr.op, '>') self.assertEqual(constraint.name, checkName) checkOneConstraint( "create table sample (example integer check (example > 5));" ) checkOneConstraint( "create table sample (example integer, check (example > 5));" ) checkOneConstraint( "create table sample " "(example integer, constraint gt_5 check (example>5))", "gt_5" ) def test_checkKeywordConstraint(self): """ A column with a CHECK constraint in SQL that compares with a keyword expression such as 'lower' will result in a L{Check} constraint being added to the L{Table} object. """ def checkOneConstraint(sqlText): s = self.schemaFromString(sqlText) table = s.tableNamed('sample') self.assertEquals(len(table.constraints), 1) expr = table.constraints[0].expression self.assertEquals(expr.a.model, table.columnNamed("example")) self.assertEquals(expr.op, "=") self.assertEquals(expr.b.function.name, "lower") self.assertEquals( expr.b.args, tuple([ColumnSyntax(table.columnNamed("example"))]) ) checkOneConstraint( "create table sample " "(example integer check (example = lower (example)));" ) def test_multiUnique(self): """ A column with a UNIQUE constraint in SQL will result in the table listing that column as a unique set. """ s = self.schemaFromString( "create table a (b integer, c integer, unique (b, c), unique (c));" ) a = s.tableNamed('a') b = a.columnNamed('b') c = a.columnNamed('c') self.assertEquals(list(a.uniques()), [[b, c], [c]]) def test_singlePrimaryKey(self): """ A table with a multi-column PRIMARY KEY clause will be parsed as a list of a single L{Column} object and stored as a C{primaryKey} attribute on the L{Table} object. """ s = self.schemaFromString( "create table a (b integer primary key, c integer)" ) a = s.tableNamed("a") self.assertEquals(a.primaryKey, [a.columnNamed("b")]) def test_multiPrimaryKey(self): """ A table with a multi-column PRIMARY KEY clause will be parsed as a list C{primaryKey} attribute on the Table object. """ s = self.schemaFromString( "create table a (b integer, c integer, primary key (b, c))" ) a = s.tableNamed("a") self.assertEquals( a.primaryKey, [a.columnNamed("b"), a.columnNamed("c")] ) def test_deleteAction(self): """ A column with an 'on delete cascade' constraint will have its C{cascade} attribute set to True. """ s = self.schemaFromString( """ create table a1 (b1 integer primary key); create table c2 (d2 integer references a1 on delete cascade); create table e3 (f3 integer references a1 on delete set null); create table g4 (h4 integer references a1 on delete set default); """) self.assertEquals(s.tableNamed("a1").columnNamed("b1").deleteAction, None) self.assertEquals(s.tableNamed("c2").columnNamed("d2").deleteAction, "cascade") self.assertEquals(s.tableNamed("e3").columnNamed("f3").deleteAction, "set null") self.assertEquals(s.tableNamed("g4").columnNamed("h4").deleteAction, "set default") def test_indexes(self): """ A 'create index' statement will add an L{Index} object to a L{Schema}'s C{indexes} list. """ s = self.schemaFromString( """ create table q (b integer); -- noise create table a (b integer primary key, c integer); create table z (c integer); -- make sure we get the right table create index idx_a_b on a(b); create index idx_a_b_c on a (c, b); create index idx_c on z using btree (c); """) a = s.tableNamed("a") b = s.indexNamed("idx_a_b") bc = s.indexNamed('idx_a_b_c') self.assertEquals(b.table, a) self.assertEquals(b.columns, [a.columnNamed("b")]) self.assertEquals(bc.table, a) self.assertEquals(bc.columns, [a.columnNamed("c"), a.columnNamed("b")]) def test_pseudoIndexes(self): """ A implicit and explicit indexes are listed. """ s = self.schemaFromString( """ create table q (b integer); -- noise create table a (b integer primary key, c integer); create table z (c integer, unique(c) ); create unique index idx_a_c on a(c); create index idx_a_b_c on a (c, b); """) self.assertEqual(set([pseudo.name for pseudo in s.pseudoIndexes()]), set(( "a-unique:(c)", "a:(c,b)", "a-unique:(b)", "z-unique:(c)", ))) calendarserver-5.2+dfsg/twext/enterprise/dal/test/test_record.py0000644000175000017500000003060112263343324024252 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Test cases for L{twext.enterprise.dal.record}. """ import datetime from twisted.internet.defer import inlineCallbacks from twisted.trial.unittest import TestCase from twext.enterprise.dal.record import ( Record, fromTable, ReadOnly, NoSuchRecord ) from twext.enterprise.dal.test.test_parseschema import SchemaTestHelper from twext.enterprise.dal.syntax import SchemaSyntax from twisted.internet.defer import gatherResults from twisted.internet.defer import returnValue from twext.enterprise.fixtures import buildConnectionPool # from twext.enterprise.dal.syntax import sth = SchemaTestHelper() sth.id = lambda : __name__ schemaString = """ create table ALPHA (BETA integer primary key, GAMMA text); create table DELTA (PHI integer primary key default (nextval('myseq')), EPSILON text not null, ZETA timestamp not null default '2012-12-12 12:12:12' ); """ # sqlite can be made to support nextval() as a function, but 'create sequence' # is syntax and can't. parseableSchemaString = """ create sequence myseq; """ + schemaString testSchema = SchemaSyntax(sth.schemaFromString(parseableSchemaString)) class TestRecord(Record, fromTable(testSchema.ALPHA)): """ A sample test record. """ class TestAutoRecord(Record, fromTable(testSchema.DELTA)): """ A sample test record with default values specified. """ class TestCRUD(TestCase): """ Tests for creation, mutation, and deletion operations. """ def setUp(self): self.pool = buildConnectionPool(self, schemaString) @inlineCallbacks def test_simpleLoad(self): """ Loading an existing row from the database by its primary key will populate its attributes from columns of the corresponding row in the database. """ txn = self.pool.connection() yield txn.execSQL("insert into ALPHA values (:1, :2)", [234, "one"]) yield txn.execSQL("insert into ALPHA values (:1, :2)", [456, "two"]) rec = yield TestRecord.load(txn, 456) self.assertIsInstance(rec, TestRecord) self.assertEquals(rec.beta, 456) self.assertEquals(rec.gamma, "two") rec2 = yield TestRecord.load(txn, 234) self.assertIsInstance(rec2, TestRecord) self.assertEqual(rec2.beta, 234) self.assertEqual(rec2.gamma, "one") @inlineCallbacks def test_missingLoad(self): """ Try loading an row which doesn't exist """ txn = self.pool.connection() yield txn.execSQL("insert into ALPHA values (:1, :2)", [234, "one"]) self.assertFailure(TestRecord.load(txn, 456), NoSuchRecord) @inlineCallbacks def test_simpleCreate(self): """ When a record object is created, a row with matching column values will be created in the database. """ txn = self.pool.connection() rec = yield TestRecord.create(txn, beta=3, gamma=u'epsilon') self.assertEquals(rec.beta, 3) self.assertEqual(rec.gamma, u'epsilon') rows = yield txn.execSQL("select BETA, GAMMA from ALPHA") self.assertEqual(rows, [tuple([3, u'epsilon'])]) @inlineCallbacks def test_simpleDelete(self): """ When a record object is deleted, a row with a matching primary key will be deleted in the database. """ txn = self.pool.connection() def mkrow(beta, gamma): return txn.execSQL("insert into ALPHA values (:1, :2)", [beta, gamma]) yield gatherResults([mkrow(123, u"one"), mkrow(234, u"two"), mkrow(345, u"three")]) tr = yield TestRecord.load(txn, 234) yield tr.delete() rows = yield txn.execSQL("select BETA, GAMMA from ALPHA order by BETA") self.assertEqual(rows, [(123, u"one"), (345, u"three")]) @inlineCallbacks def oneRowCommitted(self, beta=123, gamma=u'456'): """ Create, commit, and return one L{TestRecord}. """ txn = self.pool.connection(self.id()) row = yield TestRecord.create(txn, beta=beta, gamma=gamma) yield txn.commit() returnValue(row) @inlineCallbacks def test_deleteWhenDeleted(self): """ When a record object is deleted, if it's already been deleted, it will raise L{NoSuchRecord}. """ row = yield self.oneRowCommitted() txn = self.pool.connection(self.id()) newRow = yield TestRecord.load(txn, row.beta) yield newRow.delete() self.failUnlessFailure(newRow.delete(), NoSuchRecord) @inlineCallbacks def test_cantCreateWithoutRequiredValues(self): """ When a L{Record} object is created without required values, it raises a L{TypeError}. """ txn = self.pool.connection() te = yield self.failUnlessFailure(TestAutoRecord.create(txn), TypeError) self.assertIn("required attribute 'epsilon' not passed", str(te)) @inlineCallbacks def test_datetimeType(self): """ When a L{Record} references a timestamp column, it retrieves the date as UTC. """ txn = self.pool.connection() # Create ... rec = yield TestAutoRecord.create(txn, epsilon=1) self.assertEquals(rec.zeta, datetime.datetime(2012, 12, 12, 12, 12, 12)) yield txn.commit() # ... should have the same effect as loading. txn = self.pool.connection() rec = (yield TestAutoRecord.all(txn))[0] self.assertEquals(rec.zeta, datetime.datetime(2012, 12, 12, 12, 12, 12)) @inlineCallbacks def test_tooManyAttributes(self): """ When a L{Record} object is created with unknown attributes (those which don't map to any column), it raises a L{TypeError}. """ txn = self.pool.connection() te = yield self.failUnlessFailure(TestRecord.create( txn, beta=3, gamma=u'three', extraBonusAttribute=u'nope', otherBonusAttribute=4321, ), TypeError) self.assertIn("extraBonusAttribute, otherBonusAttribute", str(te)) @inlineCallbacks def test_createFillsInPKey(self): """ If L{Record.create} is called without an auto-generated primary key value for its row, that value will be generated and set on the returned object. """ txn = self.pool.connection() tr = yield TestAutoRecord.create(txn, epsilon=u'specified') tr2 = yield TestAutoRecord.create(txn, epsilon=u'also specified') self.assertEquals(tr.phi, 1) self.assertEquals(tr2.phi, 2) @inlineCallbacks def test_attributesArentMutableYet(self): """ Changing attributes on a database object is not supported yet, because it's not entirely clear when to flush the SQL to the database. Instead, for the time being, use C{.update}. When you attempt to set an attribute, an error will be raised informing you of this fact, so that the error is clear. """ txn = self.pool.connection() rec = yield TestRecord.create(txn, beta=7, gamma=u'what') def setit(): rec.beta = 12 ro = self.assertRaises(ReadOnly, setit) self.assertEqual(rec.beta, 7) self.assertIn("SQL-backed attribute 'TestRecord.beta' is read-only. " "Use '.update(...)' to modify attributes.", str(ro)) @inlineCallbacks def test_simpleUpdate(self): """ L{Record.update} will change the values on the record and in te database. """ txn = self.pool.connection() rec = yield TestRecord.create(txn, beta=3, gamma=u'epsilon') yield rec.update(gamma=u'otherwise') self.assertEqual(rec.gamma, u'otherwise') yield txn.commit() # Make sure that it persists. txn = self.pool.connection() rec = yield TestRecord.load(txn, 3) self.assertEqual(rec.gamma, u'otherwise') @inlineCallbacks def test_simpleQuery(self): """ L{Record.query} will allow you to query for a record by its class attributes as columns. """ txn = self.pool.connection() for beta, gamma in [(123, u"one"), (234, u"two"), (345, u"three"), (356, u"three"), (456, u"four")]: yield txn.execSQL("insert into ALPHA values (:1, :2)", [beta, gamma]) records = yield TestRecord.query(txn, TestRecord.gamma == u"three") self.assertEqual(len(records), 2) records.sort(key=lambda x: x.beta) self.assertEqual(records[0].beta, 345) self.assertEqual(records[1].beta, 356) @inlineCallbacks def test_all(self): """ L{Record.all} will return all instances of the record, sorted by primary key. """ txn = self.pool.connection() data = [(123, u"one"), (456, u"four"), (345, u"three"), (234, u"two"), (356, u"three")] for beta, gamma in data: yield txn.execSQL("insert into ALPHA values (:1, :2)", [beta, gamma]) self.assertEqual( [(x.beta, x.gamma) for x in (yield TestRecord.all(txn))], sorted(data) ) @inlineCallbacks def test_repr(self): """ The C{repr} of a L{Record} presents all its values. """ txn = self.pool.connection() yield txn.execSQL("insert into ALPHA values (:1, :2)", [789, u'nine']) rec = list((yield TestRecord.all(txn)))[0] self.assertIn(" beta=789", repr(rec)) self.assertIn(" gamma=u'nine'", repr(rec)) @inlineCallbacks def test_orderedQuery(self): """ L{Record.query} takes an 'order' argument which will allow the objects returned to be ordered. """ txn = self.pool.connection() for beta, gamma in [(123, u"one"), (234, u"two"), (345, u"three"), (356, u"three"), (456, u"four")]: yield txn.execSQL("insert into ALPHA values (:1, :2)", [beta, gamma]) records = yield TestRecord.query(txn, TestRecord.gamma == u"three", TestRecord.beta) self.assertEqual([record.beta for record in records], [345, 356]) records = yield TestRecord.query(txn, TestRecord.gamma == u"three", TestRecord.beta, ascending=False) self.assertEqual([record.beta for record in records], [356, 345]) @inlineCallbacks def test_pop(self): """ A L{Record} may be loaded and deleted atomically, with L{Record.pop}. """ txn = self.pool.connection() for beta, gamma in [(123, u"one"), (234, u"two"), (345, u"three"), (356, u"three"), (456, u"four")]: yield txn.execSQL("insert into ALPHA values (:1, :2)", [beta, gamma]) rec = yield TestRecord.pop(txn, 234) self.assertEqual(rec.gamma, u'two') self.assertEqual((yield txn.execSQL("select count(*) from ALPHA " "where BETA = :1", [234])), [tuple([0])]) yield self.failUnlessFailure(TestRecord.pop(txn, 234), NoSuchRecord) def test_columnNamingConvention(self): """ The naming convention maps columns C{LIKE_THIS} to be attributes C{likeThis}. """ self.assertEqual(Record.namingConvention(u"like_this"), "likeThis") self.assertEqual(Record.namingConvention(u"LIKE_THIS"), "likeThis") self.assertEqual(Record.namingConvention(u"LIKE_THIS_ID"), "likeThisID") calendarserver-5.2+dfsg/twext/enterprise/dal/test/test_sqlsyntax.py0000644000175000017500000017524212263343324025055 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.enterprise.dal.syntax} """ from twext.enterprise.dal import syntax from twext.enterprise.dal.parseschema import addSQLToSchema from twext.enterprise.dal.syntax import ( Select, Insert, Update, Delete, Lock, SQLFragment, TableMismatch, Parameter, Max, Len, NotEnoughValues, Savepoint, RollbackToSavepoint, ReleaseSavepoint, SavepointAction, Union, Intersect, Except, SetExpression, DALError, ResultAliasSyntax, Count, QueryGenerator, ALL_COLUMNS, DatabaseLock, DatabaseUnlock) from twext.enterprise.dal.syntax import FixedPlaceholder, NumericPlaceholder from twext.enterprise.dal.syntax import Function from twext.enterprise.dal.syntax import SchemaSyntax from twext.enterprise.dal.test.test_parseschema import SchemaTestHelper from twext.enterprise.ienterprise import (POSTGRES_DIALECT, ORACLE_DIALECT, SQLITE_DIALECT) from twext.enterprise.test.test_adbapi2 import ConnectionPoolHelper from twext.enterprise.test.test_adbapi2 import NetworkedPoolHelper from twext.enterprise.test.test_adbapi2 import resultOf, AssertResultHelper from twisted.internet.defer import succeed from twisted.trial.unittest import TestCase from twext.enterprise.dal.syntax import Tuple from twext.enterprise.dal.syntax import Constant class _FakeTransaction(object): """ An L{IAsyncTransaction} that provides the relevant metadata for SQL generation. """ def __init__(self, paramstyle): self.paramstyle = 'qmark' class FakeCXOracleModule(object): NUMBER = 'the NUMBER type' STRING = 'a string type (for varchars)' NCLOB = 'the NCLOB type. (for text)' TIMESTAMP = 'for timestamps!' class CatchSQL(object): """ L{IAsyncTransaction} emulator that records the SQL executed on it. """ counter = 0 def __init__(self, dialect=SQLITE_DIALECT, paramstyle='numeric'): self.execed = [] self.pendingResults = [] self.dialect = SQLITE_DIALECT self.paramstyle = 'numeric' def nextResult(self, result): """ Make it so that the next result from L{execSQL} will be the argument. """ self.pendingResults.append(result) def execSQL(self, sql, args, rozrc): """ Implement L{IAsyncTransaction} by recording C{sql} and C{args} in C{self.execed}, and return a L{Deferred} firing either an integer or a value pre-supplied by L{CatchSQL.nextResult}. """ self.execed.append([sql, args]) self.counter += 1 if self.pendingResults: result = self.pendingResults.pop(0) else: result = self.counter return succeed(result) class NullTestingOracleTxn(object): """ Fake transaction for testing oracle NULL behavior. """ dialect = ORACLE_DIALECT paramstyle = 'numeric' def execSQL(self, text, params, exc): return succeed([[None, None]]) EXAMPLE_SCHEMA = """ create sequence A_SEQ; create table FOO (BAR integer, BAZ varchar(255)); create table BOZ (QUX integer, QUUX integer); create table OTHER (BAR integer, FOO_BAR integer not null); create table TEXTUAL (MYTEXT varchar(255)); create table LEVELS (ACCESS integer, USERNAME varchar(255)); create table NULLCHECK (ASTRING varchar(255) not null, ANUMBER integer); """ class ExampleSchemaHelper(SchemaTestHelper): """ setUp implementor. """ def setUp(self): self.schema = SchemaSyntax(self.schemaFromString(EXAMPLE_SCHEMA)) class GenerationTests(ExampleSchemaHelper, TestCase, AssertResultHelper): """ Tests for syntactic helpers to generate SQL queries. """ def test_simplestSelect(self): """ L{Select} generates a 'select' statement, by default, asking for all rows in a table. """ self.assertEquals(Select(From=self.schema.FOO).toSQL(), SQLFragment("select * from FOO", [])) def test_tableSyntaxFromSchemaSyntaxCompare(self): """ One L{TableSyntax} is equivalent to another wrapping the same table; one wrapping a different table is different. """ self.assertEquals(self.schema.FOO, self.schema.FOO) self.assertNotEquals(self.schema.FOO, self.schema.BOZ) def test_simpleWhereClause(self): """ L{Select} generates a 'select' statement with a 'where' clause containing an expression. """ self.assertEquals(Select(From=self.schema.FOO, Where=self.schema.FOO.BAR == 1).toSQL(), SQLFragment("select * from FOO where BAR = ?", [1])) def test_alternateMetadata(self): """ L{Select} generates a 'select' statement with the specified placeholder syntax when explicitly given L{ConnectionMetadata} which specifies a placeholder. """ self.assertEquals(Select(From=self.schema.FOO, Where=self.schema.FOO.BAR == 1).toSQL( QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("$$"))), SQLFragment("select * from FOO where BAR = $$", [1])) def test_columnComparison(self): """ L{Select} generates a 'select' statement which compares columns. """ self.assertEquals(Select(From=self.schema.FOO, Where=self.schema.FOO.BAR == self.schema.FOO.BAZ).toSQL(), SQLFragment("select * from FOO where BAR = BAZ", [])) def test_comparisonTestErrorPrevention(self): """ The comparison object between SQL expressions raises an exception when compared for a truth value, so that code will not accidentally operate on SQL objects and get a truth value. (Note that this has a caveat, in test_columnsAsDictKeys and test_columnEqualityTruth.) """ def sampleComparison(): if self.schema.FOO.BAR > self.schema.FOO.BAZ: return 'comparison should not succeed' self.assertRaises(DALError, sampleComparison) def test_compareWithNULL(self): """ Comparing a column with None results in the generation of an 'is null' or 'is not null' SQL statement. """ self.assertEquals(Select(From=self.schema.FOO, Where=self.schema.FOO.BAR == None).toSQL(), SQLFragment( "select * from FOO where BAR is null", [])) self.assertEquals(Select(From=self.schema.FOO, Where=self.schema.FOO.BAR != None).toSQL(), SQLFragment( "select * from FOO where BAR is not null", [])) def test_compareWithEmptyStringOracleSpecialCase(self): """ Oracle considers the empty string to be a NULL value, so comparisons with the empty string should be 'is NULL' comparisons. """ # Sanity check: let's make sure that the non-oracle case looks normal. self.assertEquals(Select( From=self.schema.FOO, Where=self.schema.FOO.BAR == '').toSQL(), SQLFragment( "select * from FOO where BAR = ?", [""])) self.assertEquals(Select( From=self.schema.FOO, Where=self.schema.FOO.BAR != '').toSQL(), SQLFragment( "select * from FOO where BAR != ?", [""])) self.assertEquals(Select( From=self.schema.FOO, Where=self.schema.FOO.BAR == '' ).toSQL(QueryGenerator(ORACLE_DIALECT, NumericPlaceholder())), SQLFragment( "select * from FOO where BAR is null", [])) self.assertEquals(Select( From=self.schema.FOO, Where=self.schema.FOO.BAR != '' ).toSQL(QueryGenerator(ORACLE_DIALECT, NumericPlaceholder())), SQLFragment( "select * from FOO where BAR is not null", [])) def test_compoundWhere(self): """ L{Select.And} and L{Select.Or} will return compound columns. """ self.assertEquals( Select(From=self.schema.FOO, Where=(self.schema.FOO.BAR < 2).Or( self.schema.FOO.BAR > 5)).toSQL(), SQLFragment("select * from FOO where BAR < ? or BAR > ?", [2, 5])) def test_orderBy(self): """ L{Select}'s L{OrderBy} parameter generates an 'order by' clause for a 'select' statement. """ self.assertEquals( Select(From=self.schema.FOO, OrderBy=self.schema.FOO.BAR).toSQL(), SQLFragment("select * from FOO order by BAR") ) def test_orderByOrder(self): """ L{Select}'s L{Ascending} parameter specifies an ascending/descending order for query results with an OrderBy clause. """ self.assertEquals( Select(From=self.schema.FOO, OrderBy=self.schema.FOO.BAR, Ascending=False).toSQL(), SQLFragment("select * from FOO order by BAR desc") ) self.assertEquals( Select(From=self.schema.FOO, OrderBy=self.schema.FOO.BAR, Ascending=True).toSQL(), SQLFragment("select * from FOO order by BAR asc") ) self.assertEquals( Select(From=self.schema.FOO, OrderBy=[self.schema.FOO.BAR, self.schema.FOO.BAZ], Ascending=True).toSQL(), SQLFragment("select * from FOO order by BAR, BAZ asc") ) def test_orderByParens(self): """ L{Select}'s L{OrderBy} paraneter, if specified as a L{Tuple}, generates an SQL expression I{without} parentheses, since the standard format does not allow an arbitrary sort expression but rather a list of columns. """ self.assertEquals( Select(From=self.schema.FOO, OrderBy=Tuple([self.schema.FOO.BAR, self.schema.FOO.BAZ])).toSQL(), SQLFragment("select * from FOO order by BAR, BAZ") ) def test_forUpdate(self): """ L{Select}'s L{ForUpdate} parameter generates a 'for update' clause at the end of the query. """ self.assertEquals( Select(From=self.schema.FOO, ForUpdate=True).toSQL(), SQLFragment("select * from FOO for update") ) def test_groupBy(self): """ L{Select}'s L{GroupBy} parameter generates a 'group by' clause for a 'select' statement. """ self.assertEquals( Select(From=self.schema.FOO, GroupBy=self.schema.FOO.BAR).toSQL(), SQLFragment("select * from FOO group by BAR") ) def test_groupByMulti(self): """ L{Select}'s L{GroupBy} parameter can accept multiple columns in a list. """ self.assertEquals( Select(From=self.schema.FOO, GroupBy=[self.schema.FOO.BAR, self.schema.FOO.BAZ]).toSQL(), SQLFragment("select * from FOO group by BAR, BAZ") ) def test_joinClause(self): """ A table's .join() method returns a join statement in a SELECT. """ self.assertEquals( Select(From=self.schema.FOO.join( self.schema.BOZ, self.schema.FOO.BAR == self.schema.BOZ.QUX)).toSQL(), SQLFragment("select * from FOO join BOZ on BAR = QUX", []) ) def test_crossJoin(self): """ A join with no clause specified will generate a cross join. (This is an explicit synonym for an implicit join: i.e. 'select * from FOO, BAR'.) """ self.assertEquals( Select(From=self.schema.FOO.join(self.schema.BOZ)).toSQL(), SQLFragment("select * from FOO cross join BOZ") ) def test_joinJoin(self): """ L{Join.join} will result in a multi-table join. """ self.assertEquals( Select([self.schema.FOO.BAR, self.schema.BOZ.QUX], From=self.schema.FOO .join(self.schema.BOZ).join(self.schema.OTHER)).toSQL(), SQLFragment( "select FOO.BAR, QUX from FOO " "cross join BOZ cross join OTHER") ) def test_multiJoin(self): """ L{Join.join} has the same signature as L{TableSyntax.join} and supports the same 'on' and 'type' arguments. """ self.assertEquals( Select([self.schema.FOO.BAR], From=self.schema.FOO.join( self.schema.BOZ).join( self.schema.OTHER, self.schema.OTHER.BAR == self.schema.FOO.BAR, 'left outer')).toSQL(), SQLFragment( "select FOO.BAR from FOO cross join BOZ left outer join OTHER " "on OTHER.BAR = FOO.BAR") ) def test_tableAliasing(self): """ Tables may be given aliases, in order to facilitate self-joins. """ sfoo = self.schema.FOO sfoo2 = sfoo.alias() self.assertEqual( Select(From=self.schema.FOO.join(sfoo2)).toSQL(), SQLFragment("select * from FOO cross join FOO alias1") ) def test_columnsOfAliasedTable(self): """ The columns of aliased tables will always be prefixed with their alias in the generated SQL. """ sfoo = self.schema.FOO sfoo2 = sfoo.alias() self.assertEquals( Select([sfoo2.BAR], From=sfoo2).toSQL(), SQLFragment("select alias1.BAR from FOO alias1") ) def test_multipleTableAliases(self): """ When multiple aliases are used for the same table, they will be unique within the query. """ foo = self.schema.FOO fooPrime = foo.alias() fooPrimePrime = foo.alias() self.assertEquals( Select([fooPrime.BAR, fooPrimePrime.BAR], From=fooPrime.join(fooPrimePrime)).toSQL(), SQLFragment("select alias1.BAR, alias2.BAR " "from FOO alias1 cross join FOO alias2") ) def test_columnSelection(self): """ If a column is specified by the argument to L{Select}, those will be output by the SQL statement rather than the all-columns wildcard. """ self.assertEquals( Select([self.schema.FOO.BAR], From=self.schema.FOO).toSQL(), SQLFragment("select BAR from FOO") ) def test_tableIteration(self): """ Iterating a L{TableSyntax} iterates its columns, in the order that they are defined. """ self.assertEquals(list(self.schema.FOO), [self.schema.FOO.BAR, self.schema.FOO.BAZ]) def test_noColumn(self): """ Accessing an attribute that is not a defined column on a L{TableSyntax} raises an L{AttributeError}. """ self.assertRaises(AttributeError, lambda : self.schema.FOO.NOT_A_COLUMN) def test_columnAliases(self): """ When attributes are set on a L{TableSyntax}, they will be remembered as column aliases, and their alias names may be retrieved via the L{TableSyntax.columnAliases} method. """ self.assertEquals(self.schema.FOO.columnAliases(), {}) self.schema.FOO.ALIAS = self.schema.FOO.BAR # you comparing ColumnSyntax object results in a ColumnComparison, which # you can't test for truth. fixedForEquality = dict([(k, v.model) for k, v in self.schema.FOO.columnAliases().items()]) self.assertEquals(fixedForEquality, {'ALIAS': self.schema.FOO.BAR.model}) self.assertIdentical(self.schema.FOO.ALIAS.model, self.schema.FOO.BAR.model) def test_multiColumnSelection(self): """ If multiple columns are specified by the argument to L{Select}, those will be output by the SQL statement rather than the all-columns wildcard. """ self.assertEquals( Select([self.schema.FOO.BAZ, self.schema.FOO.BAR], From=self.schema.FOO).toSQL(), SQLFragment("select BAZ, BAR from FOO") ) def test_joinColumnSelection(self): """ If multiple columns are specified by the argument to L{Select} that uses a L{TableSyntax.join}, those will be output by the SQL statement. """ self.assertEquals( Select([self.schema.FOO.BAZ, self.schema.BOZ.QUX], From=self.schema.FOO.join(self.schema.BOZ, self.schema.FOO.BAR == self.schema.BOZ.QUX)).toSQL(), SQLFragment("select BAZ, QUX from FOO join BOZ on BAR = QUX") ) def test_tableMismatch(self): """ When a column in the 'columns' argument does not match the table from the 'From' argument, L{Select} raises a L{TableMismatch}. """ self.assertRaises(TableMismatch, Select, [self.schema.BOZ.QUX], From=self.schema.FOO) def test_qualifyNames(self): """ When two columns in the FROM clause requested from different tables have the same name, the emitted SQL should explicitly disambiguate them. """ self.assertEquals( Select([self.schema.FOO.BAR, self.schema.OTHER.BAR], From=self.schema.FOO.join(self.schema.OTHER, self.schema.OTHER.FOO_BAR == self.schema.FOO.BAR)).toSQL(), SQLFragment( "select FOO.BAR, OTHER.BAR from FOO " "join OTHER on FOO_BAR = FOO.BAR")) def test_bindParameters(self): """ L{SQLFragment.bind} returns a copy of that L{SQLFragment} with the L{Parameter} objects in its parameter list replaced with the keyword arguments to C{bind}. """ self.assertEquals( Select(From=self.schema.FOO, Where=(self.schema.FOO.BAR > Parameter("testing")).And( self.schema.FOO.BAZ < 7)).toSQL().bind(testing=173), SQLFragment("select * from FOO where BAR > ? and BAZ < ?", [173, 7])) def test_rightHandSideExpression(self): """ Arbitrary expressions may be used as the right-hand side of a comparison operation. """ self.assertEquals( Select(From=self.schema.FOO, Where=self.schema.FOO.BAR > (self.schema.FOO.BAZ + 3)).toSQL(), SQLFragment("select * from FOO where BAR > (BAZ + ?)", [3]) ) def test_setSelects(self): """ L{SetExpression} produces set operation on selects. """ # Simple UNION self.assertEquals( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 1), SetExpression=Union( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 2), ), ), ).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))), SQLFragment( "(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?)", [1, 2])) # Simple INTERSECT ALL self.assertEquals( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 1), SetExpression=Intersect( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 2), ), optype=SetExpression.OPTYPE_ALL ), ).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))), SQLFragment( "(select * from FOO where BAR = ?) INTERSECT ALL (select * from FOO where BAR = ?)", [1, 2])) # Multiple EXCEPTs, not nested, Postgres dialect self.assertEquals( Select( From=self.schema.FOO, SetExpression=Except( ( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 2), ), Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 3), ), ), optype=SetExpression.OPTYPE_DISTINCT, ), ).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))), SQLFragment( "(select * from FOO) EXCEPT DISTINCT (select * from FOO where BAR = ?) EXCEPT DISTINCT (select * from FOO where BAR = ?)", [2, 3])) # Nested EXCEPTs, Oracle dialect self.assertEquals( Select( From=self.schema.FOO, SetExpression=Except( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 2), SetExpression=Except( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 3), ), ), ), ), ).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))), SQLFragment( "(select * from FOO) MINUS ((select * from FOO where BAR = ?) MINUS (select * from FOO where BAR = ?))", [2, 3])) # UNION with order by self.assertEquals( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 1), SetExpression=Union( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR == 2), ), ), OrderBy=self.schema.FOO.BAR, ).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))), SQLFragment( "(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?) order by BAR", [1, 2])) def test_simpleSubSelects(self): """ L{Max}C{(column)} produces an object in the 'columns' clause that renders the 'max' aggregate in SQL. """ self.assertEquals( Select( [Max(self.schema.BOZ.QUX)], From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ)) ).toSQL(), SQLFragment( "select max(QUX) from (select QUX from BOZ) genid_1")) self.assertEquals( Select( [Count(self.schema.BOZ.QUX)], From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ)) ).toSQL(), SQLFragment( "select count(QUX) from (select QUX from BOZ) genid_1")) self.assertEquals( Select( [Max(self.schema.BOZ.QUX)], From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ, As="alias_BAR")), ).toSQL(), SQLFragment( "select max(QUX) from (select QUX from BOZ) alias_BAR")) def test_setSubSelects(self): """ L{SetExpression} in a From sub-select. """ # Simple UNION self.assertEquals( Select( [Max(self.schema.FOO.BAR)], From=Select( [self.schema.FOO.BAR], From=self.schema.FOO, Where=(self.schema.FOO.BAR == 1), SetExpression=Union( Select( [self.schema.FOO.BAR], From=self.schema.FOO, Where=(self.schema.FOO.BAR == 2), ), ), ) ).toSQL(), SQLFragment( "select max(BAR) from ((select BAR from FOO where BAR = ?) UNION (select BAR from FOO where BAR = ?)) genid_1", [1, 2])) def test_selectColumnAliases(self): """ L{Select} works with aliased columns. """ self.assertEquals( Select( [ResultAliasSyntax(self.schema.BOZ.QUX, "BOZ_QUX")], From=self.schema.BOZ ).toSQL(), SQLFragment("select QUX BOZ_QUX from BOZ")) self.assertEquals( Select( [ResultAliasSyntax(Max(self.schema.BOZ.QUX))], From=self.schema.BOZ ).toSQL(), SQLFragment("select max(QUX) genid_1 from BOZ")) alias = ResultAliasSyntax(Max(self.schema.BOZ.QUX)) self.assertEquals( Select([alias.columnReference()], From=Select( [alias], From=self.schema.BOZ) ).toSQL(), SQLFragment("select genid_1 from (select max(QUX) genid_1 from BOZ) genid_2")) alias = ResultAliasSyntax(Len(self.schema.BOZ.QUX)) self.assertEquals( Select([alias.columnReference()], From=Select( [alias], From=self.schema.BOZ) ).toSQL(), SQLFragment("select genid_1 from (select character_length(QUX) genid_1 from BOZ) genid_2")) def test_inSubSelect(self): """ L{ColumnSyntax.In} returns a sub-expression using the SQL 'in' syntax with a sub-select. """ wherein = (self.schema.FOO.BAR.In( Select([self.schema.BOZ.QUX], From=self.schema.BOZ))) self.assertEquals( Select(From=self.schema.FOO, Where=wherein).toSQL(), SQLFragment( "select * from FOO where BAR in (select QUX from BOZ)")) def test_inParameter(self): """ L{ColumnSyntax.In} returns a sub-expression using the SQL 'in' syntax with parameter list. """ # One item with IN only items = set(('A',)) self.assertEquals( Select(From=self.schema.FOO, Where=self.schema.FOO.BAR.In(Parameter("names", len(items)))).toSQL().bind(names=items), SQLFragment( "select * from FOO where BAR in (?)", ['A'])) # Two items with IN only items = set(('A', 'B')) self.assertEquals( Select(From=self.schema.FOO, Where=self.schema.FOO.BAR.In(Parameter("names", len(items)))).toSQL().bind(names=items), SQLFragment( "select * from FOO where BAR in (?, ?)", ['A', 'B'])) # Two items with preceding AND self.assertEquals( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAZ == Parameter('P1')).And( self.schema.FOO.BAR.In(Parameter("names", len(items)) )) ).toSQL().bind(P1="P1", names=items), SQLFragment( "select * from FOO where BAZ = ? and BAR in (?, ?)", ['P1', 'A', 'B']), ) # Two items with following AND self.assertEquals( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR.In(Parameter("names", len(items))).And( self.schema.FOO.BAZ == Parameter('P2') )) ).toSQL().bind(P2="P2", names=items), SQLFragment( "select * from FOO where BAR in (?, ?) and BAZ = ?", ['A', 'B', 'P2']), ) # Two items with preceding OR and following AND self.assertEquals( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAZ == Parameter('P1')).Or( self.schema.FOO.BAR.In(Parameter("names", len(items))).And( self.schema.FOO.BAZ == Parameter('P2') )) ).toSQL().bind(P1="P1", P2="P2", names=items), SQLFragment( "select * from FOO where BAZ = ? or BAR in (?, ?) and BAZ = ?", ['P1', 'A', 'B', 'P2']), ) # Check various error situations # No count not allowed self.assertRaises(DALError, self.schema.FOO.BAR.In, Parameter("names")) # count=0 not allowed self.assertRaises(DALError, Parameter, "names", 0) # Mismatched count and len(items) self.assertRaises( DALError, Select(From=self.schema.FOO, Where=self.schema.FOO.BAR.In(Parameter("names", len(items)))).toSQL().bind, names=["a", "b", "c", ] ) def test_max(self): """ L{Max}C{(column)} produces an object in the 'columns' clause that renders the 'max' aggregate in SQL. """ self.assertEquals( Select([Max(self.schema.BOZ.QUX)], From=self.schema.BOZ).toSQL(), SQLFragment( "select max(QUX) from BOZ")) def test_countAllCoumns(self): """ L{Count}C{(ALL_COLUMNS)} produces an object in the 'columns' clause that renders the 'count' in SQL. """ self.assertEquals( Select([Count(ALL_COLUMNS)], From=self.schema.BOZ).toSQL(), SQLFragment( "select count(*) from BOZ")) def test_aggregateComparison(self): """ L{Max}C{(column) > constant} produces an object in the 'columns' clause that renders a comparison to the 'max' aggregate in SQL. """ self.assertEquals(Select([Max(self.schema.BOZ.QUX) + 12], From=self.schema.BOZ).toSQL(), SQLFragment("select max(QUX) + ? from BOZ", [12])) def test_multiColumnExpression(self): """ Multiple columns may be provided in an expression in the 'columns' portion of a Select() statement. All arithmetic operators are supported. """ self.assertEquals( Select([((self.schema.FOO.BAR + self.schema.FOO.BAZ) / 3) * 7], From=self.schema.FOO).toSQL(), SQLFragment("select ((BAR + BAZ) / ?) * ? from FOO", [3, 7]) ) def test_len(self): """ Test for the 'Len' function for determining character length of a column. (Note that this should be updated to use different techniques as necessary in different databases.) """ self.assertEquals( Select([Len(self.schema.TEXTUAL.MYTEXT)], From=self.schema.TEXTUAL).toSQL(), SQLFragment( "select character_length(MYTEXT) from TEXTUAL")) def test_startswith(self): """ Test for the string starts with comparison. (Note that this should be updated to use different techniques as necessary in different databases.) """ self.assertEquals( Select([ self.schema.TEXTUAL.MYTEXT], From=self.schema.TEXTUAL, Where=self.schema.TEXTUAL.MYTEXT.StartsWith("test"), ).toSQL(), SQLFragment( "select MYTEXT from TEXTUAL where MYTEXT like (? || ?)", ["test", "%"] ) ) def test_endswith(self): """ Test for the string starts with comparison. (Note that this should be updated to use different techniques as necessary in different databases.) """ self.assertEquals( Select([ self.schema.TEXTUAL.MYTEXT], From=self.schema.TEXTUAL, Where=self.schema.TEXTUAL.MYTEXT.EndsWith("test"), ).toSQL(), SQLFragment( "select MYTEXT from TEXTUAL where MYTEXT like (? || ?)", ["%", "test"] ) ) def test_contains(self): """ Test for the string starts with comparison. (Note that this should be updated to use different techniques as necessary in different databases.) """ self.assertEquals( Select([ self.schema.TEXTUAL.MYTEXT], From=self.schema.TEXTUAL, Where=self.schema.TEXTUAL.MYTEXT.Contains("test"), ).toSQL(), SQLFragment( "select MYTEXT from TEXTUAL where MYTEXT like (? || (? || ?))", ["%", "test", "%"] ) ) def test_insert(self): """ L{Insert.toSQL} generates an 'insert' statement with all the relevant columns. """ self.assertEquals( Insert({self.schema.FOO.BAR: 23, self.schema.FOO.BAZ: 9}).toSQL(), SQLFragment("insert into FOO (BAR, BAZ) values (?, ?)", [23, 9])) def test_insertNotEnough(self): """ L{Insert}'s constructor will raise L{NotEnoughValues} if columns have not been specified. """ notEnough = self.assertRaises( NotEnoughValues, Insert, {self.schema.OTHER.BAR: 9} ) self.assertEquals(str(notEnough), "Columns [FOO_BAR] required.") def test_insertReturning(self): """ L{Insert}'s C{Return} argument will insert an SQL 'returning' clause. """ self.assertEquals( Insert({self.schema.FOO.BAR: 23, self.schema.FOO.BAZ: 9}, Return=self.schema.FOO.BAR).toSQL(), SQLFragment( "insert into FOO (BAR, BAZ) values (?, ?) returning BAR", [23, 9]) ) def test_insertMultiReturn(self): """ L{Insert}'s C{Return} argument can also be a C{tuple}, which will insert an SQL 'returning' clause with multiple columns. """ self.assertEquals( Insert({self.schema.FOO.BAR: 23, self.schema.FOO.BAZ: 9}, Return=(self.schema.FOO.BAR, self.schema.FOO.BAZ)).toSQL(), SQLFragment( "insert into FOO (BAR, BAZ) values (?, ?) returning BAR, BAZ", [23, 9]) ) def test_insertMultiReturnOracle(self): """ In Oracle's SQL dialect, the 'returning' clause requires an 'into' clause indicating where to put the results, as they can't be simply relayed to the cursor. Further, additional bound variables are required to capture the output parameters. """ self.assertEquals( Insert({self.schema.FOO.BAR: 40, self.schema.FOO.BAZ: 50}, Return=(self.schema.FOO.BAR, self.schema.FOO.BAZ)).toSQL( QueryGenerator(ORACLE_DIALECT, NumericPlaceholder()) ), SQLFragment( "insert into FOO (BAR, BAZ) values (:1, :2) returning BAR, BAZ" " into :3, :4", [40, 50, Parameter("oracle_out_0"), Parameter("oracle_out_1")] ) ) def test_insertMultiReturnSQLite(self): """ In SQLite's SQL dialect, there is no 'returning' clause, but given that SQLite serializes all SQL transactions, you can rely upon 'select' after a write operation to reliably give you exactly what was just modified. Therefore, although 'toSQL' won't include any indication of the return value, the 'on' method will execute a 'select' statement following the insert to retrieve the value. """ insertStatement = Insert({self.schema.FOO.BAR: 39, self.schema.FOO.BAZ: 82}, Return=(self.schema.FOO.BAR, self.schema.FOO.BAZ) ) qg = lambda : QueryGenerator(SQLITE_DIALECT, NumericPlaceholder()) self.assertEquals(insertStatement.toSQL(qg()), SQLFragment("insert into FOO (BAR, BAZ) values (:1, :2)", [39, 82]) ) result = [] csql = CatchSQL() insertStatement.on(csql).addCallback(result.append) self.assertEqual(result, [2]) self.assertEqual( csql.execed, [["insert into FOO (BAR, BAZ) values (:1, :2)", [39, 82]], ["select BAR, BAZ from FOO where rowid = last_insert_rowid()", []]] ) def test_insertNoReturnSQLite(self): """ Insert a row I{without} a C{Return=} parameter should also work as normal in sqlite. """ statement = Insert({self.schema.FOO.BAR: 12, self.schema.FOO.BAZ: 48}) csql = CatchSQL() statement.on(csql) self.assertEqual( csql.execed, [["insert into FOO (BAR, BAZ) values (:1, :2)", [12, 48]]] ) def test_updateReturningSQLite(self): """ Since SQLite does not support the SQL 'returning' syntax extension, in order to preserve the rows that will be modified during an UPDATE statement, we must first find the rows that will be affected, then update them, then return the rows that were affected. Since we might be changing even part of the primary key, we use the internal 'rowid' column to uniquely and reliably identify rows in the sqlite database that have been modified. """ csql = CatchSQL() stmt = Update({self.schema.FOO.BAR: 4321}, Where=self.schema.FOO.BAZ == 1234, Return=self.schema.FOO.BAR) csql.nextResult([["sample row id"]]) result = resultOf(stmt.on(csql)) # Three statements were executed; make sure that the result returned was # the result of executing the 3rd (and final) one. self.assertResultList(result, 3) # Check that they were the right statements. self.assertEqual(len(csql.execed), 3) self.assertEqual( csql.execed[0], ["select rowid from FOO where BAZ = :1", [1234]] ) self.assertEqual( csql.execed[1], ["update FOO set BAR = :1 where BAZ = :2", [4321, 1234]] ) self.assertEqual( csql.execed[2], ["select BAR from FOO where rowid = :1", ["sample row id"]] ) def test_updateReturningMultipleValuesSQLite(self): """ When SQLite updates multiple values, it must embed the row ID of each subsequent value into its second 'where' clause, as there is no way to pass a list of values to a single statement.. """ csql = CatchSQL() stmt = Update({self.schema.FOO.BAR: 4321}, Where=self.schema.FOO.BAZ == 1234, Return=self.schema.FOO.BAR) csql.nextResult([["one row id"], ["and another"], ["and one more"]]) result = resultOf(stmt.on(csql)) # Three statements were executed; make sure that the result returned was # the result of executing the 3rd (and final) one. self.assertResultList(result, 3) # Check that they were the right statements. self.assertEqual(len(csql.execed), 3) self.assertEqual( csql.execed[0], ["select rowid from FOO where BAZ = :1", [1234]] ) self.assertEqual( csql.execed[1], ["update FOO set BAR = :1 where BAZ = :2", [4321, 1234]] ) self.assertEqual( csql.execed[2], ["select BAR from FOO where rowid = :1 or rowid = :2 or rowid = :3", ["one row id", "and another", "and one more"]] ) def test_deleteReturningSQLite(self): """ When SQLite deletes a value, ... """ csql = CatchSQL() stmt = Delete(From=self.schema.FOO, Where=self.schema.FOO.BAZ == 1234, Return=self.schema.FOO.BAR) result = resultOf(stmt.on(csql)) self.assertResultList(result, 1) self.assertEqual(len(csql.execed), 2) self.assertEqual( csql.execed[0], ["select BAR from FOO where BAZ = :1", [1234]] ) self.assertEqual( csql.execed[1], ["delete from FOO where BAZ = :1", [1234]] ) def test_insertMismatch(self): """ L{Insert} raises L{TableMismatch} if the columns specified aren't all from the same table. """ self.assertRaises( TableMismatch, Insert, {self.schema.FOO.BAR: 23, self.schema.FOO.BAZ: 9, self.schema.TEXTUAL.MYTEXT: 'hello'} ) def test_quotingOnKeywordConflict(self): """ 'access' is a keyword, so although our schema parser will leniently accept it, it must be quoted in any outgoing SQL. (This is only done in the Oracle dialect, because it isn't necessary in postgres, and idiosyncratic case-folding rules make it challenging to do it in both.) """ self.assertEquals( Insert({self.schema.LEVELS.ACCESS: 1, self.schema.LEVELS.USERNAME: "hi"}).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))), SQLFragment( 'insert into LEVELS ("ACCESS", USERNAME) values (?, ?)', [1, "hi"]) ) self.assertEquals( Insert({self.schema.LEVELS.ACCESS: 1, self.schema.LEVELS.USERNAME: "hi"}).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))), SQLFragment( 'insert into LEVELS (ACCESS, USERNAME) values (?, ?)', [1, "hi"]) ) def test_updateReturning(self): """ L{update}'s C{Return} argument will update an SQL 'returning' clause. """ self.assertEquals( Update({self.schema.FOO.BAR: 23}, self.schema.FOO.BAZ == 43, Return=self.schema.FOO.BAR).toSQL(), SQLFragment( "update FOO set BAR = ? where BAZ = ? returning BAR", [23, 43]) ) def test_updateMismatch(self): """ L{Update} raises L{TableMismatch} if the columns specified aren't all from the same table. """ self.assertRaises( TableMismatch, Update, {self.schema.FOO.BAR: 23, self.schema.FOO.BAZ: 9, self.schema.TEXTUAL.MYTEXT: 'hello'}, Where=self.schema.FOO.BAZ == 9 ) def test_updateFunction(self): """ L{Update} values may be L{FunctionInvocation}s, to update to computed values in the database. """ sqlfunc = Function("hello") self.assertEquals( Update( {self.schema.FOO.BAR: 23, self.schema.FOO.BAZ: sqlfunc()}, Where=self.schema.FOO.BAZ == 9 ).toSQL(), SQLFragment("update FOO set BAR = ?, BAZ = hello() " "where BAZ = ?", [23, 9]) ) def test_insertFunction(self): """ L{Update} values may be L{FunctionInvocation}s, to update to computed values in the database. """ sqlfunc = Function("hello") self.assertEquals( Insert( {self.schema.FOO.BAR: 23, self.schema.FOO.BAZ: sqlfunc()}, ).toSQL(), SQLFragment("insert into FOO (BAR, BAZ) " "values (?, hello())", [23]) ) def test_deleteReturning(self): """ L{Delete}'s C{Return} argument will delete an SQL 'returning' clause. """ self.assertEquals( Delete(self.schema.FOO, Where=self.schema.FOO.BAR == 7, Return=self.schema.FOO.BAZ).toSQL(), SQLFragment( "delete from FOO where BAR = ? returning BAZ", [7]) ) def test_update(self): """ L{Update.toSQL} generates an 'update' statement. """ self.assertEquals( Update({self.schema.FOO.BAR: 4321}, self.schema.FOO.BAZ == 1234).toSQL(), SQLFragment("update FOO set BAR = ? where BAZ = ?", [4321, 1234])) def test_delete(self): """ L{Delete} generates an SQL 'delete' statement. """ self.assertEquals( Delete(self.schema.FOO, Where=self.schema.FOO.BAR == 12).toSQL(), SQLFragment( "delete from FOO where BAR = ?", [12]) ) self.assertEquals( Delete(self.schema.FOO, Where=None).toSQL(), SQLFragment("delete from FOO") ) def test_lock(self): """ L{Lock.exclusive} generates a ('lock table') statement, locking the table in the specified mode. """ self.assertEquals(Lock.exclusive(self.schema.FOO).toSQL(), SQLFragment("lock table FOO in exclusive mode")) def test_databaseLock(self): """ L{DatabaseLock} generates a ('pg_advisory_lock') statement """ self.assertEquals(DatabaseLock().toSQL(), SQLFragment("select pg_advisory_lock(1)")) def test_databaseUnlock(self): """ L{DatabaseUnlock} generates a ('pg_advisory_unlock') statement """ self.assertEquals(DatabaseUnlock().toSQL(), SQLFragment("select pg_advisory_unlock(1)")) def test_savepoint(self): """ L{Savepoint} generates a ('savepoint') statement. """ self.assertEquals(Savepoint("test").toSQL(), SQLFragment("savepoint test")) def test_rollbacktosavepoint(self): """ L{RollbackToSavepoint} generates a ('rollback to savepoint') statement. """ self.assertEquals(RollbackToSavepoint("test").toSQL(), SQLFragment("rollback to savepoint test")) def test_releasesavepoint(self): """ L{ReleaseSavepoint} generates a ('release savepoint') statement. """ self.assertEquals(ReleaseSavepoint("test").toSQL(), SQLFragment("release savepoint test")) def test_savepointaction(self): """ L{SavepointAction} generates a ('savepoint') statement. """ self.assertEquals(SavepointAction("test")._name, "test") def test_limit(self): """ A L{Select} object with a 'Limit' keyword parameter will generate a SQL statement with a 'limit' clause. """ self.assertEquals( Select([self.schema.FOO.BAR], From=self.schema.FOO, Limit=123).toSQL(), SQLFragment( "select BAR from FOO limit ?", [123])) def test_limitOracle(self): """ A L{Select} object with a 'Limit' keyword parameter will generate a SQL statement using a ROWNUM subquery for Oracle. See U{this "ask tom" article from 2006 for more information }. """ self.assertEquals( Select([self.schema.FOO.BAR], From=self.schema.FOO, Limit=123).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))), SQLFragment( "select * from (select BAR from FOO) " "where ROWNUM <= ?", [123]) ) def test_having(self): """ A L{Select} object with a 'Having' keyword parameter will generate a SQL statement with a 'having' expression. """ self.assertEquals( Select([self.schema.FOO.BAR], From=self.schema.FOO, Having=Max(self.schema.FOO.BAZ) < 7).toSQL(), SQLFragment("select BAR from FOO having max(BAZ) < ?", [7]) ) def test_distinct(self): """ A L{Select} object with a 'Disinct' keyword parameter with a value of C{True} will generate a SQL statement with a 'distinct' keyword preceding its list of columns. """ self.assertEquals( Select([self.schema.FOO.BAR], From=self.schema.FOO, Distinct=True).toSQL(), SQLFragment("select distinct BAR from FOO") ) def test_nextSequenceValue(self): """ When a sequence is used as a value in an expression, it renders as the call to 'nextval' that will produce its next value. """ self.assertEquals( Insert({self.schema.BOZ.QUX: self.schema.A_SEQ}).toSQL(), SQLFragment("insert into BOZ (QUX) values (nextval('A_SEQ'))", [])) def test_nextSequenceValueOracle(self): """ When a sequence is used as a value in an expression in the Oracle dialect, it renders as the 'nextval' attribute of the appropriate sequence. """ self.assertEquals( Insert({self.schema.BOZ.QUX: self.schema.A_SEQ}).toSQL( QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))), SQLFragment("insert into BOZ (QUX) values (A_SEQ.nextval)", [])) def test_nextSequenceDefaultImplicitExplicitOracle(self): """ In Oracle's dialect, sequence defaults can't be implemented without using triggers, so instead we just explicitly always include the sequence default value. """ addSQLToSchema( schema=self.schema.model, schemaData="create table DFLTR (a varchar(255), " "b integer default nextval('A_SEQ'));" ) self.assertEquals( Insert({self.schema.DFLTR.a: 'hello'}).toSQL( QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?")) ), SQLFragment("insert into DFLTR (a, b) values " "(?, A_SEQ.nextval)", ['hello']), ) # Should be the same if it's explicitly specified. self.assertEquals( Insert({self.schema.DFLTR.a: 'hello', self.schema.DFLTR.b: self.schema.A_SEQ}).toSQL( QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?")) ), SQLFragment("insert into DFLTR (a, b) values " "(?, A_SEQ.nextval)", ['hello']), ) def test_numericParams(self): """ An L{IAsyncTransaction} with the 'numeric' paramstyle attribute will cause statements to be generated with parameters in the style of :1 :2 :3, as per the DB-API. """ stmts = [] class FakeOracleTxn(object): def execSQL(self, text, params, exc): stmts.append((text, params)) dialect = ORACLE_DIALECT paramstyle = 'numeric' Select([self.schema.FOO.BAR], From=self.schema.FOO, Where=(self.schema.FOO.BAR == 7).And( self.schema.FOO.BAZ == 9) ).on(FakeOracleTxn()) self.assertEquals( stmts, [("select BAR from FOO where BAR = :1 and BAZ = :2", [7, 9])] ) def test_rewriteOracleNULLs_Select(self): """ Oracle databases cannot distinguish between the empty string and C{NULL}. When you insert an empty string, C{cx_Oracle} therefore treats it as a C{None} and will return that when you select it back again. We address this in the schema by dropping 'not null' constraints. Therefore, when executing a statement which includes a string column, 'on' should rewrite None return values from C{cx_Oracle} to be empty bytestrings, but only for string columns. """ rows = resultOf( Select([self.schema.NULLCHECK.ASTRING, self.schema.NULLCHECK.ANUMBER], From=self.schema.NULLCHECK).on(NullTestingOracleTxn()))[0] self.assertEquals(rows, [['', None]]) def test_rewriteOracleNULLs_SelectAllColumns(self): """ Same as L{test_rewriteOracleNULLs_Select}, but with the L{ALL_COLUMNS} shortcut. """ rows = resultOf( Select(From=self.schema.NULLCHECK).on(NullTestingOracleTxn()) )[0] self.assertEquals(rows, [['', None]]) def test_nestedLogicalExpressions(self): """ Make sure that logical operator precedence inserts proper parenthesis when needed. e.g. 'a.And(b.Or(c))' needs to be 'a and (b or c)' not 'a and b or c'. """ self.assertEquals( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR != 7). And(self.schema.FOO.BAZ != 8). And((self.schema.FOO.BAR == 8).Or(self.schema.FOO.BAZ == 0)) ).toSQL(), SQLFragment("select * from FOO where BAR != ? and BAZ != ? and " "(BAR = ? or BAZ = ?)", [7, 8, 8, 0])) self.assertEquals( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR != 7). Or(self.schema.FOO.BAZ != 8). Or((self.schema.FOO.BAR == 8).And(self.schema.FOO.BAZ == 0)) ).toSQL(), SQLFragment("select * from FOO where BAR != ? or BAZ != ? or " "BAR = ? and BAZ = ?", [7, 8, 8, 0])) self.assertEquals( Select( From=self.schema.FOO, Where=(self.schema.FOO.BAR != 7). Or(self.schema.FOO.BAZ != 8). And((self.schema.FOO.BAR == 8).Or(self.schema.FOO.BAZ == 0)) ).toSQL(), SQLFragment("select * from FOO where (BAR != ? or BAZ != ?) and " "(BAR = ? or BAZ = ?)", [7, 8, 8, 0])) def test_updateWithNULL(self): """ As per the DB-API specification, "SQL NULL values are represented by the Python None singleton on input and output." When a C{None} is provided as a value to an L{Update}, it will be relayed to the database as a parameter. """ self.assertEquals( Update({self.schema.BOZ.QUX: None}, Where=self.schema.BOZ.QUX == 7).toSQL(), SQLFragment("update BOZ set QUX = ? where QUX = ?", [None, 7]) ) def test_subSelectComparison(self): """ A comparison of a column to a sub-select in a where clause will result in a parenthetical 'Where' clause. """ self.assertEquals( Update( {self.schema.BOZ.QUX: 9}, Where=self.schema.BOZ.QUX == Select([self.schema.FOO.BAR], From=self.schema.FOO, Where=self.schema.FOO.BAZ == 12)).toSQL(), SQLFragment( # NOTE: it's very important that the comparison _always_ go in # this order (column from the UPDATE first, inner SELECT second) # as the other order will be considered a syntax error. "update BOZ set QUX = ? where QUX = (" "select BAR from FOO where BAZ = ?)", [9, 12] ) ) def test_tupleComparison(self): """ A L{Tuple} allows for simultaneous comparison of multiple values in a C{Where} clause. This feature is particularly useful when issuing an L{Update} or L{Delete}, where the comparison is with values from a subselect. (A L{Tuple} will be automatically generated upon comparison to a C{tuple} or C{list}.) """ self.assertEquals( Update( {self.schema.BOZ.QUX: 1}, Where=(self.schema.BOZ.QUX, self.schema.BOZ.QUUX) == Select([self.schema.FOO.BAR, self.schema.FOO.BAZ], From=self.schema.FOO, Where=self.schema.FOO.BAZ == 2)).toSQL(), SQLFragment( # NOTE: it's very important that the comparison _always_ go in # this order (tuple of columns from the UPDATE first, inner # SELECT second) as the other order will be considered a syntax # error. "update BOZ set QUX = ? where (QUX, QUUX) = (" "select BAR, BAZ from FOO where BAZ = ?)", [1, 2] ) ) def test_tupleOfConstantsComparison(self): """ For some reason Oracle requires multiple parentheses for comparisons. """ self.assertEquals( Select( [self.schema.FOO.BAR], From=self.schema.FOO, Where=(Tuple([self.schema.FOO.BAR, self.schema.FOO.BAZ]) == Tuple([Constant(7), Constant(9)])) ).toSQL(), SQLFragment( "select BAR from FOO where (BAR, BAZ) = ((?, ?))", [7, 9] ) ) def test_oracleTableTruncation(self): """ L{Table}'s SQL generation logic will truncate table names if the dialect (i.e. Oracle) demands it. (See txdav.common.datastore.sql_tables for the schema translator and enforcement of name uniqueness in the derived schema.) """ addSQLToSchema( self.schema.model, "create table veryveryveryveryveryveryveryverylong " "(foo integer);" ) vvl = self.schema.veryveryveryveryveryveryveryverylong self.assertEquals( Insert({vvl.foo: 1}).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))), SQLFragment( "insert into veryveryveryveryveryveryveryve (foo) values " "(?)", [1] ) ) def test_columnEqualityTruth(self): """ Mostly in support of test_columnsAsDictKeys, the 'same' column should compare True to itself and False to other values. """ s = self.schema self.assertEquals(bool(s.FOO.BAR == s.FOO.BAR), True) self.assertEquals(bool(s.FOO.BAR != s.FOO.BAR), False) self.assertEquals(bool(s.FOO.BAZ != s.FOO.BAR), True) def test_columnsAsDictKeys(self): """ An odd corner of the syntactic sugar provided by the DAL is that the column objects have to participate both in augmented equality comparison ("==" returns an expression object) as well as dictionary keys (for Insert and Update statement objects). Therefore it should be possible to I{manipulate} dictionaries of keys as well. """ values = {self.schema.FOO.BAR: 1} self.assertEquals(values, {self.schema.FOO.BAR: 1}) values.pop(self.schema.FOO.BAR) self.assertEquals(values, {}) class OracleConnectionMethods(object): def test_rewriteOracleNULLs_Insert(self): """ The behavior described in L{test_rewriteOracleNULLs_Select} applies to other statement types as well, specifically those with 'returning' clauses. """ # Add 2 cursor variable values so that these will be used by # FakeVariable.getvalue. self.factory.varvals.extend([None, None]) rows = self.resultOf( Insert({self.schema.NULLCHECK.ASTRING: '', self.schema.NULLCHECK.ANUMBER: None}, Return=[self.schema.NULLCHECK.ASTRING, self.schema.NULLCHECK.ANUMBER] ).on(self.createTransaction()))[0] self.assertEquals(rows, [['', None]]) def test_insertMultiReturnOnOracleTxn(self): """ As described in L{test_insertMultiReturnOracle}, Oracle deals with 'returning' clauses by using out parameters. However, this is not quite enough, as the code needs to actually retrieve the values from the out parameters. """ i = Insert({self.schema.FOO.BAR: 40, self.schema.FOO.BAZ: 50}, Return=(self.schema.FOO.BAR, self.schema.FOO.BAZ)) self.factory.varvals.extend(["first val!", "second val!"]) result = self.resultOf(i.on(self.createTransaction())) self.assertEquals(result, [[["first val!", "second val!"]]]) curvars = self.factory.connections[0].cursors[0].variables self.assertEquals(len(curvars), 2) self.assertEquals(curvars[0].type, FakeCXOracleModule.NUMBER) self.assertEquals(curvars[1].type, FakeCXOracleModule.STRING) def test_insertNoReturnOracle(self): """ In addition to being able to execute insert statements with a Return attribute, oracle also ought to be able to execute insert statements with no Return at all. """ # This statement should return nothing from .fetchall(), so... self.factory.hasResults = False i = Insert({self.schema.FOO.BAR: 40, self.schema.FOO.BAZ: 50}) result = self.resultOf(i.on(self.createTransaction())) self.assertEquals(result, [None]) class OracleConnectionTests(ConnectionPoolHelper, ExampleSchemaHelper, OracleConnectionMethods, TestCase): """ Tests which use an oracle connection. """ dialect = ORACLE_DIALECT def setUp(self): """ Create a fake oracle-ish connection pool without using real threads or a real database. """ self.patch(syntax, 'cx_Oracle', FakeCXOracleModule) super(OracleConnectionTests, self).setUp() ExampleSchemaHelper.setUp(self) class OracleNetConnectionTests(NetworkedPoolHelper, ExampleSchemaHelper, OracleConnectionMethods, TestCase): dialect = ORACLE_DIALECT def setUp(self): self.patch(syntax, 'cx_Oracle', FakeCXOracleModule) super(OracleNetConnectionTests, self).setUp() ExampleSchemaHelper.setUp(self) self.pump.client.dialect = ORACLE_DIALECT calendarserver-5.2+dfsg/twext/enterprise/dal/test/__init__.py0000644000175000017500000000120712263343324023474 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for twext.enterprise.dal. """ calendarserver-5.2+dfsg/twext/enterprise/dal/record.py0000644000175000017500000003075312263343324022244 0ustar rahulrahul# -*- test-case-name: twext.enterprise.dal.test.test_record -*- ## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ RECORD: Relational Entity Creation from Objects Representing Data. This is an asynchronous object-relational mapper based on L{twext.enterprise.dal.syntax}. """ from twisted.internet.defer import inlineCallbacks, returnValue from twext.enterprise.dal.syntax import ( Select, Tuple, Constant, ColumnSyntax, Insert, Update, Delete ) from twext.enterprise.util import parseSQLTimestamp # from twext.enterprise.dal.syntax import ExpressionSyntax class ReadOnly(AttributeError): """ A caller attempted to set an attribute on a database-backed record, rather than updating it through L{Record.update}. """ def __init__(self, className, attributeName): self.className = className self.attributeName = attributeName super(ReadOnly, self).__init__("SQL-backed attribute '{0}.{1}' is " "read-only. Use '.update(...)' to " "modify attributes." .format(className, attributeName)) class NoSuchRecord(Exception): """ No matching record could be found. """ class _RecordMeta(type): """ Metaclass for associating a L{fromTable} with a L{Record} at inheritance time. """ def __new__(cls, name, bases, ns): """ Create a new instance of this meta-type. """ newbases = [] table = None namer = None for base in bases: if isinstance(base, fromTable): if table is not None: raise RuntimeError( "Can't define a class from two or more tables at once." ) table = base.table elif getattr(base, "table", None) is not None: raise RuntimeError( "Can't define a record class by inheriting one already " "mapped to a table." # TODO: more info ) else: if namer is None: if isinstance(base, _RecordMeta): namer = base newbases.append(base) if table is not None: attrmap = {} colmap = {} allColumns = list(table) for column in allColumns: attrname = namer.namingConvention(column.model.name) attrmap[attrname] = column colmap[column] = attrname ns.update(table=table, __attrmap__=attrmap, __colmap__=colmap) ns.update(attrmap) return super(_RecordMeta, cls).__new__(cls, name, tuple(newbases), ns) class fromTable(object): """ Inherit from this after L{Record} to specify which table your L{Record} subclass is mapped to. """ def __init__(self, aTable): """ @param table: The table to map to. @type table: L{twext.enterprise.dal.syntax.TableSyntax} """ self.table = aTable class Record(object): """ Superclass for all database-backed record classes. (i.e. an object mapped from a database record). @cvar table: the table that represents this L{Record} in the database. @type table: L{TableSyntax} @ivar transaction: The L{IAsyncTransaction} where this record is being loaded. This may be C{None} if this L{Record} is not participating in a transaction, which may be true if it was instantiated but never saved. @cvar __colmap__: map of L{ColumnSyntax} objects to attribute names. @type __colmap__: L{dict} @cvar __attrmap__: map of attribute names to L{ColumnSyntax} objects. @type __attrmap__: L{dict} """ __metaclass__ = _RecordMeta transaction = None def __setattr__(self, name, value): """ Once the transaction is initialized, this object is immutable. If you want to change it, use L{Record.update}. """ if self.transaction is not None: raise ReadOnly(self.__class__.__name__, name) return super(Record, self).__setattr__(name, value) def __repr__(self): r = "<{0} record from table {1}".format(self.__class__.__name__, self.table.model.name) for k in sorted(self.__attrmap__.keys()): r += " {0}={1}".format(k, repr(getattr(self, k))) r += ">" return r @staticmethod def namingConvention(columnName): """ Implement the convention for naming-conversion between column names (typically, upper-case database names map to lower-case attribute names). """ words = columnName.lower().split("_") def cap(word): if word.lower() == 'id': return word.upper() else: return word.capitalize() return words[0] + "".join(map(cap, words[1:])) @classmethod def _primaryKeyExpression(cls): return Tuple([ColumnSyntax(c) for c in cls.table.model.primaryKey]) def _primaryKeyValue(self): val = [] for col in self._primaryKeyExpression().columns: val.append(getattr(self, self.__class__.__colmap__[col])) return val @classmethod def _primaryKeyComparison(cls, primaryKey): return (cls._primaryKeyExpression() == Tuple(map(Constant, primaryKey))) @classmethod @inlineCallbacks def load(cls, transaction, *primaryKey): results = (yield cls.query(transaction, cls._primaryKeyComparison(primaryKey))) if len(results) != 1: raise NoSuchRecord() else: returnValue(results[0]) @classmethod @inlineCallbacks def create(cls, transaction, **k): """ Create a row. Used like this:: MyRecord.create(transaction, column1=1, column2=u'two') """ self = cls() colmap = {} attrtocol = cls.__attrmap__ needsCols = [] needsAttrs = [] for attr in attrtocol: col = attrtocol[attr] if attr in k: setattr(self, attr, k[attr]) colmap[col] = k.pop(attr) else: if col.model.needsValue(): raise TypeError("required attribute " + repr(attr) + " not passed") else: needsCols.append(col) needsAttrs.append(attr) if k: raise TypeError("received unknown attribute{0}: {1}".format( "s" if len(k) > 1 else "", ", ".join(sorted(k)) )) result = yield (Insert(colmap, Return=needsCols if needsCols else None) .on(transaction)) if needsCols: self._attributesFromRow(zip(needsAttrs, result[0])) self.transaction = transaction returnValue(self) def _attributesFromRow(self, attributeList): """ Take some data loaded from a row and apply it to this instance, converting types as necessary. @param attributeList: a C{list} of 2-C{tuples} of C{(attributeName, attributeValue)}. """ for setAttribute, setValue in attributeList: setColumn = self.__attrmap__[setAttribute] if setColumn.model.type.name == "timestamp": setValue = parseSQLTimestamp(setValue) setattr(self, setAttribute, setValue) def delete(self): """ Delete this row from the database. @return: a L{Deferred} which fires with C{None} when the underlying row has been deleted, or fails with L{NoSuchRecord} if the underlying row was already deleted. """ return Delete(From=self.table, Where=self._primaryKeyComparison(self._primaryKeyValue()) ).on(self.transaction, raiseOnZeroRowCount=NoSuchRecord) @inlineCallbacks def update(self, **kw): """ Modify the given attributes in the database. @return: a L{Deferred} that fires when the updates have been sent to the database. """ colmap = {} for k, v in kw.iteritems(): colmap[self.__attrmap__[k]] = v yield (Update(colmap, Where=self._primaryKeyComparison(self._primaryKeyValue())) .on(self.transaction)) self.__dict__.update(kw) @classmethod def pop(cls, transaction, *primaryKey): """ Atomically retrieve and remove a row from this L{Record}'s table with a primary key value of C{primaryKey}. @return: a L{Deferred} that fires with an instance of C{cls}, or fails with L{NoSuchRecord} if there were no records in the database. @rtype: L{Deferred} """ return cls._rowsFromQuery( transaction, Delete(Where=cls._primaryKeyComparison(primaryKey), From=cls.table, Return=list(cls.table)), lambda : NoSuchRecord() ).addCallback(lambda x: x[0]) @classmethod def query(cls, transaction, expr, order=None, ascending=True, group=None): """ Query the table that corresponds to C{cls}, and return instances of C{cls} corresponding to the rows that are returned from that table. @param expr: An L{ExpressionSyntax} that constraints the results of the query. This is most easily produced by accessing attributes on the class; for example, C{MyRecordType.query((MyRecordType.col1 > MyRecordType.col2).And(MyRecordType.col3 == 7))} @param order: A L{ColumnSyntax} to order the resulting record objects by. @param ascending: A boolean; if C{order} is not C{None}, whether to sort in ascending or descending order. @param group: a L{ColumnSyntax} to group the resulting record objects by. """ kw = {} if order is not None: kw.update(OrderBy=order, Ascending=ascending) if group is not None: kw.update(GroupBy=group) return cls._rowsFromQuery(transaction, Select(list(cls.table), From=cls.table, Where=expr, **kw), None) @classmethod def all(cls, transaction): """ Load all rows from the table that corresponds to C{cls} and return instances of C{cls} corresponding to all. """ return cls._rowsFromQuery(transaction, Select(list(cls.table), From=cls.table, OrderBy=cls._primaryKeyExpression()), None) @classmethod @inlineCallbacks def _rowsFromQuery(cls, transaction, qry, rozrc): """ Execute the given query, and transform its results into instances of C{cls}. @param transaction: an L{IAsyncTransaction} to execute the query on. @param qry: a L{_DMLStatement} (XXX: maybe _DMLStatement or some interface that defines 'on' should be public?) whose results are the list of columns in C{self.table}. @param rozrc: The C{raiseOnZeroRowCount} argument. @return: a L{Deferred} that succeeds with a C{list} of instances of C{cls} or fails with an exception produced by C{rozrc}. """ rows = yield qry.on(transaction, raiseOnZeroRowCount=rozrc) selves = [] names = [cls.__colmap__[column] for column in list(cls.table)] for row in rows: self = cls() self._attributesFromRow(zip(names, row)) self.transaction = transaction selves.append(self) returnValue(selves) __all__ = [ "ReadOnly", "fromTable", "NoSuchRecord", ] calendarserver-5.2+dfsg/twext/enterprise/dal/parseschema.py0000644000175000017500000005227012263343324023257 0ustar rahulrahul# -*- test-case-name: twext.enterprise.dal.test.test_parseschema -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from __future__ import print_function """ Parser for SQL schema. """ from itertools import chain from sqlparse import parse, keywords from sqlparse.tokens import ( Keyword, Punctuation, Number, String, Name, Comparison as CompTok ) from sqlparse.sql import (Comment, Identifier, Parenthesis, IdentifierList, Function, Comparison) from twext.enterprise.dal.model import ( Schema, Table, SQLType, ProcedureCall, Constraint, Sequence, Index) from twext.enterprise.dal.syntax import ( ColumnSyntax, CompoundComparison, Constant, Function as FunctionSyntax ) def _fixKeywords(): """ Work around bugs in SQLParse, adding SEQUENCE as a keyword (since it is treated as one in postgres) and removing ACCESS and SIZE (since we use those as column names). Technically those are keywords in SQL, but they aren't treated as such by postgres's parser. """ keywords.KEYWORDS['SEQUENCE'] = Keyword for columnNameKeyword in ['ACCESS', 'SIZE']: del keywords.KEYWORDS[columnNameKeyword] _fixKeywords() def tableFromCreateStatement(schema, stmt): """ Add a table from a CREATE TABLE sqlparse statement object. @param schema: The schema to add the table statement to. @type schema: L{Schema} @param stmt: The C{CREATE TABLE} statement object. @type stmt: L{Statement} """ i = iterSignificant(stmt) expect(i, ttype=Keyword.DDL, value='CREATE') expect(i, ttype=Keyword, value='TABLE') function = expect(i, cls=Function) i = iterSignificant(function) name = expect(i, cls=Identifier).get_name().encode('utf-8') self = Table(schema, name) parens = expect(i, cls=Parenthesis) cp = _ColumnParser(self, iterSignificant(parens), parens) cp.parse() return self def schemaFromPath(path): """ Get a L{Schema}. @param path: a L{FilePath}-like object containing SQL. @return: a L{Schema} object with the contents of the given C{path} parsed and added to it as L{Table} objects. """ schema = Schema(path.basename()) schemaData = path.getContent() addSQLToSchema(schema, schemaData) return schema def schemaFromString(data): """ Get a L{Schema}. @param data: a C{str} containing SQL. @return: a L{Schema} object with the contents of the given C{str} parsed and added to it as L{Table} objects. """ schema = Schema() addSQLToSchema(schema, data) return schema def addSQLToSchema(schema, schemaData): """ Add new SQL to an existing schema. @param schema: The schema to add the new SQL to. @type schema: L{Schema} @param schemaData: A string containing some SQL statements. @type schemaData: C{str} @return: the C{schema} argument """ parsed = parse(schemaData) for stmt in parsed: preface = '' while stmt.tokens and not significant(stmt.tokens[0]): preface += str(stmt.tokens.pop(0)) if not stmt.tokens: continue if stmt.get_type() == 'CREATE': createType = stmt.token_next(1, True).value.upper() if createType == u'TABLE': t = tableFromCreateStatement(schema, stmt) t.addComment(preface) elif createType == u'SEQUENCE': Sequence(schema, stmt.token_next(2, True).get_name().encode('utf-8')) elif createType in (u'INDEX', u'UNIQUE'): signifindex = iterSignificant(stmt) expect(signifindex, ttype=Keyword.DDL, value='CREATE') token = signifindex.next() unique = False if token.match(Keyword, "UNIQUE"): unique = True token = signifindex.next() if not token.match(Keyword, "INDEX"): raise ViolatedExpectation("INDEX or UNQIUE", token.value) indexName = nameOrIdentifier(signifindex.next()) expect(signifindex, ttype=Keyword, value='ON') token = signifindex.next() if isinstance(token, Function): [tableName, columnArgs] = iterSignificant(token) else: tableName = token token = signifindex.next() if token.match(Keyword, "USING"): [_ignore, columnArgs] = iterSignificant(expect(signifindex, cls=Function)) else: raise ViolatedExpectation('USING', token) tableName = nameOrIdentifier(tableName) arggetter = iterSignificant(columnArgs) expect(arggetter, ttype=Punctuation, value=u'(') valueOrValues = arggetter.next() if isinstance(valueOrValues, IdentifierList): valuelist = valueOrValues.get_identifiers() else: valuelist = [valueOrValues] expect(arggetter, ttype=Punctuation, value=u')') idx = Index(schema, indexName, schema.tableNamed(tableName), unique) for token in valuelist: columnName = nameOrIdentifier(token) idx.addColumn(idx.table.columnNamed(columnName)) elif stmt.get_type() == 'INSERT': insertTokens = iterSignificant(stmt) expect(insertTokens, ttype=Keyword.DML, value='INSERT') expect(insertTokens, ttype=Keyword, value='INTO') tableName = expect(insertTokens, cls=Identifier).get_name() expect(insertTokens, ttype=Keyword, value='VALUES') values = expect(insertTokens, cls=Parenthesis) vals = iterSignificant(values) expect(vals, ttype=Punctuation, value='(') valuelist = expect(vals, cls=IdentifierList) expect(vals, ttype=Punctuation, value=')') rowData = [] for ident in valuelist.get_identifiers(): rowData.append( {Number.Integer: int, String.Single: _destringify} [ident.ttype](ident.value) ) schema.tableNamed(tableName).insertSchemaRow(rowData) else: print('unknown type:', stmt.get_type()) return schema class _ColumnParser(object): """ Stateful parser for the things between commas. """ def __init__(self, table, parenIter, parens): """ @param table: the L{Table} to add data to. @param parenIter: the iterator. """ self.parens = parens self.iter = parenIter self.table = table def __iter__(self): """ This object is an iterator; return itself. """ return self def next(self): """ Get the next L{IdentifierList}. """ result = self.iter.next() if isinstance(result, IdentifierList): # Expand out all identifier lists, since they seem to pop up # incorrectly. We should never see one in a column list anyway. # http://code.google.com/p/python-sqlparse/issues/detail?id=25 while result.tokens: it = result.tokens.pop() if significant(it): self.pushback(it) return self.next() return result def pushback(self, value): """ Push the value back onto this iterator so it will be returned by the next call to C{next}. """ self.iter = chain(iter((value,)), self.iter) def parse(self): """ Parse everything. """ expect(self.iter, ttype=Punctuation, value=u"(") while self.nextColumn(): pass def nextColumn(self): """ Parse the next column or constraint, depending on the next token. """ maybeIdent = self.next() if maybeIdent.ttype == Name: return self.parseColumn(maybeIdent.value) elif isinstance(maybeIdent, Identifier): return self.parseColumn(maybeIdent.get_name()) else: return self.parseConstraint(maybeIdent) def namesInParens(self, parens): parens = iterSignificant(parens) expect(parens, ttype=Punctuation, value="(") idorids = parens.next() if isinstance(idorids, Identifier): idnames = [idorids.get_name()] elif isinstance(idorids, IdentifierList): idnames = [x.get_name() for x in idorids.get_identifiers()] else: raise ViolatedExpectation("identifier or list", repr(idorids)) expect(parens, ttype=Punctuation, value=")") return idnames def readExpression(self, parens): """ Read a given expression from a Parenthesis object. (This is currently a limited parser in support of simple CHECK constraints, not something suitable for a full WHERE Clause.) """ parens = iterSignificant(parens) expect(parens, ttype=Punctuation, value="(") nexttok = parens.next() if isinstance(nexttok, Comparison): lhs, op, rhs = list(iterSignificant(nexttok)) result = CompoundComparison(self.nameOrValue(lhs), op.value.encode("ascii"), self.nameOrValue(rhs)) elif isinstance(nexttok, Identifier): # our version of SQLParse seems to break down and not create a nice # "Comparison" object when a keyword is present. This is just a # simple workaround. lhs = self.nameOrValue(nexttok) op = expect(parens, ttype=CompTok).value.encode("ascii") funcName = expect(parens, ttype=Keyword).value.encode("ascii") rhs = FunctionSyntax(funcName)(*[ ColumnSyntax(self.table.columnNamed(x)) for x in self.namesInParens(expect(parens, cls=Parenthesis)) ]) result = CompoundComparison(lhs, op, rhs) expect(parens, ttype=Punctuation, value=")") return result def nameOrValue(self, tok): """ Inspecting a token present in an expression (for a CHECK constraint on this table), return a L{twext.enterprise.dal.syntax} object for that value. """ if isinstance(tok, Identifier): return ColumnSyntax(self.table.columnNamed(tok.get_name())) elif tok.ttype == Number.Integer: return Constant(int(tok.value)) def parseConstraint(self, constraintType): """ Parse a 'free' constraint, described explicitly in the table as opposed to being implicitly associated with a column by being placed after it. """ ident = None # TODO: make use of identifier in tableConstraint, currently only used # for checkConstraint. if constraintType.match(Keyword, 'CONSTRAINT'): ident = expect(self, cls=Identifier).get_name() constraintType = expect(self, ttype=Keyword) if constraintType.match(Keyword, 'PRIMARY'): expect(self, ttype=Keyword, value='KEY') names = self.namesInParens(expect(self, cls=Parenthesis)) self.table.primaryKey = [self.table.columnNamed(n) for n in names] elif constraintType.match(Keyword, 'UNIQUE'): names = self.namesInParens(expect(self, cls=Parenthesis)) self.table.tableConstraint(Constraint.UNIQUE, names) elif constraintType.match(Keyword, 'CHECK'): self.table.checkConstraint(self.readExpression(self.next()), ident) else: raise ViolatedExpectation('PRIMARY or UNIQUE', constraintType) return self.checkEnd(self.next()) def checkEnd(self, val): """ After a column or constraint, check the end. """ if val.value == u",": return True elif val.value == u")": return False else: raise ViolatedExpectation(", or )", val) def parseColumn(self, name): """ Parse a column with the given name. """ typeName = self.next() if isinstance(typeName, Function): [funcIdent, args] = iterSignificant(typeName) typeName = funcIdent arggetter = iterSignificant(args) expect(arggetter, value=u'(') typeLength = int(expect(arggetter, ttype=Number.Integer).value.encode('utf-8')) else: maybeTypeArgs = self.next() if isinstance(maybeTypeArgs, Parenthesis): # type arguments significant = iterSignificant(maybeTypeArgs) expect(significant, value=u"(") typeLength = int(significant.next().value) else: # something else typeLength = None self.pushback(maybeTypeArgs) theType = SQLType(typeName.value.encode("utf-8"), typeLength) theColumn = self.table.addColumn( name=name.encode("utf-8"), type=theType ) for val in self: if val.ttype == Punctuation: return self.checkEnd(val) else: expected = True def oneConstraint(t): self.table.tableConstraint(t, [theColumn.name]) if val.match(Keyword, 'PRIMARY'): expect(self, ttype=Keyword, value='KEY') # XXX check to make sure there's no other primary key yet self.table.primaryKey = [theColumn] elif val.match(Keyword, 'UNIQUE'): # XXX add UNIQUE constraint oneConstraint(Constraint.UNIQUE) elif val.match(Keyword, 'NOT'): # possibly not necessary, as 'NOT NULL' is a single keyword # in sqlparse as of 0.1.2 expect(self, ttype=Keyword, value='NULL') oneConstraint(Constraint.NOT_NULL) elif val.match(Keyword, 'NOT NULL'): oneConstraint(Constraint.NOT_NULL) elif val.match(Keyword, 'CHECK'): self.table.checkConstraint(self.readExpression(self.next())) elif val.match(Keyword, 'DEFAULT'): theDefault = self.next() if isinstance(theDefault, Parenthesis): iDefault = iterSignificant(theDefault) expect(iDefault, ttype=Punctuation, value="(") theDefault = iDefault.next() if isinstance(theDefault, Function): thingo = theDefault.tokens[0].get_name() parens = expectSingle( theDefault.tokens[-1], cls=Parenthesis ) pareniter = iterSignificant(parens) if thingo.upper() == 'NEXTVAL': expect(pareniter, ttype=Punctuation, value="(") seqname = _destringify( expect(pareniter, ttype=String.Single).value) defaultValue = self.table.schema.sequenceNamed( seqname ) defaultValue.referringColumns.append(theColumn) else: defaultValue = ProcedureCall(thingo.encode('utf-8'), parens) elif theDefault.ttype == Number.Integer: defaultValue = int(theDefault.value) elif (theDefault.ttype == Keyword and theDefault.value.lower() == 'false'): defaultValue = False elif (theDefault.ttype == Keyword and theDefault.value.lower() == 'true'): defaultValue = True elif (theDefault.ttype == Keyword and theDefault.value.lower() == 'null'): defaultValue = None elif theDefault.ttype == String.Single: defaultValue = _destringify(theDefault.value) else: raise RuntimeError( "not sure what to do: default %r" % ( theDefault)) theColumn.setDefaultValue(defaultValue) elif val.match(Keyword, 'REFERENCES'): target = nameOrIdentifier(self.next()) theColumn.doesReferenceName(target) elif val.match(Keyword, 'ON'): expect(self, ttype=Keyword.DML, value='DELETE') refAction = self.next() if refAction.ttype == Keyword and refAction.value.upper() == 'CASCADE': theColumn.deleteAction = 'cascade' elif refAction.ttype == Keyword and refAction.value.upper() == 'SET': setAction = self.next() if setAction.ttype == Keyword and setAction.value.upper() == 'NULL': theColumn.deleteAction = 'set null' elif setAction.ttype == Keyword and setAction.value.upper() == 'DEFAULT': theColumn.deleteAction = 'set default' else: raise RuntimeError("Invalid on delete set %r" % (setAction.value,)) else: raise RuntimeError("Invalid on delete %r" % (refAction.value,)) else: expected = False if not expected: print('UNEXPECTED TOKEN:', repr(val), theColumn) print(self.parens) import pprint pprint.pprint(self.parens.tokens) return 0 class ViolatedExpectation(Exception): """ An expectation about the structure of the SQL syntax was violated. """ def __init__(self, expected, got): self.expected = expected self.got = got super(ViolatedExpectation, self).__init__( "Expected %r got %s" % (expected, got) ) def nameOrIdentifier(token): """ Determine if the given object is a name or an identifier, and return the textual value of that name or identifier. @rtype: L{str} """ if isinstance(token, Identifier): return token.get_name() elif token.ttype == Name: return token.value else: raise ViolatedExpectation("identifier or name", repr(token)) def expectSingle(nextval, ttype=None, value=None, cls=None): """ Expect some properties from retrieved value. @param ttype: A token type to compare against. @param value: A value to compare against. @param cls: A class to check if the value is an instance of. @raise ViolatedExpectation: if an unexpected token is found. @return: C{nextval}, if it matches. """ if ttype is not None: if nextval.ttype != ttype: raise ViolatedExpectation(ttype, '%s:%r' % (nextval.ttype, nextval)) if value is not None: if nextval.value.upper() != value.upper(): raise ViolatedExpectation(value, nextval.value) if cls is not None: if nextval.__class__ != cls: raise ViolatedExpectation(cls, '%s:%r' % (nextval.__class__.__name__, nextval)) return nextval def expect(iterator, **kw): """ Retrieve a value from an iterator and check its properties. Same signature as L{expectSingle}, except it takes an iterator instead of a value. @see: L{expectSingle} """ nextval = iterator.next() return expectSingle(nextval, **kw) def significant(token): """ Determine if the token is 'significant', i.e. that it is not a comment and not whitespace. """ # comment has 'None' is_whitespace() result. intentional? return (not isinstance(token, Comment) and not token.is_whitespace()) def iterSignificant(tokenList): """ Iterate tokens that pass the test given by L{significant}, from a given L{TokenList}. """ for token in tokenList.tokens: if significant(token): yield token def _destringify(strval): """ Convert a single-quoted SQL string into its actual repsresented value. (Assumes standards compliance, since we should be controlling all the input here. The only quoting syntax respected is "''".) """ return strval[1:-1].replace("''", "'") calendarserver-5.2+dfsg/twext/enterprise/dal/syntax.py0000644000175000017500000016004712263343324022314 0ustar rahulrahul# -*- test-case-name: twext.enterprise.dal.test.test_sqlsyntax -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Syntax wrappers and generators for SQL. """ from itertools import count, repeat from functools import partial from operator import eq, ne from zope.interface import implements from twisted.internet.defer import succeed from twext.enterprise.dal.model import Schema, Table, Column, Sequence, SQLType from twext.enterprise.ienterprise import ( POSTGRES_DIALECT, ORACLE_DIALECT, SQLITE_DIALECT, IDerivedParameter ) from twext.enterprise.util import mapOracleOutputType from twisted.internet.defer import inlineCallbacks, returnValue try: import cx_Oracle cx_Oracle except ImportError: cx_Oracle = None class DALError(Exception): """ Base class for exceptions raised by this module. This can be raised directly for API violations. This exception represents a serious programming error and should normally never be caught or ignored. """ class QueryPlaceholder(object): """ Representation of the placeholders required to generate some SQL, for a single statement. Contains information necessary to generate place holder strings based on the database dialect. """ def placeholder(self): raise NotImplementedError("See subclasses.") class FixedPlaceholder(QueryPlaceholder): """ Fixed string used as the place holder. """ def __init__(self, placeholder): self._placeholder = placeholder def placeholder(self): return self._placeholder class NumericPlaceholder(QueryPlaceholder): """ Numeric counter used as the place holder. """ def __init__(self): self._next = count(1).next def placeholder(self): return ':' + str(self._next()) def defaultPlaceholder(): """ Generate a default L{QueryPlaceholder} """ return FixedPlaceholder('?') class QueryGenerator(object): """ Maintains various pieces of transient information needed when building a query. This includes the SQL dialect, the format of the place holder and and automated id generator. """ def __init__(self, dialect=None, placeholder=None): self.dialect = dialect if dialect else POSTGRES_DIALECT if placeholder is None: placeholder = defaultPlaceholder() self.placeholder = placeholder self.generatedID = count(1).next def nextGeneratedID(self): return "genid_%d" % (self.generatedID(),) def shouldQuote(self, name): return (self.dialect == ORACLE_DIALECT and name.lower() in _KEYWORDS) class TableMismatch(Exception): """ A table in a statement did not match with a column. """ class NotEnoughValues(DALError): """ Not enough values were supplied for an L{Insert}. """ class _Statement(object): """ An SQL statement that may be executed. (An abstract base class, must implement several methods.) """ _paramstyles = { 'pyformat': partial(FixedPlaceholder, "%s"), 'numeric': NumericPlaceholder, 'qmark': defaultPlaceholder, } def toSQL(self, queryGenerator=None): if queryGenerator is None: queryGenerator = QueryGenerator() return self._toSQL(queryGenerator) def _extraVars(self, txn, queryGenerator): """ A hook for subclasses to provide additional keyword arguments to the C{bind} call when L{_Statement.on} is executed. Currently this is used only for 'out' parameters to capture results when executing statements that do not normally have a result (L{Insert}, L{Delete}, L{Update}). """ return {} def _extraResult(self, result, outvars, queryGenerator): """ A hook for subclasses to manipulate the results of 'on', after they've been retrieved by the database but before they've been given to application code. @param result: a L{Deferred} that will fire with the rows as returned by the database. @type result: C{list} of rows, which are C{list}s or C{tuple}s. @param outvars: a dictionary of extra variables returned by C{self._extraVars}. @param queryGenerator: information about the connection where the statement was executed. @type queryGenerator: L{QueryGenerator} (a subclass thereof) @return: the result to be returned from L{_Statement.on}. @rtype: L{Deferred} firing result rows """ return result def on(self, txn, raiseOnZeroRowCount=None, **kw): """ Execute this statement on a given L{IAsyncTransaction} and return the resulting L{Deferred}. @param txn: the L{IAsyncTransaction} to execute this on. @param raiseOnZeroRowCount: a 0-argument callable which returns an exception to raise if the executed SQL does not affect any rows. @param kw: keyword arguments, mapping names of L{Parameter} objects located somewhere in C{self} @return: results from the database. @rtype: a L{Deferred} firing a C{list} of records (C{tuple}s or C{list}s) """ queryGenerator = QueryGenerator(txn.dialect, self._paramstyles[txn.paramstyle]()) outvars = self._extraVars(txn, queryGenerator) kw.update(outvars) fragment = self.toSQL(queryGenerator).bind(**kw) result = txn.execSQL(fragment.text, fragment.parameters, raiseOnZeroRowCount) result = self._extraResult(result, outvars, queryGenerator) if queryGenerator.dialect == ORACLE_DIALECT and result: result.addCallback(self._fixOracleNulls) return result def _resultColumns(self): """ Subclasses must implement this to return a description of the columns expected to be returned. This is a list of L{ColumnSyntax} objects, and possibly other expression syntaxes which will be converted to C{None}. """ raise NotImplementedError( "Each statement subclass must describe its result" ) def _resultShape(self): """ Process the result of the subclass's C{_resultColumns}, as described in the docstring above. """ for expectation in self._resultColumns(): if isinstance(expectation, ColumnSyntax): yield expectation.model else: yield None def _fixOracleNulls(self, rows): """ Oracle treats empty strings as C{NULL}. Fix this by looking at the columns we expect to have returned, and replacing any C{None}s with empty strings in the appropriate position. """ if rows is None: return None newRows = [] for row in rows: newRow = [] for column, description in zip(row, self._resultShape()): if ((description is not None and # FIXME: "is the python type str" is what I mean; this list # should be more centrally maintained description.type.name in ('varchar', 'text', 'char') and column is None )): column = '' newRow.append(column) newRows.append(newRow) return newRows class Syntax(object): """ Base class for syntactic convenience. This class will define dynamic attribute access to represent its underlying model as a Python namespace. You can access the underlying model as '.model'. """ modelType = None model = None def __init__(self, model): if not isinstance(model, self.modelType): # make sure we don't get a misleading repr() raise DALError("type mismatch: %r %r", type(self), model) self.model = model def __repr__(self): if self.model is not None: return '' % (self.model,) return super(Syntax, self).__repr__() def comparison(comparator): def __(self, other): if other is None: return NullComparison(self, comparator) if isinstance(other, Select): return NotImplemented if isinstance(other, ColumnSyntax): return ColumnComparison(self, comparator, other) if isinstance(other, ExpressionSyntax): return CompoundComparison(self, comparator, other) else: return CompoundComparison(self, comparator, Constant(other)) return __ class ExpressionSyntax(Syntax): __eq__ = comparison('=') __ne__ = comparison('!=') # NB: these operators "cannot be used with lists" (see ORA-01796) __gt__ = comparison('>') __ge__ = comparison('>=') __lt__ = comparison('<') __le__ = comparison('<=') # TODO: operators aren't really comparisons; these should behave slightly # differently. (For example; in Oracle, 'select 3 = 4 from dual' doesn't # work, but 'select 3 + 4 from dual' does; similarly, you can't do 'select * # from foo where 3 + 4', but you can do 'select * from foo where 3 + 4 > # 0'.) __add__ = comparison("+") __sub__ = comparison("-") __div__ = comparison("/") __mul__ = comparison("*") def __nonzero__(self): raise DALError( "SQL expressions should not be tested for truth value in Python.") def In(self, other): """ We support two forms of the SQL "IN" syntax: one where a list of values is supplied, the other where a sub-select is used to provide a set of values. @param other: a constant parameter or sub-select @type other: L{Parameter} or L{Select} """ if isinstance(other, Parameter): if other.count is None: raise DALError("IN expression needs an explicit count of parameters") return CompoundComparison(self, 'in', Constant(other)) else: # Can't be Select.__contains__ because __contains__ gets __nonzero__ # called on its result by the 'in' syntax. return CompoundComparison(self, 'in', other) def StartsWith(self, other): return CompoundComparison(self, "like", CompoundComparison(Constant(other), '||', Constant('%'))) def EndsWith(self, other): return CompoundComparison(self, "like", CompoundComparison(Constant('%'), '||', Constant(other))) def Contains(self, other): return CompoundComparison(self, "like", CompoundComparison(Constant('%'), '||', CompoundComparison(Constant(other), '||', Constant('%')))) class FunctionInvocation(ExpressionSyntax): def __init__(self, function, *args): self.function = function self.args = args def allColumns(self): """ All of the columns in all of the arguments' columns. """ def ac(): for arg in self.args: for column in arg.allColumns(): yield column return list(ac()) def subSQL(self, queryGenerator, allTables): result = SQLFragment(self.function.nameFor(queryGenerator)) result.append(_inParens( _commaJoined(_convert(arg).subSQL(queryGenerator, allTables) for arg in self.args))) return result class Constant(ExpressionSyntax): """ Generates an expression for a place holder where a value will be bound to the query. If the constant is a Parameter with count > 1 then a parenthesized, comma-separated list of place holders will be generated. """ def __init__(self, value): self.value = value def allColumns(self): return [] def subSQL(self, queryGenerator, allTables): if isinstance(self.value, Parameter) and self.value.count is not None: return _inParens(_CommaList( [SQLFragment(queryGenerator.placeholder.placeholder(), [self.value] if ctr == 0 else []) for ctr in range(self.value.count)] ).subSQL(queryGenerator, allTables)) else: return SQLFragment(queryGenerator.placeholder.placeholder(), [self.value]) class NamedValue(ExpressionSyntax): """ A constant within the database; something predefined, such as CURRENT_TIMESTAMP. """ def __init__(self, name): self.name = name def subSQL(self, queryGenerator, allTables): return SQLFragment(self.name) class Function(object): """ An L{Function} is a representation of an SQL Function function. """ def __init__(self, name, oracleName=None): self.name = name self.oracleName = oracleName def nameFor(self, queryGenerator): if queryGenerator.dialect == ORACLE_DIALECT and self.oracleName is not None: return self.oracleName return self.name def __call__(self, *args): """ Produce an L{FunctionInvocation} """ return FunctionInvocation(self, *args) Count = Function("count") Sum = Function("sum") Max = Function("max") Len = Function("character_length", "length") Upper = Function("upper") Lower = Function("lower") _sqliteLastInsertRowID = Function("last_insert_rowid") # Use a specific value here for "the convention for case-insensitive values in # the database" so we don't need to keep remembering whether it's upper or # lowercase. CaseFold = Lower class SchemaSyntax(Syntax): """ Syntactic convenience for L{Schema}. """ modelType = Schema def __getattr__(self, attr): try: tableModel = self.model.tableNamed(attr) except KeyError: try: seqModel = self.model.sequenceNamed(attr) except KeyError: raise AttributeError("schema has no table or sequence %r" % (attr,)) else: return SequenceSyntax(seqModel) else: syntax = TableSyntax(tableModel) # Needs to be preserved here so that aliasing will work. setattr(self, attr, syntax) return syntax def __iter__(self): for table in self.model.tables: yield TableSyntax(table) class SequenceSyntax(ExpressionSyntax): """ Syntactic convenience for L{Sequence}. """ modelType = Sequence def subSQL(self, queryGenerator, allTables): """ Convert to an SQL fragment. """ if queryGenerator.dialect == ORACLE_DIALECT: fmt = "%s.nextval" else: fmt = "nextval('%s')" return SQLFragment(fmt % (self.model.name,)) def _nameForDialect(name, dialect): """ If the given name is being computed in the oracle dialect, truncate it to 30 characters. """ if dialect == ORACLE_DIALECT: name = name[:30] return name class TableSyntax(Syntax): """ Syntactic convenience for L{Table}. """ modelType = Table def alias(self): """ Return an alias for this L{TableSyntax} so that it might be joined against itself. As in SQL, C{someTable.join(someTable)} is an error; you can't join a table against itself. However, C{t = someTable.alias(); someTable.join(t)} is usable as a 'from' clause. """ return TableAlias(self.model) def join(self, otherTableSyntax, on=None, type=''): """ Create a L{Join}, representing a join between two tables. """ if on is None: type = 'cross' return Join(self, type, otherTableSyntax, on) def subSQL(self, queryGenerator, allTables): """ Generate the L{SQLFragment} for this table's identification; this is for use in a 'from' clause. """ # XXX maybe there should be a specific method which is only invoked # from the FROM clause, that only tables and joins would implement? return SQLFragment(_nameForDialect(self.model.name, queryGenerator.dialect)) def __getattr__(self, attr): """ Attributes named after columns on a L{TableSyntax} are returned by accessing their names as attributes. For example, if there is a schema syntax object created from SQL equivalent to 'create table foo (bar integer, baz integer)', 'schemaSyntax.foo.bar' and 'schemaSyntax.foo.baz' """ try: column = self.model.columnNamed(attr) except KeyError: raise AttributeError("table {0} has no column {1}".format( self.model.name, attr )) else: return ColumnSyntax(column) def __iter__(self): """ Yield a L{ColumnSyntax} for each L{Column} in this L{TableSyntax}'s model's table. """ for column in self.model.columns: yield ColumnSyntax(column) def tables(self): """ Return a C{list} of tables involved in the query by this table. (This method is expected by anything that can act as the C{From} clause: see L{Join.tables}) """ return [self] def columnAliases(self): """ Inspect the Python aliases for this table in the given schema. Python aliases for a table are created by setting an attribute on the schema. For example, in a schema which had "schema.MYTABLE.ID = schema.MYTABLE.MYTABLE_ID" applied to it, schema.MYTABLE.columnAliases() would return C{[("ID", schema.MYTABLE.MYTABLE_ID)]}. @return: a list of 2-tuples of (alias (C{str}), column (C{ColumnSyntax})), enumerating all of the Python aliases provided. """ result = {} for k, v in self.__dict__.items(): if isinstance(v, ColumnSyntax): result[k] = v return result def __contains__(self, columnSyntax): if isinstance(columnSyntax, FunctionInvocation): columnSyntax = columnSyntax.arg return (columnSyntax.model.table is self.model) class TableAlias(TableSyntax): """ An alias for a table, under a different name, for the purpose of doing a self-join. """ def subSQL(self, queryGenerator, allTables): """ Return an L{SQLFragment} with a string of the form C{'mytable myalias'} suitable for use in a FROM clause. """ result = super(TableAlias, self).subSQL(queryGenerator, allTables) result.append(SQLFragment(" " + self._aliasName(allTables))) return result def _aliasName(self, allTables): """ The alias under which this table will be known in the query. @param allTables: a C{list}, as passed to a C{subSQL} method during SQL generation. @return: a string naming this alias, a unique identifier, albeit one which is only stable within the query which populated C{allTables}. @rtype: C{str} """ anum = [t for t in allTables if isinstance(t, TableAlias)].index(self) + 1 return 'alias%d' % (anum,) def __getattr__(self, attr): return AliasedColumnSyntax(self, self.model.columnNamed(attr)) class Join(object): """ A DAL object representing an SQL 'join' statement. @ivar leftSide: a L{Join} or L{TableSyntax} representing the left side of this join. @ivar rightSide: a L{TableSyntax} representing the right side of this join. @ivar type: the type of join this is. For example, for a left outer join, this would be C{'left outer'}. @type type: C{str} @ivar on: the 'on' clause of this table. @type on: L{ExpressionSyntax} """ def __init__(self, leftSide, type, rightSide, on): self.leftSide = leftSide self.type = type self.rightSide = rightSide self.on = on def subSQL(self, queryGenerator, allTables): stmt = SQLFragment() stmt.append(self.leftSide.subSQL(queryGenerator, allTables)) stmt.text += ' ' if self.type: stmt.text += self.type stmt.text += ' ' stmt.text += 'join ' stmt.append(self.rightSide.subSQL(queryGenerator, allTables)) if self.type != 'cross': stmt.text += ' on ' stmt.append(self.on.subSQL(queryGenerator, allTables)) return stmt def tables(self): """ Return a C{list} of tables which this L{Join} will involve in a query: all those present on the left side, as well as all those present on the right side. """ return self.leftSide.tables() + self.rightSide.tables() def join(self, otherTable, on=None, type=None): if on is None: type = 'cross' return Join(self, type, otherTable, on) _KEYWORDS = ["access", # SQL keyword, but we have a column with this name "path", # Not actually a standard keyword, but a function in oracle, and we # have a column with this name. "size", # not actually sure what this is; only experimentally determined # that not quoting it causes an issue. ] class ColumnSyntax(ExpressionSyntax): """ Syntactic convenience for L{Column}. @ivar _alwaysQualified: a boolean indicating whether to always qualify the column name in generated SQL, regardless of whether the column name is specific enough even when unqualified. @type _alwaysQualified: C{bool} """ modelType = Column _alwaysQualified = False def allColumns(self): return [self] def subSQL(self, queryGenerator, allTables): # XXX This, and 'model', could in principle conflict with column names. # Maybe do something about that. name = self.model.name if queryGenerator.shouldQuote(name): name = '"%s"' % (name,) if self._alwaysQualified: qualified = True else: qualified = False for tableSyntax in allTables: if self.model.table is not tableSyntax.model: if self.model.name in (c.name for c in tableSyntax.model.columns): qualified = True break if qualified: return SQLFragment(self._qualify(name, allTables)) else: return SQLFragment(name) def __hash__(self): return hash(self.model) + 10 def _qualify(self, name, allTables): return self.model.table.name + '.' + name class ResultAliasSyntax(ExpressionSyntax): def __init__(self, expression, alias=None): self.expression = expression self.alias = alias def aliasName(self, queryGenerator): if self.alias is None: self.alias = queryGenerator.nextGeneratedID() return self.alias def columnReference(self): return AliasReferenceSyntax(self) def allColumns(self): return self.expression.allColumns() def subSQL(self, queryGenerator, allTables): result = SQLFragment() result.append(self.expression.subSQL(queryGenerator, allTables)) result.append(SQLFragment(" %s" % (self.aliasName(queryGenerator),))) return result class AliasReferenceSyntax(ExpressionSyntax): def __init__(self, resultAlias): self.resultAlias = resultAlias def allColumns(self): return self.resultAlias.allColumns() def subSQL(self, queryGenerator, allTables): return SQLFragment(self.resultAlias.aliasName(queryGenerator)) class AliasedColumnSyntax(ColumnSyntax): """ An L{AliasedColumnSyntax} is like a L{ColumnSyntax}, but it generates SQL for a column of a table under an alias, rather than directly. i.e. this is used for C{'something.col'} in C{'select something.col from tablename something'} rather than the 'col' in C{'select col from tablename'}. @see: L{TableSyntax.alias} """ _alwaysQualified = True def __init__(self, tableAlias, model): super(AliasedColumnSyntax, self).__init__(model) self._tableAlias = tableAlias def _qualify(self, name, allTables): return self._tableAlias._aliasName(allTables) + '.' + name class Comparison(ExpressionSyntax): def __init__(self, a, op, b): self.a = a self.op = op self.b = b def _subexpression(self, expr, queryGenerator, allTables): result = expr.subSQL(queryGenerator, allTables) if self.op not in ('and', 'or') and isinstance(expr, Comparison): result = _inParens(result) return result def booleanOp(self, operand, other): return CompoundComparison(self, operand, other) def And(self, other): return self.booleanOp('and', other) def Or(self, other): return self.booleanOp('or', other) class NullComparison(Comparison): """ A L{NullComparison} is a comparison of a column or expression with None. """ def __init__(self, a, op): # 'b' is always None for this comparison type super(NullComparison, self).__init__(a, op, None) def subSQL(self, queryGenerator, allTables): sqls = SQLFragment() sqls.append(self.a.subSQL(queryGenerator, allTables)) sqls.text += " is " if self.op != "=": sqls.text += "not " sqls.text += "null" return sqls class CompoundComparison(Comparison): """ A compound comparison; two or more constraints, joined by an operation (currently only AND or OR). """ def allColumns(self): return self.a.allColumns() + self.b.allColumns() def subSQL(self, queryGenerator, allTables): if (queryGenerator.dialect == ORACLE_DIALECT and isinstance(self.b, Constant) and self.b.value == '' and self.op in ('=', '!=')): return NullComparison(self.a, self.op).subSQL(queryGenerator, allTables) stmt = SQLFragment() result = self._subexpression(self.a, queryGenerator, allTables) if (isinstance(self.a, CompoundComparison) and self.a.op == 'or' and self.op == 'and'): result = _inParens(result) stmt.append(result) stmt.text += ' %s ' % (self.op,) result = self._subexpression(self.b, queryGenerator, allTables) if (isinstance(self.b, CompoundComparison) and self.b.op == 'or' and self.op == 'and'): result = _inParens(result) if isinstance(self.b, Tuple): # If the right-hand side of the comparison is a Tuple, it needs to # be double-parenthesized in Oracle, as per # http://docs.oracle.com/cd/B28359_01/server.111/b28286/expressions015.htm#i1033664 # because it is an expression list. result = _inParens(result) stmt.append(result) return stmt _operators = {"=": eq, "!=": ne} class ColumnComparison(CompoundComparison): """ Comparing two columns is the same as comparing any other two expressions, except that Python can retrieve a truth value, so that columns may be compared for value equality in scripts that want to interrogate schemas. """ def __nonzero__(self): thunk = _operators.get(self.op) if thunk is None: return super(ColumnComparison, self).__nonzero__() return thunk(self.a.model, self.b.model) class _AllColumns(NamedValue): def __init__(self): self.name = "*" def allColumns(self): return [] ALL_COLUMNS = _AllColumns() class _SomeColumns(object): def __init__(self, columns): self.columns = columns def subSQL(self, queryGenerator, allTables): first = True cstatement = SQLFragment() for column in self.columns: if first: first = False else: cstatement.append(SQLFragment(", ")) cstatement.append(column.subSQL(queryGenerator, allTables)) return cstatement def _checkColumnsMatchTables(columns, tables): """ Verify that the given C{columns} match the given C{tables}; that is, that every L{TableSyntax} referenced by every L{ColumnSyntax} referenced by every L{ExpressionSyntax} in the given C{columns} list is present in the given C{tables} list. @param columns: a L{list} of L{ExpressionSyntax}, each of which references some set of L{ColumnSyntax}es via its C{allColumns} method. @param tables: a L{list} of L{TableSyntax} @return: L{None} @rtype: L{NoneType} @raise TableMismatch: if any table referenced by a column is I{not} found in C{tables} """ for expression in columns: for column in expression.allColumns(): for table in tables: if column in table: break else: raise TableMismatch("{} not found in {}".format( column, tables )) return None class Tuple(ExpressionSyntax): def __init__(self, columns): self.columns = columns def __iter__(self): return iter(self.columns) def subSQL(self, queryGenerator, allTables): return _inParens(_commaJoined(c.subSQL(queryGenerator, allTables) for c in self.columns)) def allColumns(self): return self.columns class SetExpression(object): """ A UNION, INTERSECT, or EXCEPT construct used inside a SELECT. """ OPTYPE_ALL = "all" OPTYPE_DISTINCT = "distinct" def __init__(self, selects, optype=None): """ @param selects: a single Select or a list of Selects @type selects: C{list} or L{Select} @param optype: whether to use the ALL, DISTINCT constructs: C{None} use neither, OPTYPE_ALL, or OPTYPE_DISTINCT @type optype: C{str} """ if isinstance(selects, Select): selects = (selects,) self.selects = selects self.optype = optype for select in self.selects: if not isinstance(select, Select): raise DALError("Must have SELECT statements in a set expression") if self.optype not in (None, SetExpression.OPTYPE_ALL, SetExpression.OPTYPE_DISTINCT,): raise DALError("Must have either 'all' or 'distinct' in a set expression") def subSQL(self, queryGenerator, allTables): result = SQLFragment() for select in self.selects: result.append(self.setOpSQL(queryGenerator)) if self.optype == SetExpression.OPTYPE_ALL: result.append(SQLFragment("ALL ")) elif self.optype == SetExpression.OPTYPE_DISTINCT: result.append(SQLFragment("DISTINCT ")) result.append(select.subSQL(queryGenerator, allTables)) return result def allColumns(self): return [] class Union(SetExpression): """ A UNION construct used inside a SELECT. """ def setOpSQL(self, queryGenerator): return SQLFragment(" UNION ") class Intersect(SetExpression): """ An INTERSECT construct used inside a SELECT. """ def setOpSQL(self, queryGenerator): return SQLFragment(" INTERSECT ") class Except(SetExpression): """ An EXCEPT construct used inside a SELECT. """ def setOpSQL(self, queryGenerator): if queryGenerator.dialect == POSTGRES_DIALECT: return SQLFragment(" EXCEPT ") elif queryGenerator.dialect == ORACLE_DIALECT: return SQLFragment(" MINUS ") else: raise NotImplementedError("Unsupported dialect") class Select(_Statement): """ 'select' statement. """ def __init__(self, columns=None, Where=None, From=None, OrderBy=None, GroupBy=None, Limit=None, ForUpdate=False, NoWait=False, Ascending=None, Having=None, Distinct=False, As=None, SetExpression=None): self.From = From self.Where = Where self.Distinct = Distinct if not isinstance(OrderBy, (Tuple, list, tuple, type(None))): OrderBy = [OrderBy] self.OrderBy = OrderBy if not isinstance(GroupBy, (list, tuple, type(None))): GroupBy = [GroupBy] self.GroupBy = GroupBy self.Limit = Limit self.Having = Having self.SetExpression = SetExpression if columns is None: columns = ALL_COLUMNS else: _checkColumnsMatchTables(columns, From.tables()) columns = _SomeColumns(columns) self.columns = columns self.ForUpdate = ForUpdate self.NoWait = NoWait self.Ascending = Ascending self.As = As # A FROM that uses a sub-select will need the AS alias name if isinstance(self.From, Select): if self.From.As is None: self.From.As = "" def __eq__(self, other): """ Create a comparison. """ if isinstance(other, (list, tuple)): other = Tuple(other) return CompoundComparison(other, '=', self) def _toSQL(self, queryGenerator): """ @return: a 'select' statement with placeholders and arguments @rtype: L{SQLFragment} """ if self.SetExpression is not None: stmt = SQLFragment("(") else: stmt = SQLFragment() stmt.append(SQLFragment("select ")) if self.Distinct: stmt.text += "distinct " allTables = self.From.tables() stmt.append(self.columns.subSQL(queryGenerator, allTables)) stmt.text += " from " stmt.append(self.From.subSQL(queryGenerator, allTables)) if self.Where is not None: wherestmt = self.Where.subSQL(queryGenerator, allTables) stmt.text += " where " stmt.append(wherestmt) if self.GroupBy is not None: stmt.text += " group by " fst = True for subthing in self.GroupBy: if fst: fst = False else: stmt.text += ', ' stmt.append(subthing.subSQL(queryGenerator, allTables)) if self.Having is not None: havingstmt = self.Having.subSQL(queryGenerator, allTables) stmt.text += " having " stmt.append(havingstmt) if self.SetExpression is not None: stmt.append(SQLFragment(")")) stmt.append(self.SetExpression.subSQL(queryGenerator, allTables)) if self.OrderBy is not None: stmt.text += " order by " fst = True for subthing in self.OrderBy: if fst: fst = False else: stmt.text += ', ' stmt.append(subthing.subSQL(queryGenerator, allTables)) if self.Ascending is not None: if self.Ascending: kw = " asc" else: kw = " desc" stmt.append(SQLFragment(kw)) if self.ForUpdate: stmt.text += " for update" if self.NoWait: stmt.text += " nowait" if self.Limit is not None: limitConst = Constant(self.Limit).subSQL(queryGenerator, allTables) if queryGenerator.dialect == ORACLE_DIALECT: wrapper = SQLFragment("select * from (") wrapper.append(stmt) wrapper.append(SQLFragment(") where ROWNUM <= ")) stmt = wrapper else: stmt.text += " limit " stmt.append(limitConst) return stmt def subSQL(self, queryGenerator, allTables): result = SQLFragment("(") result.append(self.toSQL(queryGenerator)) result.append(SQLFragment(")")) if self.As is not None: if self.As == "": self.As = queryGenerator.nextGeneratedID() result.append(SQLFragment(" %s" % (self.As,))) return result def _resultColumns(self): """ Determine the list of L{ColumnSyntax} objects that will represent the result. Normally just the list of selected columns; if wildcard syntax is used though, determine the ordering from the database. """ if self.columns is ALL_COLUMNS: # TODO: Possibly this rewriting should always be done, before even # executing the query, so that if we develop a schema mismatch with # the database (additional columns), the application will still see # the right rows. for table in self.From.tables(): for column in table: yield column else: for column in self.columns.columns: yield column def tables(self): """ Determine the tables used by the result columns. """ if self.columns is ALL_COLUMNS: # TODO: Possibly this rewriting should always be done, before even # executing the query, so that if we develop a schema mismatch with # the database (additional columns), the application will still see # the right rows. return self.From.tables() else: tables = set([column.model.table for column in self.columns.columns if isinstance(column, ColumnSyntax)]) for table in self.From.tables(): tables.add(table.model) return [TableSyntax(table) for table in tables] def _commaJoined(stmts): first = True cstatement = SQLFragment() for stmt in stmts: if first: first = False else: cstatement.append(SQLFragment(", ")) cstatement.append(stmt) return cstatement def _inParens(stmt): result = SQLFragment("(") result.append(stmt) result.append(SQLFragment(")")) return result def _fromSameTable(columns): """ Extract the common table used by a list of L{Column} objects, raising L{TableMismatch}. """ table = columns[0].table for column in columns: if table is not column.table: raise TableMismatch("Columns must all be from the same table.") return table def _modelsFromMap(columnMap): """ Get the L{Column} objects from a mapping of L{ColumnSyntax} to values. """ return [c.model for c in columnMap.keys()] class _CommaList(object): def __init__(self, subfragments): self.subfragments = subfragments def subSQL(self, queryGenerator, allTables): return _commaJoined(f.subSQL(queryGenerator, allTables) for f in self.subfragments) class _DMLStatement(_Statement): """ Common functionality of Insert/Update/Delete statements. """ def _returningClause(self, queryGenerator, stmt, allTables): """ Add a dialect-appropriate 'returning' clause to the end of the given SQL statement. @param queryGenerator: describes the database we are generating the statement for. @type queryGenerator: L{QueryGenerator} @param stmt: the SQL fragment generated without the 'returning' clause @type stmt: L{SQLFragment} @param allTables: all tables involved in the query; see any C{subSQL} method. @return: the C{stmt} parameter. """ retclause = self.Return if retclause is None: return stmt if isinstance(retclause, (tuple, list)): retclause = _CommaList(retclause) if queryGenerator.dialect == SQLITE_DIALECT: # sqlite does this another way. return stmt elif retclause is not None: stmt.text += ' returning ' stmt.append(retclause.subSQL(queryGenerator, allTables)) if queryGenerator.dialect == ORACLE_DIALECT: stmt.text += ' into ' params = [] retvals = self._returnAsList() for n, _ignore_v in enumerate(retvals): params.append( Constant(Parameter("oracle_out_" + str(n))) .subSQL(queryGenerator, allTables) ) stmt.append(_commaJoined(params)) return stmt def _returnAsList(self): if not isinstance(self.Return, (tuple, list)): return [self.Return] else: return self.Return def _extraVars(self, txn, queryGenerator): if self.Return is None: return [] result = [] rvars = self._returnAsList() if queryGenerator.dialect == ORACLE_DIALECT: for n, v in enumerate(rvars): result.append(("oracle_out_" + str(n), _OracleOutParam(v))) return result def _extraResult(self, result, outvars, queryGenerator): if queryGenerator.dialect == ORACLE_DIALECT and self.Return is not None: def processIt(shouldBeNone): result = [[v.value for _ignore_k, v in outvars]] return result return result.addCallback(processIt) else: return result def _resultColumns(self): return self._returnAsList() class _OracleOutParam(object): """ A parameter that will be populated using the cx_Oracle API for host variables. """ implements(IDerivedParameter) def __init__(self, columnSyntax): self.typeID = columnSyntax.model.type.name.lower() def preQuery(self, cursor): typeMap = {'integer': cx_Oracle.NUMBER, 'text': cx_Oracle.NCLOB, 'varchar': cx_Oracle.STRING, 'timestamp': cx_Oracle.TIMESTAMP} self.var = cursor.var(typeMap[self.typeID]) return self.var def postQuery(self, cursor): self.value = mapOracleOutputType(self.var.getvalue()) self.var = None class Insert(_DMLStatement): """ 'insert' statement. """ def __init__(self, columnMap, Return=None): self.columnMap = columnMap self.Return = Return columns = _modelsFromMap(columnMap) table = _fromSameTable(columns) required = [column for column in table.columns if column.needsValue()] unspecified = [column for column in required if column not in columns] if unspecified: raise NotEnoughValues( 'Columns [%s] required.' % (', '.join([c.name for c in unspecified]))) def _toSQL(self, queryGenerator): """ @return: a 'insert' statement with placeholders and arguments @rtype: L{SQLFragment} """ columnsAndValues = self.columnMap.items() tableModel = columnsAndValues[0][0].model.table specifiedColumnModels = [x.model for x in self.columnMap.keys()] if queryGenerator.dialect == ORACLE_DIALECT: # See test_nextSequenceDefaultImplicitExplicitOracle. for column in tableModel.columns: if isinstance(column.default, Sequence): columnSyntax = ColumnSyntax(column) if column not in specifiedColumnModels: columnsAndValues.append( (columnSyntax, SequenceSyntax(column.default)) ) sortedColumns = sorted(columnsAndValues, key=lambda (c, v): c.model.name) allTables = [] stmt = SQLFragment('insert into ') stmt.append(TableSyntax(tableModel).subSQL(queryGenerator, allTables)) stmt.append(SQLFragment(" ")) stmt.append(_inParens(_commaJoined( [c.subSQL(queryGenerator, allTables) for (c, _ignore_v) in sortedColumns]))) stmt.append(SQLFragment(" values ")) stmt.append(_inParens(_commaJoined( [_convert(v).subSQL(queryGenerator, allTables) for (c, v) in sortedColumns]))) return self._returningClause(queryGenerator, stmt, allTables) def on(self, txn, *a, **kw): """ Override to provide extra logic for L{Insert}s that return values on databases that don't provide return values as part of their C{INSERT} behavior. """ result = super(_DMLStatement, self).on(txn, *a, **kw) if self.Return is not None and txn.dialect == SQLITE_DIALECT: table = self._returnAsList()[0].model.table return Select(self._returnAsList(), # TODO: error reporting when 'return' includes columns # foreign to the primary table. From=TableSyntax(table), Where=ColumnSyntax(Column(table, "rowid", SQLType("integer", None))) == _sqliteLastInsertRowID() ).on(txn, *a, **kw) return result def _convert(x): """ Convert a value to an appropriate SQL AST node. (Currently a simple isinstance, could be promoted to use adaptation if we want to get fancy.) """ if isinstance(x, ExpressionSyntax): return x else: return Constant(x) class Update(_DMLStatement): """ 'update' statement @ivar columnMap: A L{dict} mapping L{ColumnSyntax} objects to values to change; values may be simple database values (such as L{str}, L{unicode}, L{datetime.datetime}, L{float}, L{int} etc) or L{Parameter} instances. @type columnMap: L{dict} """ def __init__(self, columnMap, Where, Return=None): super(Update, self).__init__() _fromSameTable(_modelsFromMap(columnMap)) self.columnMap = columnMap self.Where = Where self.Return = Return @inlineCallbacks def on(self, txn, *a, **kw): """ Override to provide extra logic for L{Update}s that return values on databases that don't provide return values as part of their C{UPDATE} behavior. """ doExtra = self.Return is not None and txn.dialect == SQLITE_DIALECT upcall = lambda: super(_DMLStatement, self).on(txn, *a, **kw) if doExtra: table = self._returnAsList()[0].model.table rowidcol = ColumnSyntax(Column(table, "rowid", SQLType("integer", None))) prequery = Select([rowidcol], From=TableSyntax(table), Where=self.Where) preresult = prequery.on(txn, *a, **kw) before = yield preresult yield upcall() result = (yield Select(self._returnAsList(), # TODO: error reporting when 'return' includes # columns foreign to the primary table. From=TableSyntax(table), Where=reduce(lambda left, right: left.Or(right), ((rowidcol == x) for [x] in before)) ).on(txn, *a, **kw)) returnValue(result) else: returnValue((yield upcall())) def _toSQL(self, queryGenerator): """ @return: a 'insert' statement with placeholders and arguments @rtype: L{SQLFragment} """ sortedColumns = sorted(self.columnMap.items(), key=lambda (c, v): c.model.name) allTables = [] result = SQLFragment('update ') result.append( TableSyntax(sortedColumns[0][0].model.table).subSQL( queryGenerator, allTables) ) result.text += ' set ' result.append( _commaJoined( [c.subSQL(queryGenerator, allTables).append( SQLFragment(" = ").subSQL(queryGenerator, allTables) ).append(_convert(v).subSQL(queryGenerator, allTables)) for (c, v) in sortedColumns] ) ) if self.Where is not None: result.append(SQLFragment(' where ')) result.append(self.Where.subSQL(queryGenerator, allTables)) return self._returningClause(queryGenerator, result, allTables) class Delete(_DMLStatement): """ 'delete' statement. """ def __init__(self, From, Where, Return=None): """ If Where is None then all rows will be deleted. """ self.From = From self.Where = Where self.Return = Return def _toSQL(self, queryGenerator): result = SQLFragment() allTables = self.From.tables() result.text += 'delete from ' result.append(self.From.subSQL(queryGenerator, allTables)) if self.Where is not None: result.text += ' where ' result.append(self.Where.subSQL(queryGenerator, allTables)) return self._returningClause(queryGenerator, result, allTables) @inlineCallbacks def on(self, txn, *a, **kw): upcall = lambda: super(Delete, self).on(txn, *a, **kw) if txn.dialect == SQLITE_DIALECT and self.Return is not None: result = yield Select(self._returnAsList(), From=self.From, Where=self.Where).on(txn, *a, **kw) yield upcall() else: result = yield upcall() returnValue(result) class _LockingStatement(_Statement): """ A statement related to lock management, which implicitly has no results. """ def _resultColumns(self): """ No columns should be expected, so return an infinite iterator of None. """ return repeat(None) class Lock(_LockingStatement): """ An SQL 'lock' statement. """ def __init__(self, table, mode): self.table = table self.mode = mode @classmethod def exclusive(cls, table): return cls(table, 'exclusive') def _toSQL(self, queryGenerator): if queryGenerator.dialect == SQLITE_DIALECT: # FIXME - this is only stubbed out for testing right now, actual # concurrency would require some kind of locking statement here. # BEGIN IMMEDIATE maybe, if that's okay in the middle of a # transaction or repeatedly? return SQLFragment('select null') return SQLFragment('lock table ').append( self.table.subSQL(queryGenerator, [self.table])).append( SQLFragment(' in %s mode' % (self.mode,))) class DatabaseLock(_LockingStatement): """ An SQL exclusive session level advisory lock """ def _toSQL(self, queryGenerator): assert(queryGenerator.dialect == POSTGRES_DIALECT) return SQLFragment('select pg_advisory_lock(1)') def on(self, txn, *a, **kw): """ Override on() to only execute on Postgres """ if txn.dialect == POSTGRES_DIALECT: return super(DatabaseLock, self).on(txn, *a, **kw) return succeed(None) class DatabaseUnlock(_LockingStatement): """ An SQL exclusive session level advisory lock """ def _toSQL(self, queryGenerator): assert(queryGenerator.dialect == POSTGRES_DIALECT) return SQLFragment('select pg_advisory_unlock(1)') def on(self, txn, *a, **kw): """ Override on() to only execute on Postgres """ if txn.dialect == POSTGRES_DIALECT: return super(DatabaseUnlock, self).on(txn, *a, **kw) return succeed(None) class Savepoint(_LockingStatement): """ An SQL 'savepoint' statement. """ def __init__(self, name): self.name = name def _toSQL(self, queryGenerator): return SQLFragment('savepoint %s' % (self.name,)) class RollbackToSavepoint(_LockingStatement): """ An SQL 'rollback to savepoint' statement. """ def __init__(self, name): self.name = name def _toSQL(self, queryGenerator): return SQLFragment('rollback to savepoint %s' % (self.name,)) class ReleaseSavepoint(_LockingStatement): """ An SQL 'release savepoint' statement. """ def __init__(self, name): self.name = name def _toSQL(self, queryGenerator): return SQLFragment('release savepoint %s' % (self.name,)) class SavepointAction(object): def __init__(self, name): self._name = name def acquire(self, txn): return Savepoint(self._name).on(txn) def rollback(self, txn): return RollbackToSavepoint(self._name).on(txn) def release(self, txn): if txn.dialect == ORACLE_DIALECT: # There is no 'release savepoint' statement in oracle, but then, we # don't need it because there's no resource to manage. Just don't # do anything. return NoOp() else: return ReleaseSavepoint(self._name).on(txn) class NoOp(object): def on(self, *a, **kw): return succeed(None) class SQLFragment(object): """ Combination of SQL text and arguments; a statement which may be executed against a database. """ def __init__(self, text="", parameters=None): self.text = text if parameters is None: parameters = [] self.parameters = parameters def bind(self, **kw): params = [] for parameter in self.parameters: if isinstance(parameter, Parameter): if parameter.count is not None: if parameter.count != len(kw[parameter.name]): raise DALError("Number of place holders does not match number of items to bind") for item in kw[parameter.name]: params.append(item) else: params.append(kw[parameter.name]) else: params.append(parameter) return SQLFragment(self.text, params) def append(self, anotherStatement): self.text += anotherStatement.text self.parameters += anotherStatement.parameters return self def __eq__(self, stmt): if not isinstance(stmt, SQLFragment): return NotImplemented return (self.text, self.parameters) == (stmt.text, stmt.parameters) def __ne__(self, stmt): if not isinstance(stmt, SQLFragment): return NotImplemented return not self.__eq__(stmt) def __repr__(self): return self.__class__.__name__ + repr((self.text, self.parameters)) def subSQL(self, queryGenerator, allTables): return self class Parameter(object): """ Used to represent a place holder for a value to be bound to the query at a later date. If count > 1, then a "set" of parenthesized, comma separate place holders will be generated. """ def __init__(self, name, count=None): self.name = name self.count = count if self.count is not None and self.count < 1: raise DALError("Must have Parameter.count > 0") def __eq__(self, param): if not isinstance(param, Parameter): return NotImplemented return self.name == param.name and self.count == param.count def __ne__(self, param): if not isinstance(param, Parameter): return NotImplemented return not self.__eq__(param) def __repr__(self): return 'Parameter(%r)' % (self.name,) # Common helpers: # current timestamp in UTC format. Hack to support standard syntax for this, # rather than the compatibility procedure found in various databases. utcNowSQL = NamedValue("CURRENT_TIMESTAMP at time zone 'UTC'") # You can't insert a column with no rows. In SQL that just isn't valid syntax, # and in this DAL you need at least one key or we can't tell what table you're # talking about. Luckily there's the 'default' keyword to the rescue, which, in # the context of an INSERT statement means 'use the default value explicitly'. # (Although this is a special keyword in a CREATE statement, in an INSERT it # behaves like an expression to the best of my knowledge.) default = NamedValue('default') calendarserver-5.2+dfsg/twext/enterprise/dal/__init__.py0000644000175000017500000000205112263343324022513 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Toolkit for building a Data-Access Layer (DAL). This includes an abstract representation of SQL objects like tables, columns, sequences and queries, a parser to convert your schema to that representation, and tools for working with it. In some ways this is similar to the low levels of something like SQLAlchemy, but it is designed to be more introspectable, to allow for features like automatic caching and index detection. NB: work in progress. """ calendarserver-5.2+dfsg/twext/enterprise/util.py0000644000175000017500000000624012263343324021175 0ustar rahulrahul# -*- test-case-name: twext.enterprise.test.test_util -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Utilities for dealing with different databases. """ from datetime import datetime SQL_TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S.%f" def parseSQLTimestamp(ts): """ Parse an SQL timestamp string. """ # Handle case where fraction seconds may not be present if len(ts) < len(SQL_TIMESTAMP_FORMAT): ts += ".0" return datetime.strptime(ts, SQL_TIMESTAMP_FORMAT) def mapOracleOutputType(column): """ Map a single output value from cx_Oracle based on some rules and expectations that we have based on the pgdb bindings. @param column: a single value from a column. @return: a converted value based on the type of the input; oracle CLOBs and datetime timestamps will be converted to strings, unicode values will be converted to UTF-8 encoded byte sequences (C{str}s), and floating point numbers will be converted to integer types if they are integers. Any other types will be left alone. """ if hasattr(column, 'read'): # Try to detect large objects and format convert them to # strings on the fly. We need to do this as we read each # row, due to the issue described here - # http://cx-oracle.sourceforge.net/html/lob.html - in # particular, the part where it says "In particular, do not # use the fetchall() method". column = column.read() elif isinstance(column, datetime): # cx_Oracle properly maps the type of timestamps to datetime # objects. However, our code is mostly written against # PyGreSQL, which just emits strings as results and expects # to have to convert them itself.. Since it's easier to # just detect the datetimes and stringify them, for now # we'll do that. return column.strftime(SQL_TIMESTAMP_FORMAT) elif isinstance(column, float): # cx_Oracle maps _all_ nubmers to float types, which is more consistent, # but we expect the database to be able to store integers as integers # (in fact almost all the values in our schema are integers), so we map # those values which exactly match back into integers. if int(column) == column: return int(column) else: return column if isinstance(column, unicode): # Finally, we process all data as UTF-8 bytestrings in order to reduce # memory consumption. Pass any unicode string values back to the # application as unicode. column = column.encode('utf-8') return column calendarserver-5.2+dfsg/twext/protocols/0000755000175000017500000000000012322625326017511 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/protocols/test/0000755000175000017500000000000012322625326020470 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/protocols/test/test_memcache.py0000644000175000017500000004450111156045201023637 0ustar rahulrahul# Copyright (c) 2007-2009 Twisted Matrix Laboratories. # See LICENSE for details. """ Test the memcache client protocol. """ from twext.protocols.memcache import MemCacheProtocol, NoSuchCommand from twext.protocols.memcache import ClientError, ServerError from twisted.trial.unittest import TestCase from twisted.test.proto_helpers import StringTransportWithDisconnection from twisted.internet.task import Clock from twisted.internet.defer import Deferred, gatherResults, TimeoutError class MemCacheTestCase(TestCase): """ Test client protocol class L{MemCacheProtocol}. """ def setUp(self): """ Create a memcache client, connect it to a string protocol, and make it use a deterministic clock. """ self.proto = MemCacheProtocol() self.clock = Clock() self.proto.callLater = self.clock.callLater self.transport = StringTransportWithDisconnection() self.transport.protocol = self.proto self.proto.makeConnection(self.transport) def _test(self, d, send, recv, result): """ Shortcut method for classic tests. @param d: the resulting deferred from the memcache command. @type d: C{Deferred} @param send: the expected data to be sent. @type send: C{str} @param recv: the data to simulate as reception. @type recv: C{str} @param result: the expected result. @type result: C{any} """ def cb(res): self.assertEquals(res, result) self.assertEquals(self.transport.value(), send) d.addCallback(cb) self.proto.dataReceived(recv) return d def test_get(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ return self._test(self.proto.get("foo"), "get foo\r\n", "VALUE foo 0 3\r\nbar\r\nEND\r\n", (0, "bar")) def test_emptyGet(self): """ Test getting a non-available key: it should succeed but return C{None} as value and C{0} as flag. """ return self._test(self.proto.get("foo"), "get foo\r\n", "END\r\n", (0, None)) def test_set(self): """ L{MemCacheProtocol.set} should return a L{Deferred} which is called back with C{True} when the operation succeeds. """ return self._test(self.proto.set("foo", "bar"), "set foo 0 0 3\r\nbar\r\n", "STORED\r\n", True) def test_add(self): """ L{MemCacheProtocol.add} should return a L{Deferred} which is called back with C{True} when the operation succeeds. """ return self._test(self.proto.add("foo", "bar"), "add foo 0 0 3\r\nbar\r\n", "STORED\r\n", True) def test_replace(self): """ L{MemCacheProtocol.replace} should return a L{Deferred} which is called back with C{True} when the operation succeeds. """ return self._test(self.proto.replace("foo", "bar"), "replace foo 0 0 3\r\nbar\r\n", "STORED\r\n", True) def test_errorAdd(self): """ Test an erroneous add: if a L{MemCacheProtocol.add} is called but the key already exists on the server, it returns a B{NOT STORED} answer, which should callback the resulting L{Deferred} with C{False}. """ return self._test(self.proto.add("foo", "bar"), "add foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False) def test_errorReplace(self): """ Test an erroneous replace: if a L{MemCacheProtocol.replace} is called but the key doesn't exist on the server, it returns a B{NOT STORED} answer, which should callback the resulting L{Deferred} with C{False}. """ return self._test(self.proto.replace("foo", "bar"), "replace foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False) def test_delete(self): """ L{MemCacheProtocol.delete} should return a L{Deferred} which is called back with C{True} when the server notifies a success. """ return self._test(self.proto.delete("bar"), "delete bar\r\n", "DELETED\r\n", True) def test_errorDelete(self): """ Test a error during a delete: if key doesn't exist on the server, it returns a B{NOT FOUND} answer which should callback the resulting L{Deferred} with C{False}. """ return self._test(self.proto.delete("bar"), "delete bar\r\n", "NOT FOUND\r\n", False) def test_increment(self): """ Test incrementing a variable: L{MemCacheProtocol.increment} should return a L{Deferred} which is called back with the incremented value of the given key. """ return self._test(self.proto.increment("foo"), "incr foo 1\r\n", "4\r\n", 4) def test_decrement(self): """ Test decrementing a variable: L{MemCacheProtocol.decrement} should return a L{Deferred} which is called back with the decremented value of the given key. """ return self._test( self.proto.decrement("foo"), "decr foo 1\r\n", "5\r\n", 5) def test_incrementVal(self): """ L{MemCacheProtocol.increment} takes an optional argument C{value} which should replace the default value of 1 when specified. """ return self._test(self.proto.increment("foo", 8), "incr foo 8\r\n", "4\r\n", 4) def test_decrementVal(self): """ L{MemCacheProtocol.decrement} takes an optional argument C{value} which should replace the default value of 1 when specified. """ return self._test(self.proto.decrement("foo", 3), "decr foo 3\r\n", "5\r\n", 5) def test_stats(self): """ Test retrieving server statistics via the L{MemCacheProtocol.stats} command: it should parse the data sent by the server and call back the resulting L{Deferred} with a dictionary of the received statistics. """ return self._test(self.proto.stats(), "stats\r\n", "STAT foo bar\r\nSTAT egg spam\r\nEND\r\n", {"foo": "bar", "egg": "spam"}) def test_statsWithArgument(self): """ L{MemCacheProtocol.stats} takes an optional C{str} argument which, if specified, is sent along with the I{STAT} command. The I{STAT} responses from the server are parsed as key/value pairs and returned as a C{dict} (as in the case where the argument is not specified). """ return self._test(self.proto.stats("blah"), "stats blah\r\n", "STAT foo bar\r\nSTAT egg spam\r\nEND\r\n", {"foo": "bar", "egg": "spam"}) def test_version(self): """ Test version retrieval via the L{MemCacheProtocol.version} command: it should return a L{Deferred} which is called back with the version sent by the server. """ return self._test(self.proto.version(), "version\r\n", "VERSION 1.1\r\n", "1.1") def test_flushAll(self): """ L{MemCacheProtocol.flushAll} should return a L{Deferred} which is called back with C{True} if the server acknowledges success. """ return self._test(self.proto.flushAll(), "flush_all\r\n", "OK\r\n", True) def test_invalidGetResponse(self): """ If the value returned doesn't match the expected key of the current, we should get an error in L{MemCacheProtocol.dataReceived}. """ self.proto.get("foo") s = "spamegg" self.assertRaises(RuntimeError, self.proto.dataReceived, "VALUE bar 0 %s\r\n%s\r\nEND\r\n" % (len(s), s)) def test_timeOut(self): """ Test the timeout on outgoing requests: when timeout is detected, all current commands should fail with a L{TimeoutError}, and the connection should be closed. """ d1 = self.proto.get("foo") d2 = self.proto.get("bar") d3 = Deferred() self.proto.connectionLost = d3.callback self.clock.advance(self.proto.persistentTimeOut) self.assertFailure(d1, TimeoutError) self.assertFailure(d2, TimeoutError) def checkMessage(error): self.assertEquals(str(error), "Connection timeout") d1.addCallback(checkMessage) return gatherResults([d1, d2, d3]) def test_timeoutRemoved(self): """ When a request gets a response, no pending timeout call should remain around. """ d = self.proto.get("foo") self.clock.advance(self.proto.persistentTimeOut - 1) self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n") def check(result): self.assertEquals(result, (0, "bar")) self.assertEquals(len(self.clock.calls), 0) d.addCallback(check) return d def test_timeOutRaw(self): """ Test the timeout when raw mode was started: the timeout should not be reset until all the data has been received, so we can have a L{TimeoutError} when waiting for raw data. """ d1 = self.proto.get("foo") d2 = Deferred() self.proto.connectionLost = d2.callback self.proto.dataReceived("VALUE foo 0 10\r\n12345") self.clock.advance(self.proto.persistentTimeOut) self.assertFailure(d1, TimeoutError) return gatherResults([d1, d2]) def test_timeOutStat(self): """ Test the timeout when stat command has started: the timeout should not be reset until the final B{END} is received. """ d1 = self.proto.stats() d2 = Deferred() self.proto.connectionLost = d2.callback self.proto.dataReceived("STAT foo bar\r\n") self.clock.advance(self.proto.persistentTimeOut) self.assertFailure(d1, TimeoutError) return gatherResults([d1, d2]) def test_timeoutPipelining(self): """ When two requests are sent, a timeout call should remain around for the second request, and its timeout time should be correct. """ d1 = self.proto.get("foo") d2 = self.proto.get("bar") d3 = Deferred() self.proto.connectionLost = d3.callback self.clock.advance(self.proto.persistentTimeOut - 1) self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n") def check(result): self.assertEquals(result, (0, "bar")) self.assertEquals(len(self.clock.calls), 1) for i in range(self.proto.persistentTimeOut): self.clock.advance(1) return self.assertFailure(d2, TimeoutError).addCallback(checkTime) def checkTime(ignored): # Check that the timeout happened C{self.proto.persistentTimeOut} # after the last response self.assertEquals(self.clock.seconds(), 2 * self.proto.persistentTimeOut - 1) d1.addCallback(check) return d1 def test_timeoutNotReset(self): """ Check that timeout is not resetted for every command, but keep the timeout from the first command without response. """ d1 = self.proto.get("foo") d3 = Deferred() self.proto.connectionLost = d3.callback self.clock.advance(self.proto.persistentTimeOut - 1) d2 = self.proto.get("bar") self.clock.advance(1) self.assertFailure(d1, TimeoutError) self.assertFailure(d2, TimeoutError) return gatherResults([d1, d2, d3]) def test_tooLongKey(self): """ Test that an error is raised when trying to use a too long key: the called command should return a L{Deferred} which fail with a L{ClientError}. """ d1 = self.assertFailure(self.proto.set("a" * 500, "bar"), ClientError) d2 = self.assertFailure(self.proto.increment("a" * 500), ClientError) d3 = self.assertFailure(self.proto.get("a" * 500), ClientError) d4 = self.assertFailure(self.proto.append("a" * 500, "bar"), ClientError) d5 = self.assertFailure(self.proto.prepend("a" * 500, "bar"), ClientError) return gatherResults([d1, d2, d3, d4, d5]) def test_invalidCommand(self): """ When an unknown command is sent directly (not through public API), the server answers with an B{ERROR} token, and the command should fail with L{NoSuchCommand}. """ d = self.proto._set("egg", "foo", "bar", 0, 0, "") self.assertEquals(self.transport.value(), "egg foo 0 0 3\r\nbar\r\n") self.assertFailure(d, NoSuchCommand) self.proto.dataReceived("ERROR\r\n") return d def test_clientError(self): """ Test the L{ClientError} error: when the server send a B{CLIENT_ERROR} token, the originating command should fail with L{ClientError}, and the error should contain the text sent by the server. """ a = "eggspamm" d = self.proto.set("foo", a) self.assertEquals(self.transport.value(), "set foo 0 0 8\r\neggspamm\r\n") self.assertFailure(d, ClientError) def check(err): self.assertEquals(str(err), "We don't like egg and spam") d.addCallback(check) self.proto.dataReceived("CLIENT_ERROR We don't like egg and spam\r\n") return d def test_serverError(self): """ Test the L{ServerError} error: when the server send a B{SERVER_ERROR} token, the originating command should fail with L{ServerError}, and the error should contain the text sent by the server. """ a = "eggspamm" d = self.proto.set("foo", a) self.assertEquals(self.transport.value(), "set foo 0 0 8\r\neggspamm\r\n") self.assertFailure(d, ServerError) def check(err): self.assertEquals(str(err), "zomg") d.addCallback(check) self.proto.dataReceived("SERVER_ERROR zomg\r\n") return d def test_unicodeKey(self): """ Using a non-string key as argument to commands should raise an error. """ d1 = self.assertFailure(self.proto.set(u"foo", "bar"), ClientError) d2 = self.assertFailure(self.proto.increment(u"egg"), ClientError) d3 = self.assertFailure(self.proto.get(1), ClientError) d4 = self.assertFailure(self.proto.delete(u"bar"), ClientError) d5 = self.assertFailure(self.proto.append(u"foo", "bar"), ClientError) d6 = self.assertFailure(self.proto.prepend(u"foo", "bar"), ClientError) return gatherResults([d1, d2, d3, d4, d5, d6]) def test_unicodeValue(self): """ Using a non-string value should raise an error. """ return self.assertFailure(self.proto.set("foo", u"bar"), ClientError) def test_pipelining(self): """ Test that multiple requests can be sent subsequently to the server, and that the protocol order the responses correctly and dispatch to the corresponding client command. """ d1 = self.proto.get("foo") d1.addCallback(self.assertEquals, (0, "bar")) d2 = self.proto.set("bar", "spamspamspam") d2.addCallback(self.assertEquals, True) d3 = self.proto.get("egg") d3.addCallback(self.assertEquals, (0, "spam")) self.assertEquals(self.transport.value(), "get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n") self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n" "STORED\r\n" "VALUE egg 0 4\r\nspam\r\nEND\r\n") return gatherResults([d1, d2, d3]) def test_getInChunks(self): """ If the value retrieved by a C{get} arrive in chunks, the protocol should be able to reconstruct it and to produce the good value. """ d = self.proto.get("foo") d.addCallback(self.assertEquals, (0, "0123456789")) self.assertEquals(self.transport.value(), "get foo\r\n") self.proto.dataReceived("VALUE foo 0 10\r\n0123456") self.proto.dataReceived("789") self.proto.dataReceived("\r\nEND") self.proto.dataReceived("\r\n") return d def test_append(self): """ L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set} method: it should return a L{Deferred} which is called back with C{True} when the operation succeeds. """ return self._test(self.proto.append("foo", "bar"), "append foo 0 0 3\r\nbar\r\n", "STORED\r\n", True) def test_prepend(self): """ L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set} method: it should return a L{Deferred} which is called back with C{True} when the operation succeeds. """ return self._test(self.proto.prepend("foo", "bar"), "prepend foo 0 0 3\r\nbar\r\n", "STORED\r\n", True) def test_gets(self): """ L{MemCacheProtocol.get} should handle an additional cas result when C{withIdentifier} is C{True} and forward it in the resulting L{Deferred}. """ return self._test(self.proto.get("foo", True), "gets foo\r\n", "VALUE foo 0 3 1234\r\nbar\r\nEND\r\n", (0, "1234", "bar")) def test_emptyGets(self): """ Test getting a non-available key with gets: it should succeed but return C{None} as value, C{0} as flag and an empty cas value. """ return self._test(self.proto.get("foo", True), "gets foo\r\n", "END\r\n", (0, "", None)) def test_checkAndSet(self): """ L{MemCacheProtocol.checkAndSet} passes an additional cas identifier that the server should handle to check if the data has to be updated. """ return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"), "cas foo 0 0 3 1234\r\nbar\r\n", "STORED\r\n", True) def test_casUnknowKey(self): """ When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the resulting L{Deferred} should fire with C{False}. """ return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"), "cas foo 0 0 3 1234\r\nbar\r\n", "EXISTS\r\n", False) calendarserver-5.2+dfsg/twext/protocols/test/__init__.py0000644000175000017500000000120712263343324022600 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extentions to twisted.protocols """ calendarserver-5.2+dfsg/twext/protocols/memcache.py0000644000175000017500000004637212147725751021651 0ustar rahulrahul# -*- test-case-name: twisted.test.test_memcache -*- # Copyright (c) 2007-2009 Twisted Matrix Laboratories. # See LICENSE for details. """ Memcache client protocol. Memcached is a caching server, storing data in the form of pairs key/value, and memcache is the protocol to talk with it. To connect to a server, create a factory for L{MemCacheProtocol}:: from twisted.internet import reactor, protocol from twisted.protocols.memcache import MemCacheProtocol, DEFAULT_PORT d = protocol.ClientCreator(reactor, MemCacheProtocol ).connectTCP("localhost", DEFAULT_PORT) def doSomething(proto): # Here you call the memcache operations return proto.set("mykey", "a lot of data") d.addCallback(doSomething) reactor.run() All the operations of the memcache protocol are present, but L{MemCacheProtocol.set} and L{MemCacheProtocol.get} are the more important. See U{http://code.sixapart.com/svn/memcached/trunk/server/doc/protocol.txt} for more information about the protocol. """ try: from collections import deque except ImportError: class deque(list): def popleft(self): return self.pop(0) from twisted.protocols.basic import LineReceiver from twisted.protocols.policies import TimeoutMixin from twisted.internet.defer import Deferred, fail, TimeoutError from twext.python.log import Logger log = Logger() DEFAULT_PORT = 11211 class NoSuchCommand(Exception): """ Exception raised when a non existent command is called. """ class ClientError(Exception): """ Error caused by an invalid client call. """ class ServerError(Exception): """ Problem happening on the server. """ class Command(object): """ Wrap a client action into an object, that holds the values used in the protocol. @ivar _deferred: the L{Deferred} object that will be fired when the result arrives. @type _deferred: L{Deferred} @ivar command: name of the command sent to the server. @type command: C{str} """ def __init__(self, command, **kwargs): """ Create a command. @param command: the name of the command. @type command: C{str} @param kwargs: this values will be stored as attributes of the object for future use """ self.command = command self._deferred = Deferred() for k, v in kwargs.items(): setattr(self, k, v) def success(self, value): """ Shortcut method to fire the underlying deferred. """ self._deferred.callback(value) def fail(self, error): """ Make the underlying deferred fails. """ self._deferred.errback(error) class MemCacheProtocol(LineReceiver, TimeoutMixin): """ MemCache protocol: connect to a memcached server to store/retrieve values. @ivar persistentTimeOut: the timeout period used to wait for a response. @type persistentTimeOut: C{int} @ivar _current: current list of requests waiting for an answer from the server. @type _current: C{deque} of L{Command} @ivar _lenExpected: amount of data expected in raw mode, when reading for a value. @type _lenExpected: C{int} @ivar _getBuffer: current buffer of data, used to store temporary data when reading in raw mode. @type _getBuffer: C{list} @ivar _bufferLength: the total amount of bytes in C{_getBuffer}. @type _bufferLength: C{int} """ MAX_KEY_LENGTH = 250 def __init__(self, timeOut=60): """ Create the protocol. @param timeOut: the timeout to wait before detecting that the connection is dead and close it. It's expressed in seconds. @type timeOut: C{int} """ self._current = deque() self._lenExpected = None self._getBuffer = None self._bufferLength = None self.persistentTimeOut = self.timeOut = timeOut def timeoutConnection(self): """ Close the connection in case of timeout. """ for cmd in self._current: cmd.fail(TimeoutError("Connection timeout")) self.transport.loseConnection() def sendLine(self, line): """ Override sendLine to add a timeout to response. """ if not self._current: self.setTimeout(self.persistentTimeOut) LineReceiver.sendLine(self, line) def rawDataReceived(self, data): """ Collect data for a get. """ self.resetTimeout() self._getBuffer.append(data) self._bufferLength += len(data) if self._bufferLength >= self._lenExpected + 2: data = "".join(self._getBuffer) buf = data[:self._lenExpected] rem = data[self._lenExpected + 2:] val = buf self._lenExpected = None self._getBuffer = None self._bufferLength = None cmd = self._current[0] cmd.value = val self.setLineMode(rem) def cmd_STORED(self): """ Manage a success response to a set operation. """ self._current.popleft().success(True) def cmd_NOT_STORED(self): """ Manage a specific 'not stored' response to a set operation: this is not an error, but some condition wasn't met. """ self._current.popleft().success(False) def cmd_END(self): """ This the end token to a get or a stat operation. """ cmd = self._current.popleft() if cmd.command == "get": cmd.success((cmd.flags, cmd.value)) elif cmd.command == "gets": cmd.success((cmd.flags, cmd.cas, cmd.value)) elif cmd.command == "stats": cmd.success(cmd.values) def cmd_NOT_FOUND(self): """ Manage error response for incr/decr/delete. """ self._current.popleft().success(False) def cmd_VALUE(self, line): """ Prepare the reading a value after a get. """ cmd = self._current[0] if cmd.command == "get": key, flags, length = line.split() cas = "" else: key, flags, length, cas = line.split() self._lenExpected = int(length) self._getBuffer = [] self._bufferLength = 0 if cmd.key != key: raise RuntimeError("Unexpected commands answer.") cmd.flags = int(flags) cmd.length = self._lenExpected cmd.cas = cas self.setRawMode() def cmd_STAT(self, line): """ Reception of one stat line. """ cmd = self._current[0] key, val = line.split(" ", 1) cmd.values[key] = val def cmd_VERSION(self, versionData): """ Read version token. """ self._current.popleft().success(versionData) def cmd_ERROR(self): """ An non-existent command has been sent. """ log.error("Non-existent command sent.") cmd = self._current.popleft() cmd.fail(NoSuchCommand()) def cmd_CLIENT_ERROR(self, errText): """ An invalid input as been sent. """ log.error("Invalid input: %s" % (errText,)) cmd = self._current.popleft() cmd.fail(ClientError(errText)) def cmd_SERVER_ERROR(self, errText): """ An error has happened server-side. """ log.error("Server error: %s" % (errText,)) cmd = self._current.popleft() cmd.fail(ServerError(errText)) def cmd_DELETED(self): """ A delete command has completed successfully. """ self._current.popleft().success(True) def cmd_OK(self): """ The last command has been completed. """ self._current.popleft().success(True) def cmd_EXISTS(self): """ A C{checkAndSet} update has failed. """ self._current.popleft().success(False) def lineReceived(self, line): """ Receive line commands from the server. """ self.resetTimeout() token = line.split(" ", 1)[0] # First manage standard commands without space cmd = getattr(self, "cmd_%s" % (token,), None) if cmd is not None: args = line.split(" ", 1)[1:] if args: cmd(args[0]) else: cmd() else: # Then manage commands with space in it line = line.replace(" ", "_") cmd = getattr(self, "cmd_%s" % (line,), None) if cmd is not None: cmd() else: # Increment/Decrement response cmd = self._current.popleft() val = int(line) cmd.success(val) if not self._current: # No pending request, remove timeout self.setTimeout(None) def increment(self, key, val=1): """ Increment the value of C{key} by given value (default to 1). C{key} must be consistent with an int. Return the new value. @param key: the key to modify. @type key: C{str} @param val: the value to increment. @type val: C{int} @return: a deferred with will be called back with the new value associated with the key (after the increment). @rtype: L{Deferred} """ return self._incrdecr("incr", key, val) def decrement(self, key, val=1): """ Decrement the value of C{key} by given value (default to 1). C{key} must be consistent with an int. Return the new value, coerced to 0 if negative. @param key: the key to modify. @type key: C{str} @param val: the value to decrement. @type val: C{int} @return: a deferred with will be called back with the new value associated with the key (after the decrement). @rtype: L{Deferred} """ return self._incrdecr("decr", key, val) def _incrdecr(self, cmd, key, val): """ Internal wrapper for incr/decr. """ if not isinstance(key, str): return fail(ClientError( "Invalid type for key: %s, expecting a string" % (type(key),))) if len(key) > self.MAX_KEY_LENGTH: return fail(ClientError("Key too long")) fullcmd = "%s %s %d" % (cmd, key, int(val)) self.sendLine(fullcmd) cmdObj = Command(cmd, key=key) self._current.append(cmdObj) return cmdObj._deferred def replace(self, key, val, flags=0, expireTime=0): """ Replace the given C{key}. It must already exist in the server. @param key: the key to replace. @type key: C{str} @param val: the new value associated with the key. @type val: C{str} @param flags: the flags to store with the key. @type flags: C{int} @param expireTime: if different from 0, the relative time in seconds when the key will be deleted from the store. @type expireTime: C{int} @return: a deferred that will fire with C{True} if the operation has succeeded, and C{False} with the key didn't previously exist. @rtype: L{Deferred} """ return self._set("replace", key, val, flags, expireTime, "") def add(self, key, val, flags=0, expireTime=0): """ Add the given C{key}. It must not exist in the server. @param key: the key to add. @type key: C{str} @param val: the value associated with the key. @type val: C{str} @param flags: the flags to store with the key. @type flags: C{int} @param expireTime: if different from 0, the relative time in seconds when the key will be deleted from the store. @type expireTime: C{int} @return: a deferred that will fire with C{True} if the operation has succeeded, and C{False} with the key already exists. @rtype: L{Deferred} """ return self._set("add", key, val, flags, expireTime, "") def set(self, key, val, flags=0, expireTime=0): """ Set the given C{key}. @param key: the key to set. @type key: C{str} @param val: the value associated with the key. @type val: C{str} @param flags: the flags to store with the key. @type flags: C{int} @param expireTime: if different from 0, the relative time in seconds when the key will be deleted from the store. @type expireTime: C{int} @return: a deferred that will fire with C{True} if the operation has succeeded. @rtype: L{Deferred} """ return self._set("set", key, val, flags, expireTime, "") def checkAndSet(self, key, val, cas, flags=0, expireTime=0): """ Change the content of C{key} only if the C{cas} value matches the current one associated with the key. Use this to store a value which hasn't been modified since last time you fetched it. @param key: The key to set. @type key: C{str} @param val: The value associated with the key. @type val: C{str} @param cas: Unique 64-bit value returned by previous call of C{get}. @type cas: C{str} @param flags: The flags to store with the key. @type flags: C{int} @param expireTime: If different from 0, the relative time in seconds when the key will be deleted from the store. @type expireTime: C{int} @return: A deferred that will fire with C{True} if the operation has succeeded, C{False} otherwise. @rtype: L{Deferred} """ return self._set("cas", key, val, flags, expireTime, cas) def _set(self, cmd, key, val, flags, expireTime, cas): """ Internal wrapper for setting values. """ if not isinstance(key, str): return fail(ClientError( "Invalid type for key: %s, expecting a string" % (type(key),))) if len(key) > self.MAX_KEY_LENGTH: return fail(ClientError("Key too long")) if not isinstance(val, str): return fail(ClientError( "Invalid type for value: %s, expecting a string" % (type(val),))) if cas: cas = " " + cas length = len(val) fullcmd = "%s %s %d %d %d%s" % ( cmd, key, flags, expireTime, length, cas) self.sendLine(fullcmd) self.sendLine(val) cmdObj = Command(cmd, key=key, flags=flags, length=length) self._current.append(cmdObj) return cmdObj._deferred def append(self, key, val): """ Append given data to the value of an existing key. @param key: The key to modify. @type key: C{str} @param val: The value to append to the current value associated with the key. @type val: C{str} @return: A deferred that will fire with C{True} if the operation has succeeded, C{False} otherwise. @rtype: L{Deferred} """ # Even if flags and expTime values are ignored, we have to pass them return self._set("append", key, val, 0, 0, "") def prepend(self, key, val): """ Prepend given data to the value of an existing key. @param key: The key to modify. @type key: C{str} @param val: The value to prepend to the current value associated with the key. @type val: C{str} @return: A deferred that will fire with C{True} if the operation has succeeded, C{False} otherwise. @rtype: L{Deferred} """ # Even if flags and expTime values are ignored, we have to pass them return self._set("prepend", key, val, 0, 0, "") def get(self, key, withIdentifier=False): """ Get the given C{key}. It doesn't support multiple keys. If C{withIdentifier} is set to C{True}, the command issued is a C{gets}, that will return the current identifier associated with the value. This identifier has to be used when issuing C{checkAndSet} update later, using the corresponding method. @param key: The key to retrieve. @type key: C{str} @param withIdentifier: If set to C{True}, retrieve the current identifier along with the value and the flags. @type withIdentifier: C{bool} @return: A deferred that will fire with the tuple (flags, value) if C{withIdentifier} is C{False}, or (flags, cas identifier, value) if C{True}. @rtype: L{Deferred} """ if not isinstance(key, str): return fail(ClientError( "Invalid type for key: %s, expecting a string" % (type(key),))) if len(key) > self.MAX_KEY_LENGTH: return fail(ClientError("Key too long")) if withIdentifier: cmd = "gets" else: cmd = "get" fullcmd = "%s %s" % (cmd, key) self.sendLine(fullcmd) cmdObj = Command(cmd, key=key, value=None, flags=0, cas="") self._current.append(cmdObj) return cmdObj._deferred def stats(self, arg=None): """ Get some stats from the server. It will be available as a dict. @param arg: An optional additional string which will be sent along with the I{stats} command. The interpretation of this value by the server is left undefined by the memcache protocol specification. @type arg: L{NoneType} or L{str} @return: a deferred that will fire with a C{dict} of the available statistics. @rtype: L{Deferred} """ cmd = "stats" if arg: cmd = "stats " + arg self.sendLine(cmd) cmdObj = Command("stats", values={}) self._current.append(cmdObj) return cmdObj._deferred def version(self): """ Get the version of the server. @return: a deferred that will fire with the string value of the version. @rtype: L{Deferred} """ self.sendLine("version") cmdObj = Command("version") self._current.append(cmdObj) return cmdObj._deferred def delete(self, key): """ Delete an existing C{key}. @param key: the key to delete. @type key: C{str} @return: a deferred that will be called back with C{True} if the key was successfully deleted, or C{False} if not. @rtype: L{Deferred} """ if not isinstance(key, str): return fail(ClientError( "Invalid type for key: %s, expecting a string" % (type(key),))) self.sendLine("delete %s" % key) cmdObj = Command("delete", key=key) self._current.append(cmdObj) return cmdObj._deferred def flushAll(self): """ Flush all cached values. @return: a deferred that will be called back with C{True} when the operation has succeeded. @rtype: L{Deferred} """ self.sendLine("flush_all") cmdObj = Command("flush_all") self._current.append(cmdObj) return cmdObj._deferred __all__ = ["MemCacheProtocol", "DEFAULT_PORT", "NoSuchCommand", "ClientError", "ServerError"] calendarserver-5.2+dfsg/twext/protocols/__init__.py0000644000175000017500000000120712263343324021621 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extentions to twisted.protocols """ calendarserver-5.2+dfsg/twext/backport/0000755000175000017500000000000012322625326017272 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/backport/__init__.py0000644000175000017500000000130212263343324021376 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Backports of portions of Twisted. (Specifically, those required for IPv6 client support). """ calendarserver-5.2+dfsg/twext/backport/internet/0000755000175000017500000000000012322625326021122 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/backport/internet/tcp.py0000644000175000017500000011707511742073632022277 0ustar rahulrahul# -*- test-case-name: twisted.test.test_tcp -*- # Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ Various asynchronous TCP/IP classes. End users shouldn't use this module directly - use the reactor APIs instead. """ # System Imports import types import socket import sys import operator import struct from zope.interface import implements from twisted.python.runtime import platformType from twisted.python import versions, deprecate try: # Try to get the memory BIO based startTLS implementation, available since # pyOpenSSL 0.10 from twisted.internet._newtls import ( ConnectionMixin as _TLSConnectionMixin, ClientMixin as _TLSClientMixin, ServerMixin as _TLSServerMixin) except ImportError: try: # Try to get the socket BIO based startTLS implementation, available in # all pyOpenSSL versions from twisted.internet._oldtls import ( ConnectionMixin as _TLSConnectionMixin, ClientMixin as _TLSClientMixin, ServerMixin as _TLSServerMixin) except ImportError: # There is no version of startTLS available class _TLSConnectionMixin(object): TLS = False class _TLSClientMixin(object): pass class _TLSServerMixin(object): pass if platformType == 'win32': # no such thing as WSAEPERM or error code 10001 according to winsock.h or MSDN EPERM = object() from errno import WSAEINVAL as EINVAL from errno import WSAEWOULDBLOCK as EWOULDBLOCK from errno import WSAEINPROGRESS as EINPROGRESS from errno import WSAEALREADY as EALREADY from errno import WSAECONNRESET as ECONNRESET from errno import WSAEISCONN as EISCONN from errno import WSAENOTCONN as ENOTCONN from errno import WSAEINTR as EINTR from errno import WSAENOBUFS as ENOBUFS from errno import WSAEMFILE as EMFILE # No such thing as WSAENFILE, either. ENFILE = object() # Nor ENOMEM ENOMEM = object() EAGAIN = EWOULDBLOCK from errno import WSAECONNRESET as ECONNABORTED from twisted.python.win32 import formatError as strerror else: from errno import EPERM from errno import EINVAL from errno import EWOULDBLOCK from errno import EINPROGRESS from errno import EALREADY from errno import ECONNRESET from errno import EISCONN from errno import ENOTCONN from errno import EINTR from errno import ENOBUFS from errno import EMFILE from errno import ENFILE from errno import ENOMEM from errno import EAGAIN from errno import ECONNABORTED from os import strerror from errno import errorcode # Twisted Imports from twisted.internet import base, address, fdesc from twisted.internet.task import deferLater from twisted.python import log, failure, reflect from twisted.python.util import unsignedID from twisted.internet.error import CannotListenError from twisted.internet import abstract, main, interfaces, error # Not all platforms have, or support, this flag. _AI_NUMERICSERV = getattr(socket, "AI_NUMERICSERV", 0) class _SocketCloser(object): _socketShutdownMethod = 'shutdown' def _closeSocket(self, orderly): # The call to shutdown() before close() isn't really necessary, because # we set FD_CLOEXEC now, which will ensure this is the only process # holding the FD, thus ensuring close() really will shutdown the TCP # socket. However, do it anyways, just to be safe. skt = self.socket try: if orderly: getattr(skt, self._socketShutdownMethod)(2) else: # Set SO_LINGER to 1,0 which, by convention, causes a # connection reset to be sent when close is called, # instead of the standard FIN shutdown sequence. self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)) except socket.error: pass try: skt.close() except socket.error: pass class _AbortingMixin(object): """ Common implementation of C{abortConnection}. @ivar _aborting: Set to C{True} when C{abortConnection} is called. @type _aborting: C{bool} """ _aborting = False def abortConnection(self): """ Aborts the connection immediately, dropping any buffered data. @since: 11.1 """ if self.disconnected or self._aborting: return self._aborting = True self.stopReading() self.stopWriting() self.doRead = lambda *args, **kwargs: None self.doWrite = lambda *args, **kwargs: None self.reactor.callLater(0, self.connectionLost, failure.Failure(error.ConnectionAborted())) class Connection(_TLSConnectionMixin, abstract.FileDescriptor, _SocketCloser, _AbortingMixin): """ Superclass of all socket-based FileDescriptors. This is an abstract superclass of all objects which represent a TCP/IP connection based socket. @ivar logstr: prefix used when logging events related to this connection. @type logstr: C{str} """ implements(interfaces.ITCPTransport, interfaces.ISystemHandle) def __init__(self, skt, protocol, reactor=None): abstract.FileDescriptor.__init__(self, reactor=reactor) self.socket = skt self.socket.setblocking(0) self.fileno = skt.fileno self.protocol = protocol def getHandle(self): """Return the socket for this connection.""" return self.socket def doRead(self): """Calls self.protocol.dataReceived with all available data. This reads up to self.bufferSize bytes of data from its socket, then calls self.dataReceived(data) to process it. If the connection is not lost through an error in the physical recv(), this function will return the result of the dataReceived call. """ try: data = self.socket.recv(self.bufferSize) except socket.error, se: if se.args[0] == EWOULDBLOCK: return else: return main.CONNECTION_LOST if not data: return main.CONNECTION_DONE rval = self.protocol.dataReceived(data) if rval is not None: offender = self.protocol.dataReceived warningFormat = ( 'Returning a value other than None from %(fqpn)s is ' 'deprecated since %(version)s.') warningString = deprecate.getDeprecationWarningString( offender, versions.Version('Twisted', 11, 0, 0), format=warningFormat) deprecate.warnAboutFunction(offender, warningString) return rval def writeSomeData(self, data): """ Write as much as possible of the given data to this TCP connection. This sends up to C{self.SEND_LIMIT} bytes from C{data}. If the connection is lost, an exception is returned. Otherwise, the number of bytes successfully written is returned. """ try: # Limit length of buffer to try to send, because some OSes are too # stupid to do so themselves (ahem windows) return self.socket.send(buffer(data, 0, self.SEND_LIMIT)) except socket.error, se: if se.args[0] == EINTR: return self.writeSomeData(data) elif se.args[0] in (EWOULDBLOCK, ENOBUFS): return 0 else: return main.CONNECTION_LOST def _closeWriteConnection(self): try: getattr(self.socket, self._socketShutdownMethod)(1) except socket.error: pass p = interfaces.IHalfCloseableProtocol(self.protocol, None) if p: try: p.writeConnectionLost() except: f = failure.Failure() log.err() self.connectionLost(f) def readConnectionLost(self, reason): p = interfaces.IHalfCloseableProtocol(self.protocol, None) if p: try: p.readConnectionLost() except: log.err() self.connectionLost(failure.Failure()) else: self.connectionLost(reason) def connectionLost(self, reason): """See abstract.FileDescriptor.connectionLost(). """ # Make sure we're not called twice, which can happen e.g. if # abortConnection() is called from protocol's dataReceived and then # code immediately after throws an exception that reaches the # reactor. We can't rely on "disconnected" attribute for this check # since twisted.internet._oldtls does evil things to it: if not hasattr(self, "socket"): return abstract.FileDescriptor.connectionLost(self, reason) self._closeSocket(not reason.check(error.ConnectionAborted)) protocol = self.protocol del self.protocol del self.socket del self.fileno protocol.connectionLost(reason) logstr = "Uninitialized" def logPrefix(self): """Return the prefix to log with when I own the logging thread. """ return self.logstr def getTcpNoDelay(self): return operator.truth(self.socket.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)) def setTcpNoDelay(self, enabled): self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled) def getTcpKeepAlive(self): return operator.truth(self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)) def setTcpKeepAlive(self, enabled): self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, enabled) class _BaseBaseClient(object): """ Code shared with other (non-POSIX) reactors for management of general outgoing connections. Requirements upon subclasses are documented as instance variables rather than abstract methods, in order to avoid MRO confusion, since this base is mixed in to unfortunately weird and distinctive multiple-inheritance hierarchies and many of these attributes are provided by peer classes rather than descendant classes in those hierarchies. @ivar addressFamily: The address family constant (C{socket.AF_INET}, C{socket.AF_INET6}, C{socket.AF_UNIX}) of the underlying socket of this client connection. @type addressFamily: C{int} @ivar socketType: The socket type constant (C{socket.SOCK_STREAM} or C{socket.SOCK_DGRAM}) of the underlying socket. @type socketType: C{int} @ivar _requiresResolution: A flag indicating whether the address of this client will require name resolution. C{True} if the hostname of said address indicates a name that must be resolved by hostname lookup, C{False} if it indicates an IP address literal. @type _requiresResolution: C{bool} @cvar _commonConnection: Subclasses must provide this attribute, which indicates the L{Connection}-alike class to invoke C{__init__} and C{connectionLost} on. @type _commonConnection: C{type} @ivar _stopReadingAndWriting: Subclasses must implement in order to remove this transport from its reactor's notifications in response to a terminated connection attempt. @type _stopReadingAndWriting: 0-argument callable returning C{None} @ivar _closeSocket: Subclasses must implement in order to close the socket in response to a terminated connection attempt. @type _closeSocket: 1-argument callable; see L{_SocketCloser._closeSocket} @ivar _collectSocketDetails: Clean up references to the attached socket in its underlying OS resource (such as a file descriptor or file handle), as part of post connection-failure cleanup. @type _collectSocketDetails: 0-argument callable returning C{None}. @ivar reactor: The class pointed to by C{_commonConnection} should set this attribute in its constructor. @type reactor: L{twisted.internet.interfaces.IReactorTime}, L{twisted.internet.interfaces.IReactorCore}, L{twisted.internet.interfaces.IReactorFDSet} """ addressFamily = socket.AF_INET socketType = socket.SOCK_STREAM def _finishInit(self, whenDone, skt, error, reactor): """ Called by subclasses to continue to the stage of initialization where the socket connect attempt is made. @param whenDone: A 0-argument callable to invoke once the connection is set up. This is C{None} if the connection could not be prepared due to a previous error. @param skt: The socket object to use to perform the connection. @type skt: C{socket._socketobject} @param error: The error to fail the connection with. @param reactor: The reactor to use for this client. @type reactor: L{twisted.internet.interfaces.IReactorTime} """ if whenDone: self._commonConnection.__init__(self, skt, None, reactor) reactor.callLater(0, whenDone) else: reactor.callLater(0, self.failIfNotConnected, error) def resolveAddress(self): """ Resolve the name that was passed to this L{_BaseBaseClient}, if necessary, and then move on to attempting the connection once an address has been determined. (The connection will be attempted immediately within this function if either name resolution can be synchronous or the address was an IP address literal.) @note: You don't want to call this method from outside, as it won't do anything useful; it's just part of the connection bootstrapping process. Also, although this method is on L{_BaseBaseClient} for historical reasons, it's not used anywhere except for L{Client} itself. @return: C{None} """ if self._requiresResolution: d = self.reactor.resolve(self.addr[0]) d.addCallback(lambda n: (n,) + self.addr[1:]) d.addCallbacks(self._setRealAddress, self.failIfNotConnected) else: self._setRealAddress(self.addr) def _setRealAddress(self, address): """ Set the resolved address of this L{_BaseBaseClient} and initiate the connection attempt. @param address: Depending on whether this is an IPv4 or IPv6 connection attempt, a 2-tuple of C{(host, port)} or a 4-tuple of C{(host, port, flow, scope)}. At this point it is a fully resolved address, and the 'host' portion will always be an IP address, not a DNS name. """ self.realAddress = address self.doConnect() def failIfNotConnected(self, err): """ Generic method called when the attemps to connect failed. It basically cleans everything it can: call connectionFailed, stop read and write, delete socket related members. """ if (self.connected or self.disconnected or not hasattr(self, "connector")): return self._stopReadingAndWriting() try: self._closeSocket(True) except AttributeError: pass else: self._collectSocketDetails() self.connector.connectionFailed(failure.Failure(err)) del self.connector def stopConnecting(self): """ If a connection attempt is still outstanding (i.e. no connection is yet established), immediately stop attempting to connect. """ self.failIfNotConnected(error.UserError()) def connectionLost(self, reason): """ Invoked by lower-level logic when it's time to clean the socket up. Depending on the state of the connection, either inform the attached L{Connector} that the connection attempt has failed, or inform the connected L{IProtocol} that the established connection has been lost. @param reason: the reason that the connection was terminated @type reason: L{Failure} """ if not self.connected: self.failIfNotConnected(error.ConnectError(string=reason)) else: self._commonConnection.connectionLost(self, reason) self.connector.connectionLost(reason) class BaseClient(_BaseBaseClient, _TLSClientMixin, Connection): """ A base class for client TCP (and similiar) sockets. @ivar realAddress: The address object that will be used for socket.connect; this address is an address tuple (the number of elements dependent upon the address family) which does not contain any names which need to be resolved. @type realAddress: C{tuple} @ivar _base: L{Connection}, which is the base class of this class which has all of the useful file descriptor methods. This is used by L{_TLSServerMixin} to call the right methods to directly manipulate the transport, as is necessary for writing TLS-encrypted bytes (whereas those methods on L{Server} will go through another layer of TLS if it has been enabled). """ _base = Connection _commonConnection = Connection def _stopReadingAndWriting(self): """ Implement the POSIX-ish (i.e. L{twisted.internet.interfaces.IReactorFDSet}) method of detaching this socket from the reactor for L{_BaseBaseClient}. """ if hasattr(self, "reactor"): # this doesn't happen if we failed in __init__ self.stopReading() self.stopWriting() def _collectSocketDetails(self): """ Clean up references to the socket and its file descriptor. @see: L{_BaseBaseClient} """ del self.socket, self.fileno def createInternetSocket(self): """(internal) Create a non-blocking socket using self.addressFamily, self.socketType. """ s = socket.socket(self.addressFamily, self.socketType) s.setblocking(0) fdesc._setCloseOnExec(s.fileno()) return s def doConnect(self): """ Initiate the outgoing connection attempt. @note: Applications do not need to call this method; it will be invoked internally as part of L{IReactorTCP.connectTCP}. """ self.doWrite = self.doConnect self.doRead = self.doConnect if not hasattr(self, "connector"): # this happens when connection failed but doConnect # was scheduled via a callLater in self._finishInit return err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err: self.failIfNotConnected(error.getConnectError((err, strerror(err)))) return # doConnect gets called twice. The first time we actually need to # start the connection attempt. The second time we don't really # want to (SO_ERROR above will have taken care of any errors, and if # it reported none, the mere fact that doConnect was called again is # sufficient to indicate that the connection has succeeded), but it # is not /particularly/ detrimental to do so. This should get # cleaned up some day, though. try: connectResult = self.socket.connect_ex(self.realAddress) except socket.error, se: connectResult = se.args[0] if connectResult: if connectResult == EISCONN: pass # on Windows EINVAL means sometimes that we should keep trying: # http://msdn.microsoft.com/library/default.asp?url=/library/en-us/winsock/winsock/connect_2.asp elif ((connectResult in (EWOULDBLOCK, EINPROGRESS, EALREADY)) or (connectResult == EINVAL and platformType == "win32")): self.startReading() self.startWriting() return else: self.failIfNotConnected(error.getConnectError((connectResult, strerror(connectResult)))) return # If I have reached this point without raising or returning, that means # that the socket is connected. del self.doWrite del self.doRead # we first stop and then start, to reset any references to the old doRead self.stopReading() self.stopWriting() self._connectDone() def _connectDone(self): """ This is a hook for when a connection attempt has succeeded. Here, we build the protocol from the L{twisted.internet.protocol.ClientFactory} that was passed in, compute a log string, begin reading so as to send traffic to the newly built protocol, and finally hook up the protocol itself. This hook is overridden by L{ssl.Client} to initiate the TLS protocol. """ self.protocol = self.connector.buildProtocol(self.getPeer()) self.connected = 1 logPrefix = self._getLogPrefix(self.protocol) self.logstr = "%s,client" % logPrefix self.startReading() self.protocol.makeConnection(self) _NUMERIC_ONLY = socket.AI_NUMERICHOST | _AI_NUMERICSERV def _resolveIPv6(ip, port): """ Resolve an IPv6 literal into an IPv6 address. This is necessary to resolve any embedded scope identifiers to the relevant C{sin6_scope_id} for use with C{socket.connect()}, C{socket.listen()}, or C{socket.bind()}; see U{RFC 3493 } for more information. @param ip: An IPv6 address literal. @type ip: C{str} @param port: A port number. @type port: C{int} @return: a 4-tuple of C{(host, port, flow, scope)}, suitable for use as an IPv6 address. @raise socket.gaierror: if either the IP or port is not numeric as it should be. """ return socket.getaddrinfo(ip, port, 0, 0, 0, _NUMERIC_ONLY)[0][4] class _BaseTCPClient(object): """ Code shared with other (non-POSIX) reactors for management of outgoing TCP connections (both TCPv4 and TCPv6). @note: In order to be functional, this class must be mixed into the same hierarchy as L{_BaseBaseClient}. It would subclass L{_BaseBaseClient} directly, but the class hierarchy here is divided in strange ways out of the need to share code along multiple axes; specifically, with the IOCP reactor and also with UNIX clients in other reactors. @ivar _addressType: The Twisted _IPAddress implementation for this client @type _addressType: L{IPv4Address} or L{IPv6Address} @ivar connector: The L{Connector} which is driving this L{_BaseTCPClient}'s connection attempt. @ivar addr: The address that this socket will be connecting to. @type addr: If IPv4, a 2-C{tuple} of C{(str host, int port)}. If IPv6, a 4-C{tuple} of (C{str host, int port, int ignored, int scope}). @ivar createInternetSocket: Subclasses must implement this as a method to create a python socket object of the appropriate address family and socket type. @type createInternetSocket: 0-argument callable returning C{socket._socketobject}. """ _addressType = address.IPv4Address def __init__(self, host, port, bindAddress, connector, reactor=None): # BaseClient.__init__ is invoked later self.connector = connector self.addr = (host, port) whenDone = self.resolveAddress err = None skt = None if abstract.isIPAddress(host): self._requiresResolution = False elif abstract.isIPv6Address(host): self._requiresResolution = False self.addr = _resolveIPv6(host, port) self.addressFamily = socket.AF_INET6 self._addressType = address.IPv6Address else: self._requiresResolution = True try: skt = self.createInternetSocket() except socket.error, se: err = error.ConnectBindError(se.args[0], se.args[1]) whenDone = None if whenDone and bindAddress is not None: try: if abstract.isIPv6Address(bindAddress[0]): bindinfo = _resolveIPv6(*bindAddress) else: bindinfo = bindAddress skt.bind(bindinfo) except socket.error, se: err = error.ConnectBindError(se.args[0], se.args[1]) whenDone = None self._finishInit(whenDone, skt, err, reactor) def getHost(self): """ Returns an L{IPv4Address} or L{IPv6Address}. This indicates the address from which I am connecting. """ return self._addressType('TCP', *self.socket.getsockname()[:2]) def getPeer(self): """ Returns an L{IPv4Address} or L{IPv6Address}. This indicates the address that I am connected to. """ # an ipv6 realAddress has more than two elements, but the IPv6Address # constructor still only takes two. return self._addressType('TCP', *self.realAddress[:2]) def __repr__(self): s = '<%s to %s at %x>' % (self.__class__, self.addr, unsignedID(self)) return s class Client(_BaseTCPClient, BaseClient): """ A transport for a TCP protocol; either TCPv4 or TCPv6. Do not create these directly; use L{IReactorTCP.connectTCP}. """ class Server(_TLSServerMixin, Connection): """ Serverside socket-stream connection class. This is a serverside network connection transport; a socket which came from an accept() on a server. @ivar _base: L{Connection}, which is the base class of this class which has all of the useful file descriptor methods. This is used by L{_TLSServerMixin} to call the right methods to directly manipulate the transport, as is necessary for writing TLS-encrypted bytes (whereas those methods on L{Server} will go through another layer of TLS if it has been enabled). """ _base = Connection _addressType = address.IPv4Address def __init__(self, sock, protocol, client, server, sessionno, reactor): """ Server(sock, protocol, client, server, sessionno) Initialize it with a socket, a protocol, a descriptor for my peer (a tuple of host, port describing the other end of the connection), an instance of Port, and a session number. """ Connection.__init__(self, sock, protocol, reactor) if len(client) != 2: self._addressType = address.IPv6Address self.server = server self.client = client self.sessionno = sessionno self.hostname = client[0] logPrefix = self._getLogPrefix(self.protocol) self.logstr = "%s,%s,%s" % (logPrefix, sessionno, self.hostname) self.repstr = "<%s #%s on %s>" % (self.protocol.__class__.__name__, self.sessionno, self.server._realPortNumber) self.startReading() self.connected = 1 def __repr__(self): """A string representation of this connection. """ return self.repstr def getHost(self): """ Returns an L{IPv4Address} or L{IPv6Address}. This indicates the server's address. """ host, port = self.socket.getsockname()[:2] return self._addressType('TCP', host, port) def getPeer(self): """ Returns an L{IPv4Address} or L{IPv6Address}. This indicates the client's address. """ return self._addressType('TCP', *self.client[:2]) class Port(base.BasePort, _SocketCloser): """ A TCP server port, listening for connections. When a connection is accepted, this will call a factory's buildProtocol with the incoming address as an argument, according to the specification described in L{twisted.internet.interfaces.IProtocolFactory}. If you wish to change the sort of transport that will be used, the C{transport} attribute will be called with the signature expected for C{Server.__init__}, so it can be replaced. @ivar deferred: a deferred created when L{stopListening} is called, and that will fire when connection is lost. This is not to be used it directly: prefer the deferred returned by L{stopListening} instead. @type deferred: L{defer.Deferred} @ivar disconnecting: flag indicating that the L{stopListening} method has been called and that no connections should be accepted anymore. @type disconnecting: C{bool} @ivar connected: flag set once the listen has successfully been called on the socket. @type connected: C{bool} @ivar _type: A string describing the connections which will be created by this port. Normally this is C{"TCP"}, since this is a TCP port, but when the TLS implementation re-uses this class it overrides the value with C{"TLS"}. Only used for logging. @ivar _preexistingSocket: If not C{None}, a L{socket.socket} instance which was created and initialized outside of the reactor and will be used to listen for connections (instead of a new socket being created by this L{Port}). """ implements(interfaces.IListeningPort) socketType = socket.SOCK_STREAM transport = Server sessionno = 0 interface = '' backlog = 50 _type = 'TCP' # Actual port number being listened on, only set to a non-None # value when we are actually listening. _realPortNumber = None # An externally initialized socket that we will use, rather than creating # our own. _preexistingSocket = None addressFamily = socket.AF_INET _addressType = address.IPv4Address def __init__(self, port, factory, backlog=50, interface='', reactor=None): """Initialize with a numeric port to listen on. """ base.BasePort.__init__(self, reactor=reactor) self.port = port self.factory = factory self.backlog = backlog if abstract.isIPv6Address(interface): self.addressFamily = socket.AF_INET6 self._addressType = address.IPv6Address self.interface = interface @classmethod def _fromListeningDescriptor(cls, reactor, fd, addressFamily, factory): """ Create a new L{Port} based on an existing listening I{SOCK_STREAM} I{AF_INET} socket. Arguments are the same as to L{Port.__init__}, except where noted. @param fd: An integer file descriptor associated with a listening socket. The socket must be in non-blocking mode. Any additional attributes desired, such as I{FD_CLOEXEC}, must also be set already. @param addressFamily: The address family (sometimes called I{domain}) of the existing socket. For example, L{socket.AF_INET}. @return: A new instance of C{cls} wrapping the socket given by C{fd}. """ port = socket.fromfd(fd, addressFamily, cls.socketType) interface = port.getsockname()[0] self = cls(None, factory, None, interface, reactor) self._preexistingSocket = port return self def __repr__(self): if self._realPortNumber is not None: return "<%s of %s on %s>" % (self.__class__, self.factory.__class__, self._realPortNumber) else: return "<%s of %s (not listening)>" % (self.__class__, self.factory.__class__) def createInternetSocket(self): s = base.BasePort.createInternetSocket(self) if platformType == "posix" and sys.platform != "cygwin": s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s def startListening(self): """Create and bind my socket, and begin listening on it. This is called on unserialization, and must be called after creating a server to begin listening on the specified port. """ if self._preexistingSocket is None: # Create a new socket and make it listen try: skt = self.createInternetSocket() if self.addressFamily == socket.AF_INET6: addr = _resolveIPv6(self.interface, self.port) else: addr = (self.interface, self.port) skt.bind(addr) except socket.error, le: raise CannotListenError, (self.interface, self.port, le) skt.listen(self.backlog) else: # Re-use the externally specified socket skt = self._preexistingSocket self._preexistingSocket = None # Make sure that if we listened on port 0, we update that to # reflect what the OS actually assigned us. self._realPortNumber = skt.getsockname()[1] log.msg("%s starting on %s" % ( self._getLogPrefix(self.factory), self._realPortNumber)) # The order of the next 5 lines is kind of bizarre. If no one # can explain it, perhaps we should re-arrange them. self.factory.doStart() self.connected = True self.socket = skt self.fileno = self.socket.fileno self.numberAccepts = 100 self.startReading() def _buildAddr(self, address): host, port = address[:2] return self._addressType('TCP', host, port) def doRead(self): """Called when my socket is ready for reading. This accepts a connection and calls self.protocol() to handle the wire-level protocol. """ try: if platformType == "posix": numAccepts = self.numberAccepts else: # win32 event loop breaks if we do more than one accept() # in an iteration of the event loop. numAccepts = 1 for i in range(numAccepts): # we need this so we can deal with a factory's buildProtocol # calling our loseConnection if self.disconnecting: return try: skt, addr = self.socket.accept() except socket.error, e: if e.args[0] in (EWOULDBLOCK, EAGAIN): self.numberAccepts = i break elif e.args[0] == EPERM: # Netfilter on Linux may have rejected the # connection, but we get told to try to accept() # anyway. continue elif e.args[0] in (EMFILE, ENOBUFS, ENFILE, ENOMEM, ECONNABORTED): # Linux gives EMFILE when a process is not allowed # to allocate any more file descriptors. *BSD and # Win32 give (WSA)ENOBUFS. Linux can also give # ENFILE if the system is out of inodes, or ENOMEM # if there is insufficient memory to allocate a new # dentry. ECONNABORTED is documented as possible on # both Linux and Windows, but it is not clear # whether there are actually any circumstances under # which it can happen (one might expect it to be # possible if a client sends a FIN or RST after the # server sends a SYN|ACK but before application code # calls accept(2), however at least on Linux this # _seems_ to be short-circuited by syncookies. log.msg("Could not accept new connection (%s)" % ( errorcode[e.args[0]],)) break raise fdesc._setCloseOnExec(skt.fileno()) protocol = self.factory.buildProtocol(self._buildAddr(addr)) if protocol is None: skt.close() continue s = self.sessionno self.sessionno = s+1 transport = self.transport(skt, protocol, addr, self, s, self.reactor) protocol.makeConnection(transport) else: self.numberAccepts = self.numberAccepts+20 except: # Note that in TLS mode, this will possibly catch SSL.Errors # raised by self.socket.accept() # # There is no "except SSL.Error:" above because SSL may be # None if there is no SSL support. In any case, all the # "except SSL.Error:" suite would probably do is log.deferr() # and return, so handling it here works just as well. log.deferr() def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)): """ Stop accepting connections on this port. This will shut down the socket and call self.connectionLost(). It returns a deferred which will fire successfully when the port is actually closed, or with a failure if an error occurs shutting down. """ self.disconnecting = True self.stopReading() if self.connected: self.deferred = deferLater( self.reactor, 0, self.connectionLost, connDone) return self.deferred stopListening = loseConnection def _logConnectionLostMsg(self): """ Log message for closing port """ log.msg('(%s Port %s Closed)' % (self._type, self._realPortNumber)) def connectionLost(self, reason): """ Cleans up the socket. """ self._logConnectionLostMsg() self._realPortNumber = None base.BasePort.connectionLost(self, reason) self.connected = False self._closeSocket(True) del self.socket del self.fileno try: self.factory.doStop() finally: self.disconnecting = False def logPrefix(self): """Returns the name of my class, to prefix log entries with. """ return reflect.qual(self.factory.__class__) def getHost(self): """ Return an L{IPv4Address} or L{IPv6Address} indicating the listening address of this port. """ host, port = self.socket.getsockname()[:2] return self._addressType('TCP', host, port) class Connector(base.BaseConnector): """ A L{Connector} provides of L{twisted.internet.interfaces.IConnector} for all POSIX-style reactors. @ivar _addressType: the type returned by L{Connector.getDestination}. Either L{IPv4Address} or L{IPv6Address}, depending on the type of address. @type _addressType: C{type} """ _addressType = address.IPv4Address def __init__(self, host, port, factory, timeout, bindAddress, reactor=None): if isinstance(port, types.StringTypes): try: port = socket.getservbyname(port, 'tcp') except socket.error, e: raise error.ServiceNameUnknownError(string="%s (%r)" % (e, port)) self.host, self.port = host, port if abstract.isIPv6Address(host): self._addressType = address.IPv6Address self.bindAddress = bindAddress base.BaseConnector.__init__(self, factory, timeout, reactor) def _makeTransport(self): """ Create a L{Client} bound to this L{Connector}. @return: a new L{Client} @rtype: L{Client} """ return Client(self.host, self.port, self.bindAddress, self, self.reactor) def getDestination(self): """ @see: L{twisted.internet.interfaces.IConnector.getDestination}. """ return self._addressType('TCP', self.host, self.port) calendarserver-5.2+dfsg/twext/backport/internet/address.py0000644000175000017500000000771511742073632023135 0ustar rahulrahul# Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ Address objects for network connections. """ import warnings, os from zope.interface import implements from twisted.internet.interfaces import IAddress from twisted.python import util class _IPAddress(object, util.FancyEqMixin): """ An L{_IPAddress} represents the address of an IP socket endpoint, providing common behavior for IPv4 and IPv6. @ivar type: A string describing the type of transport, either 'TCP' or 'UDP'. @ivar host: A string containing the presentation format of the IP address; for example, "127.0.0.1" or "::1". @type host: C{str} @ivar port: An integer representing the port number. @type port: C{int} """ implements(IAddress) compareAttributes = ('type', 'host', 'port') def __init__(self, type, host, port): assert type in ('TCP', 'UDP') self.type = type self.host = host self.port = port def __repr__(self): return '%s(%s, %r, %d)' % ( self.__class__.__name__, self.type, self.host, self.port) def __hash__(self): return hash((self.type, self.host, self.port)) class IPv4Address(_IPAddress): """ An L{IPv4Address} represents the address of an IPv4 socket endpoint. @ivar host: A string containing a dotted-quad IPv4 address; for example, "127.0.0.1". @type host: C{str} """ def __init__(self, type, host, port, _bwHack=None): _IPAddress.__init__(self, type, host, port) if _bwHack is not None: warnings.warn("twisted.internet.address.IPv4Address._bwHack " "is deprecated since Twisted 11.0", DeprecationWarning, stacklevel=2) class IPv6Address(_IPAddress): """ An L{IPv6Address} represents the address of an IPv6 socket endpoint. @ivar host: A string containing a colon-separated, hexadecimal formatted IPv6 address; for example, "::1". @type host: C{str} """ class UNIXAddress(object, util.FancyEqMixin): """ Object representing a UNIX socket endpoint. @ivar name: The filename associated with this socket. @type name: C{str} """ implements(IAddress) compareAttributes = ('name', ) def __init__(self, name, _bwHack = None): self.name = name if _bwHack is not None: warnings.warn("twisted.internet.address.UNIXAddress._bwHack is deprecated since Twisted 11.0", DeprecationWarning, stacklevel=2) if getattr(os.path, 'samefile', None) is not None: def __eq__(self, other): """ overriding L{util.FancyEqMixin} to ensure the os level samefile check is done if the name attributes do not match. """ res = super(UNIXAddress, self).__eq__(other) if res == False: try: return os.path.samefile(self.name, other.name) except OSError: pass return res def __repr__(self): return 'UNIXAddress(%r)' % (self.name,) def __hash__(self): try: s1 = os.stat(self.name) return hash((s1.st_ino, s1.st_dev)) except OSError: return hash(self.name) # These are for buildFactory backwards compatability due to # stupidity-induced inconsistency. class _ServerFactoryIPv4Address(IPv4Address): """Backwards compatability hack. Just like IPv4Address in practice.""" def __eq__(self, other): if isinstance(other, tuple): warnings.warn("IPv4Address.__getitem__ is deprecated. Use attributes instead.", category=DeprecationWarning, stacklevel=2) return (self.host, self.port) == other elif isinstance(other, IPv4Address): a = (self.type, self.host, self.port) b = (other.type, other.host, other.port) return a == b return False calendarserver-5.2+dfsg/twext/backport/internet/__init__.py0000644000175000017500000000131612263343324023233 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Backports of portions of L{twisted.internet}. (Specifically, those required for IPv6 client support). """ calendarserver-5.2+dfsg/twext/backport/internet/endpoints.py0000644000175000017500000011475011742073632023511 0ustar rahulrahul# -*- test-case-name: twisted.internet.test.test_endpoints -*- # Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ Implementations of L{IStreamServerEndpoint} and L{IStreamClientEndpoint} that wrap the L{IReactorTCP}, L{IReactorSSL}, and L{IReactorUNIX} interfaces. This also implements an extensible mini-language for describing endpoints, parsed by the L{clientFromString} and L{serverFromString} functions. @since: 10.1 """ import os, socket from zope.interface import implements, directlyProvides import warnings from twisted.internet import interfaces, defer, error, fdesc from twisted.internet.protocol import ClientFactory, Protocol from twisted.plugin import IPlugin, getPlugins from twisted.internet.interfaces import IStreamServerEndpointStringParser from twisted.internet.interfaces import IStreamClientEndpointStringParser from twisted.python.filepath import FilePath #from twisted.python.systemd import ListenFDs __all__ = ["clientFromString", "serverFromString", "TCP4ServerEndpoint", "TCP4ClientEndpoint", "UNIXServerEndpoint", "UNIXClientEndpoint", "SSL4ServerEndpoint", "SSL4ClientEndpoint", "AdoptedStreamServerEndpoint"] class _WrappingProtocol(Protocol): """ Wrap another protocol in order to notify my user when a connection has been made. @ivar _connectedDeferred: The L{Deferred} that will callback with the C{wrappedProtocol} when it is connected. @ivar _wrappedProtocol: An L{IProtocol} provider that will be connected. """ def __init__(self, connectedDeferred, wrappedProtocol): """ @param connectedDeferred: The L{Deferred} that will callback with the C{wrappedProtocol} when it is connected. @param wrappedProtocol: An L{IProtocol} provider that will be connected. """ self._connectedDeferred = connectedDeferred self._wrappedProtocol = wrappedProtocol if interfaces.IHalfCloseableProtocol.providedBy( self._wrappedProtocol): directlyProvides(self, interfaces.IHalfCloseableProtocol) def logPrefix(self): """ Transparently pass through the wrapped protocol's log prefix. """ if interfaces.ILoggingContext.providedBy(self._wrappedProtocol): return self._wrappedProtocol.logPrefix() return self._wrappedProtocol.__class__.__name__ def connectionMade(self): """ Connect the C{self._wrappedProtocol} to our C{self.transport} and callback C{self._connectedDeferred} with the C{self._wrappedProtocol} """ self._wrappedProtocol.makeConnection(self.transport) self._connectedDeferred.callback(self._wrappedProtocol) def dataReceived(self, data): """ Proxy C{dataReceived} calls to our C{self._wrappedProtocol} """ return self._wrappedProtocol.dataReceived(data) def connectionLost(self, reason): """ Proxy C{connectionLost} calls to our C{self._wrappedProtocol} """ return self._wrappedProtocol.connectionLost(reason) def readConnectionLost(self): """ Proxy L{IHalfCloseableProtocol.readConnectionLost} to our C{self._wrappedProtocol} """ self._wrappedProtocol.readConnectionLost() def writeConnectionLost(self): """ Proxy L{IHalfCloseableProtocol.writeConnectionLost} to our C{self._wrappedProtocol} """ self._wrappedProtocol.writeConnectionLost() class _WrappingFactory(ClientFactory): """ Wrap a factory in order to wrap the protocols it builds. @ivar _wrappedFactory: A provider of I{IProtocolFactory} whose buildProtocol method will be called and whose resulting protocol will be wrapped. @ivar _onConnection: An L{Deferred} that fires when the protocol is connected @ivar _connector: A L{connector } that is managing the current or previous connection attempt. """ protocol = _WrappingProtocol def __init__(self, wrappedFactory): """ @param wrappedFactory: A provider of I{IProtocolFactory} whose buildProtocol method will be called and whose resulting protocol will be wrapped. """ self._wrappedFactory = wrappedFactory self._onConnection = defer.Deferred(canceller=self._canceller) def startedConnecting(self, connector): """ A connection attempt was started. Remember the connector which started said attempt, for use later. """ self._connector = connector def _canceller(self, deferred): """ The outgoing connection attempt was cancelled. Fail that L{Deferred} with a L{error.ConnectingCancelledError}. @param deferred: The L{Deferred } that was cancelled; should be the same as C{self._onConnection}. @type deferred: L{Deferred } @note: This relies on startedConnecting having been called, so it may seem as though there's a race condition where C{_connector} may not have been set. However, using public APIs, this condition is impossible to catch, because a connection API (C{connectTCP}/C{SSL}/C{UNIX}) is always invoked before a L{_WrappingFactory}'s L{Deferred } is returned to C{connect()}'s caller. @return: C{None} """ deferred.errback( error.ConnectingCancelledError( self._connector.getDestination())) self._connector.stopConnecting() def doStart(self): """ Start notifications are passed straight through to the wrapped factory. """ self._wrappedFactory.doStart() def doStop(self): """ Stop notifications are passed straight through to the wrapped factory. """ self._wrappedFactory.doStop() def buildProtocol(self, addr): """ Proxy C{buildProtocol} to our C{self._wrappedFactory} or errback the C{self._onConnection} L{Deferred}. @return: An instance of L{_WrappingProtocol} or C{None} """ try: proto = self._wrappedFactory.buildProtocol(addr) except: self._onConnection.errback() else: return self.protocol(self._onConnection, proto) def clientConnectionFailed(self, connector, reason): """ Errback the C{self._onConnection} L{Deferred} when the client connection fails. """ if not self._onConnection.called: self._onConnection.errback(reason) class TCP4ServerEndpoint(object): """ TCP server endpoint with an IPv4 configuration @ivar _reactor: An L{IReactorTCP} provider. @type _port: int @ivar _port: The port number on which to listen for incoming connections. @type _backlog: int @ivar _backlog: size of the listen queue @type _interface: str @ivar _interface: the hostname to bind to, defaults to '' (all) """ implements(interfaces.IStreamServerEndpoint) def __init__(self, reactor, port, backlog=50, interface=''): """ @param reactor: An L{IReactorTCP} provider. @param port: The port number used listening @param backlog: size of the listen queue @param interface: the hostname to bind to, defaults to '' (all) """ self._reactor = reactor self._port = port self._listenArgs = dict(backlog=50, interface='') self._backlog = backlog self._interface = interface def listen(self, protocolFactory): """ Implement L{IStreamServerEndpoint.listen} to listen on a TCP socket """ return defer.execute(self._reactor.listenTCP, self._port, protocolFactory, backlog=self._backlog, interface=self._interface) class TCP4ClientEndpoint(object): """ TCP client endpoint with an IPv4 configuration. @ivar _reactor: An L{IReactorTCP} provider. @type _host: str @ivar _host: The hostname to connect to as a C{str} @type _port: int @ivar _port: The port to connect to as C{int} @type _timeout: int @ivar _timeout: number of seconds to wait before assuming the connection has failed. @type _bindAddress: tuple @type _bindAddress: a (host, port) tuple of local address to bind to, or None. """ implements(interfaces.IStreamClientEndpoint) def __init__(self, reactor, host, port, timeout=30, bindAddress=None): """ @param reactor: An L{IReactorTCP} provider @param host: A hostname, used when connecting @param port: The port number, used when connecting @param timeout: number of seconds to wait before assuming the connection has failed. @param bindAddress: a (host, port tuple of local address to bind to, or None. """ self._reactor = reactor self._host = host self._port = port self._timeout = timeout self._bindAddress = bindAddress def connect(self, protocolFactory): """ Implement L{IStreamClientEndpoint.connect} to connect via TCP. """ try: wf = _WrappingFactory(protocolFactory) self._reactor.connectTCP( self._host, self._port, wf, timeout=self._timeout, bindAddress=self._bindAddress) return wf._onConnection except: return defer.fail() class SSL4ServerEndpoint(object): """ SSL secured TCP server endpoint with an IPv4 configuration. @ivar _reactor: An L{IReactorSSL} provider. @type _host: str @ivar _host: The hostname to connect to as a C{str} @type _port: int @ivar _port: The port to connect to as C{int} @type _sslContextFactory: L{OpenSSLCertificateOptions} @var _sslContextFactory: SSL Configuration information as an L{OpenSSLCertificateOptions} @type _backlog: int @ivar _backlog: size of the listen queue @type _interface: str @ivar _interface: the hostname to bind to, defaults to '' (all) """ implements(interfaces.IStreamServerEndpoint) def __init__(self, reactor, port, sslContextFactory, backlog=50, interface=''): """ @param reactor: An L{IReactorSSL} provider. @param port: The port number used listening @param sslContextFactory: An instance of L{twisted.internet._sslverify.OpenSSLCertificateOptions}. @param timeout: number of seconds to wait before assuming the connection has failed. @param bindAddress: a (host, port tuple of local address to bind to, or None. """ self._reactor = reactor self._port = port self._sslContextFactory = sslContextFactory self._backlog = backlog self._interface = interface def listen(self, protocolFactory): """ Implement L{IStreamServerEndpoint.listen} to listen for SSL on a TCP socket. """ return defer.execute(self._reactor.listenSSL, self._port, protocolFactory, contextFactory=self._sslContextFactory, backlog=self._backlog, interface=self._interface) class SSL4ClientEndpoint(object): """ SSL secured TCP client endpoint with an IPv4 configuration @ivar _reactor: An L{IReactorSSL} provider. @type _host: str @ivar _host: The hostname to connect to as a C{str} @type _port: int @ivar _port: The port to connect to as C{int} @type _sslContextFactory: L{OpenSSLCertificateOptions} @var _sslContextFactory: SSL Configuration information as an L{OpenSSLCertificateOptions} @type _timeout: int @ivar _timeout: number of seconds to wait before assuming the connection has failed. @type _bindAddress: tuple @ivar _bindAddress: a (host, port) tuple of local address to bind to, or None. """ implements(interfaces.IStreamClientEndpoint) def __init__(self, reactor, host, port, sslContextFactory, timeout=30, bindAddress=None): """ @param reactor: An L{IReactorSSL} provider. @param host: A hostname, used when connecting @param port: The port number, used when connecting @param sslContextFactory: SSL Configuration information as An instance of L{OpenSSLCertificateOptions}. @param timeout: number of seconds to wait before assuming the connection has failed. @param bindAddress: a (host, port tuple of local address to bind to, or None. """ self._reactor = reactor self._host = host self._port = port self._sslContextFactory = sslContextFactory self._timeout = timeout self._bindAddress = bindAddress def connect(self, protocolFactory): """ Implement L{IStreamClientEndpoint.connect} to connect with SSL over TCP. """ try: wf = _WrappingFactory(protocolFactory) self._reactor.connectSSL( self._host, self._port, wf, self._sslContextFactory, timeout=self._timeout, bindAddress=self._bindAddress) return wf._onConnection except: return defer.fail() class UNIXServerEndpoint(object): """ UnixSocket server endpoint. @type path: str @ivar path: a path to a unix socket on the filesystem. @type _listenArgs: dict @ivar _listenArgs: A C{dict} of keyword args that will be passed to L{IReactorUNIX.listenUNIX} @var _reactor: An L{IReactorTCP} provider. """ implements(interfaces.IStreamServerEndpoint) def __init__(self, reactor, address, backlog=50, mode=0666, wantPID=0): """ @param reactor: An L{IReactorUNIX} provider. @param address: The path to the Unix socket file, used when listening @param listenArgs: An optional dict of keyword args that will be passed to L{IReactorUNIX.listenUNIX} @param backlog: number of connections to allow in backlog. @param mode: mode to set on the unix socket. This parameter is deprecated. Permissions should be set on the directory which contains the UNIX socket. @param wantPID: if True, create a pidfile for the socket. """ self._reactor = reactor self._address = address self._backlog = backlog self._mode = mode self._wantPID = wantPID def listen(self, protocolFactory): """ Implement L{IStreamServerEndpoint.listen} to listen on a UNIX socket. """ return defer.execute(self._reactor.listenUNIX, self._address, protocolFactory, backlog=self._backlog, mode=self._mode, wantPID=self._wantPID) class UNIXClientEndpoint(object): """ UnixSocket client endpoint. @type _path: str @ivar _path: a path to a unix socket on the filesystem. @type _timeout: int @ivar _timeout: number of seconds to wait before assuming the connection has failed. @type _checkPID: bool @ivar _checkPID: if True, check for a pid file to verify that a server is listening. @var _reactor: An L{IReactorUNIX} provider. """ implements(interfaces.IStreamClientEndpoint) def __init__(self, reactor, path, timeout=30, checkPID=0): """ @param reactor: An L{IReactorUNIX} provider. @param path: The path to the Unix socket file, used when connecting @param timeout: number of seconds to wait before assuming the connection has failed. @param checkPID: if True, check for a pid file to verify that a server is listening. """ self._reactor = reactor self._path = path self._timeout = timeout self._checkPID = checkPID def connect(self, protocolFactory): """ Implement L{IStreamClientEndpoint.connect} to connect via a UNIX Socket """ try: wf = _WrappingFactory(protocolFactory) self._reactor.connectUNIX( self._path, wf, timeout=self._timeout, checkPID=self._checkPID) return wf._onConnection except: return defer.fail() class AdoptedStreamServerEndpoint(object): """ An endpoint for listening on a file descriptor initialized outside of Twisted. @ivar _used: A C{bool} indicating whether this endpoint has been used to listen with a factory yet. C{True} if so. """ _close = os.close _setNonBlocking = staticmethod(fdesc.setNonBlocking) def __init__(self, reactor, fileno, addressFamily): """ @param reactor: An L{IReactorSocket} provider. @param fileno: An integer file descriptor corresponding to a listening I{SOCK_STREAM} socket. @param addressFamily: The address family of the socket given by C{fileno}. """ self.reactor = reactor self.fileno = fileno self.addressFamily = addressFamily self._used = False def listen(self, factory): """ Implement L{IStreamServerEndpoint.listen} to start listening on, and then close, C{self._fileno}. """ if self._used: return defer.fail(error.AlreadyListened()) self._used = True try: self._setNonBlocking(self.fileno) port = self.reactor.adoptStreamPort( self.fileno, self.addressFamily, factory) self._close(self.fileno) except: return defer.fail() return defer.succeed(port) def _parseTCP(factory, port, interface="", backlog=50): """ Internal parser function for L{_parseServer} to convert the string arguments for a TCP(IPv4) stream endpoint into the structured arguments. @param factory: the protocol factory being parsed, or C{None}. (This was a leftover argument from when this code was in C{strports}, and is now mostly None and unused.) @type factory: L{IProtocolFactory} or C{NoneType} @param port: the integer port number to bind @type port: C{str} @param interface: the interface IP to listen on @param backlog: the length of the listen queue @type backlog: C{str} @return: a 2-tuple of (args, kwargs), describing the parameters to L{IReactorTCP.listenTCP} (or, modulo argument 2, the factory, arguments to L{TCP4ServerEndpoint}. """ return (int(port), factory), {'interface': interface, 'backlog': int(backlog)} def _parseUNIX(factory, address, mode='666', backlog=50, lockfile=True): """ Internal parser function for L{_parseServer} to convert the string arguments for a UNIX (AF_UNIX/SOCK_STREAM) stream endpoint into the structured arguments. @param factory: the protocol factory being parsed, or C{None}. (This was a leftover argument from when this code was in C{strports}, and is now mostly None and unused.) @type factory: L{IProtocolFactory} or C{NoneType} @param address: the pathname of the unix socket @type address: C{str} @param backlog: the length of the listen queue @type backlog: C{str} @param lockfile: A string '0' or '1', mapping to True and False respectively. See the C{wantPID} argument to C{listenUNIX} @return: a 2-tuple of (args, kwargs), describing the parameters to L{IReactorTCP.listenUNIX} (or, modulo argument 2, the factory, arguments to L{UNIXServerEndpoint}. """ return ( (address, factory), {'mode': int(mode, 8), 'backlog': int(backlog), 'wantPID': bool(int(lockfile))}) def _parseSSL(factory, port, privateKey="server.pem", certKey=None, sslmethod=None, interface='', backlog=50): """ Internal parser function for L{_parseServer} to convert the string arguments for an SSL (over TCP/IPv4) stream endpoint into the structured arguments. @param factory: the protocol factory being parsed, or C{None}. (This was a leftover argument from when this code was in C{strports}, and is now mostly None and unused.) @type factory: L{IProtocolFactory} or C{NoneType} @param port: the integer port number to bind @type port: C{str} @param interface: the interface IP to listen on @param backlog: the length of the listen queue @type backlog: C{str} @param privateKey: The file name of a PEM format private key file. @type privateKey: C{str} @param certKey: The file name of a PEM format certificate file. @type certKey: C{str} @param sslmethod: The string name of an SSL method, based on the name of a constant in C{OpenSSL.SSL}. Must be one of: "SSLv23_METHOD", "SSLv2_METHOD", "SSLv3_METHOD", "TLSv1_METHOD". @type sslmethod: C{str} @return: a 2-tuple of (args, kwargs), describing the parameters to L{IReactorSSL.listenSSL} (or, modulo argument 2, the factory, arguments to L{SSL4ServerEndpoint}. """ from twisted.internet import ssl if certKey is None: certKey = privateKey kw = {} if sslmethod is not None: kw['sslmethod'] = getattr(ssl.SSL, sslmethod) cf = ssl.DefaultOpenSSLContextFactory(privateKey, certKey, **kw) return ((int(port), factory, cf), {'interface': interface, 'backlog': int(backlog)}) class _SystemdParser(object): """ Stream server endpoint string parser for the I{systemd} endpoint type. @ivar prefix: See L{IStreamClientEndpointStringParser.prefix}. @ivar _sddaemon: A L{ListenFDs} instance used to translate an index into an actual file descriptor. """ implements(IPlugin, IStreamServerEndpointStringParser) #_sddaemon = ListenFDs.fromEnvironment() prefix = "systemd" def _parseServer(self, reactor, domain, index): """ Internal parser function for L{_parseServer} to convert the string arguments for a systemd server endpoint into structured arguments for L{AdoptedStreamServerEndpoint}. @param reactor: An L{IReactorSocket} provider. @param domain: The domain (or address family) of the socket inherited from systemd. This is a string like C{"INET"} or C{"UNIX"}, ie the name of an address family from the L{socket} module, without the C{"AF_"} prefix. @type domain: C{str} @param index: An offset into the list of file descriptors inherited from systemd. @type index: C{str} @return: A two-tuple of parsed positional arguments and parsed keyword arguments (a tuple and a dictionary). These can be used to construct a L{AdoptedStreamServerEndpoint}. """ index = int(index) fileno = self._sddaemon.inheritedDescriptors()[index] addressFamily = getattr(socket, 'AF_' + domain) return AdoptedStreamServerEndpoint(reactor, fileno, addressFamily) def parseStreamServer(self, reactor, *args, **kwargs): # Delegate to another function with a sane signature. This function has # an insane signature to trick zope.interface into believing the # interface is correctly implemented. return self._parseServer(reactor, *args, **kwargs) _serverParsers = {"tcp": _parseTCP, "unix": _parseUNIX, "ssl": _parseSSL, } _OP, _STRING = range(2) def _tokenize(description): """ Tokenize a strports string and yield each token. @param description: a string as described by L{serverFromString} or L{clientFromString}. @return: an iterable of 2-tuples of (L{_OP} or L{_STRING}, string). Tuples starting with L{_OP} will contain a second element of either ':' (i.e. 'next parameter') or '=' (i.e. 'assign parameter value'). For example, the string 'hello:greet\=ing=world' would result in a generator yielding these values:: _STRING, 'hello' _OP, ':' _STRING, 'greet=ing' _OP, '=' _STRING, 'world' """ current = '' ops = ':=' nextOps = {':': ':=', '=': ':'} description = iter(description) for n in description: if n in ops: yield _STRING, current yield _OP, n current = '' ops = nextOps[n] elif n == '\\': current += description.next() else: current += n yield _STRING, current def _parse(description): """ Convert a description string into a list of positional and keyword parameters, using logic vaguely like what Python does. @param description: a string as described by L{serverFromString} or L{clientFromString}. @return: a 2-tuple of C{(args, kwargs)}, where 'args' is a list of all ':'-separated C{str}s not containing an '=' and 'kwargs' is a map of all C{str}s which do contain an '='. For example, the result of C{_parse('a:b:d=1:c')} would be C{(['a', 'b', 'c'], {'d': '1'})}. """ args, kw = [], {} def add(sofar): if len(sofar) == 1: args.append(sofar[0]) else: kw[sofar[0]] = sofar[1] sofar = () for (type, value) in _tokenize(description): if type is _STRING: sofar += (value,) elif value == ':': add(sofar) sofar = () add(sofar) return args, kw # Mappings from description "names" to endpoint constructors. _endpointServerFactories = { 'TCP': TCP4ServerEndpoint, 'SSL': SSL4ServerEndpoint, 'UNIX': UNIXServerEndpoint, } _endpointClientFactories = { 'TCP': TCP4ClientEndpoint, 'SSL': SSL4ClientEndpoint, 'UNIX': UNIXClientEndpoint, } _NO_DEFAULT = object() def _parseServer(description, factory, default=None): """ Parse a stports description into a 2-tuple of arguments and keyword values. @param description: A description in the format explained by L{serverFromString}. @type description: C{str} @param factory: A 'factory' argument; this is left-over from twisted.application.strports, it's not really used. @type factory: L{IProtocolFactory} or L{None} @param default: Deprecated argument, specifying the default parser mode to use for unqualified description strings (those which do not have a ':' and prefix). @type default: C{str} or C{NoneType} @return: a 3-tuple of (plugin or name, arguments, keyword arguments) """ args, kw = _parse(description) if not args or (len(args) == 1 and not kw): deprecationMessage = ( "Unqualified strport description passed to 'service'." "Use qualified endpoint descriptions; for example, 'tcp:%s'." % (description,)) if default is None: default = 'tcp' warnings.warn( deprecationMessage, category=DeprecationWarning, stacklevel=4) elif default is _NO_DEFAULT: raise ValueError(deprecationMessage) # If the default has been otherwise specified, the user has already # been warned. args[0:0] = [default] endpointType = args[0] parser = _serverParsers.get(endpointType) if parser is None: for plugin in getPlugins(IStreamServerEndpointStringParser): if plugin.prefix == endpointType: return (plugin, args[1:], kw) raise ValueError("Unknown endpoint type: '%s'" % (endpointType,)) return (endpointType.upper(),) + parser(factory, *args[1:], **kw) def _serverFromStringLegacy(reactor, description, default): """ Underlying implementation of L{serverFromString} which avoids exposing the deprecated 'default' argument to anything but L{strports.service}. """ nameOrPlugin, args, kw = _parseServer(description, None, default) if type(nameOrPlugin) is not str: plugin = nameOrPlugin return plugin.parseStreamServer(reactor, *args, **kw) else: name = nameOrPlugin # Chop out the factory. args = args[:1] + args[2:] return _endpointServerFactories[name](reactor, *args, **kw) def serverFromString(reactor, description): """ Construct a stream server endpoint from an endpoint description string. The format for server endpoint descriptions is a simple string. It is a prefix naming the type of endpoint, then a colon, then the arguments for that endpoint. For example, you can call it like this to create an endpoint that will listen on TCP port 80:: serverFromString(reactor, "tcp:80") Additional arguments may be specified as keywords, separated with colons. For example, you can specify the interface for a TCP server endpoint to bind to like this:: serverFromString(reactor, "tcp:80:interface=127.0.0.1") SSL server endpoints may be specified with the 'ssl' prefix, and the private key and certificate files may be specified by the C{privateKey} and C{certKey} arguments:: serverFromString(reactor, "ssl:443:privateKey=key.pem:certKey=crt.pem") If a private key file name (C{privateKey}) isn't provided, a "server.pem" file is assumed to exist which contains the private key. If the certificate file name (C{certKey}) isn't provided, the private key file is assumed to contain the certificate as well. You may escape colons in arguments with a backslash, which you will need to use if you want to specify a full pathname argument on Windows:: serverFromString(reactor, "ssl:443:privateKey=C\\:/key.pem:certKey=C\\:/cert.pem") finally, the 'unix' prefix may be used to specify a filesystem UNIX socket, optionally with a 'mode' argument to specify the mode of the socket file created by C{listen}:: serverFromString(reactor, "unix:/var/run/finger") serverFromString(reactor, "unix:/var/run/finger:mode=660") This function is also extensible; new endpoint types may be registered as L{IStreamServerEndpointStringParser} plugins. See that interface for more information. @param reactor: The server endpoint will be constructed with this reactor. @param description: The strports description to parse. @return: A new endpoint which can be used to listen with the parameters given by by C{description}. @rtype: L{IStreamServerEndpoint} @raise ValueError: when the 'description' string cannot be parsed. @since: 10.2 """ return _serverFromStringLegacy(reactor, description, _NO_DEFAULT) def quoteStringArgument(argument): """ Quote an argument to L{serverFromString} and L{clientFromString}. Since arguments are separated with colons and colons are escaped with backslashes, some care is necessary if, for example, you have a pathname, you may be tempted to interpolate into a string like this:: serverFromString("ssl:443:privateKey=%s" % (myPathName,)) This may appear to work, but will have portability issues (Windows pathnames, for example). Usually you should just construct the appropriate endpoint type rather than interpolating strings, which in this case would be L{SSL4ServerEndpoint}. There are some use-cases where you may need to generate such a string, though; for example, a tool to manipulate a configuration file which has strports descriptions in it. To be correct in those cases, do this instead:: serverFromString("ssl:443:privateKey=%s" % (quoteStringArgument(myPathName),)) @param argument: The part of the endpoint description string you want to pass through. @type argument: C{str} @return: The quoted argument. @rtype: C{str} """ return argument.replace('\\', '\\\\').replace(':', '\\:') def _parseClientTCP(*args, **kwargs): """ Perform any argument value coercion necessary for TCP client parameters. Valid positional arguments to this function are host and port. Valid keyword arguments to this function are all L{IReactorTCP.connectTCP} arguments. @return: The coerced values as a C{dict}. """ if len(args) == 2: kwargs['port'] = int(args[1]) kwargs['host'] = args[0] elif len(args) == 1: if 'host' in kwargs: kwargs['port'] = int(args[0]) else: kwargs['host'] = args[0] try: kwargs['port'] = int(kwargs['port']) except KeyError: pass try: kwargs['timeout'] = int(kwargs['timeout']) except KeyError: pass return kwargs def _loadCAsFromDir(directoryPath): """ Load certificate-authority certificate objects in a given directory. @param directoryPath: a L{FilePath} pointing at a directory to load .pem files from. @return: a C{list} of L{OpenSSL.crypto.X509} objects. """ from twisted.internet import ssl caCerts = {} for child in directoryPath.children(): if not child.basename().split('.')[-1].lower() == 'pem': continue try: data = child.getContent() except IOError: # Permission denied, corrupt disk, we don't care. continue try: theCert = ssl.Certificate.loadPEM(data) except ssl.SSL.Error: # Duplicate certificate, invalid certificate, etc. We don't care. pass else: caCerts[theCert.digest()] = theCert.original return caCerts.values() def _parseClientSSL(*args, **kwargs): """ Perform any argument value coercion necessary for SSL client parameters. Valid keyword arguments to this function are all L{IReactorSSL.connectSSL} arguments except for C{contextFactory}. Instead, C{certKey} (the path name of the certificate file) C{privateKey} (the path name of the private key associated with the certificate) are accepted and used to construct a context factory. Valid positional arguments to this function are host and port. @param caCertsDir: The one parameter which is not part of L{IReactorSSL.connectSSL}'s signature, this is a path name used to construct a list of certificate authority certificates. The directory will be scanned for files ending in C{.pem}, all of which will be considered valid certificate authorities for this connection. @type caCertsDir: C{str} @return: The coerced values as a C{dict}. """ from twisted.internet import ssl kwargs = _parseClientTCP(*args, **kwargs) certKey = kwargs.pop('certKey', None) privateKey = kwargs.pop('privateKey', None) caCertsDir = kwargs.pop('caCertsDir', None) if certKey is not None: certx509 = ssl.Certificate.loadPEM( FilePath(certKey).getContent()).original else: certx509 = None if privateKey is not None: privateKey = ssl.PrivateCertificate.loadPEM( FilePath(privateKey).getContent()).privateKey.original else: privateKey = None if caCertsDir is not None: verify = True caCerts = _loadCAsFromDir(FilePath(caCertsDir)) else: verify = False caCerts = None kwargs['sslContextFactory'] = ssl.CertificateOptions( method=ssl.SSL.SSLv23_METHOD, certificate=certx509, privateKey=privateKey, verify=verify, caCerts=caCerts ) return kwargs def _parseClientUNIX(**kwargs): """ Perform any argument value coercion necessary for UNIX client parameters. Valid keyword arguments to this function are all L{IReactorUNIX.connectUNIX} arguments except for C{checkPID}. Instead, C{lockfile} is accepted and has the same meaning. @return: The coerced values as a C{dict}. """ try: kwargs['checkPID'] = bool(int(kwargs.pop('lockfile'))) except KeyError: pass try: kwargs['timeout'] = int(kwargs['timeout']) except KeyError: pass return kwargs _clientParsers = { 'TCP': _parseClientTCP, 'SSL': _parseClientSSL, 'UNIX': _parseClientUNIX, } def clientFromString(reactor, description): """ Construct a client endpoint from a description string. Client description strings are much like server description strings, although they take all of their arguments as keywords, aside from host and port. You can create a TCP client endpoint with the 'host' and 'port' arguments, like so:: clientFromString(reactor, "tcp:host=www.example.com:port=80") or, without specifying host and port keywords:: clientFromString(reactor, "tcp:www.example.com:80") Or you can specify only one or the other, as in the following 2 examples:: clientFromString(reactor, "tcp:host=www.example.com:80") clientFromString(reactor, "tcp:www.example.com:port=80") or an SSL client endpoint with those arguments, plus the arguments used by the server SSL, for a client certificate:: clientFromString(reactor, "ssl:web.example.com:443:" "privateKey=foo.pem:certKey=foo.pem") to specify your certificate trust roots, you can identify a directory with PEM files in it with the C{caCertsDir} argument:: clientFromString(reactor, "ssl:host=web.example.com:port=443:" "caCertsDir=/etc/ssl/certs") This function is also extensible; new endpoint types may be registered as L{IStreamClientEndpointStringParser} plugins. See that interface for more information. @param reactor: The client endpoint will be constructed with this reactor. @param description: The strports description to parse. @return: A new endpoint which can be used to connect with the parameters given by by C{description}. @rtype: L{IStreamClientEndpoint} @since: 10.2 """ args, kwargs = _parse(description) aname = args.pop(0) name = aname.upper() for plugin in getPlugins(IStreamClientEndpointStringParser): if plugin.prefix.upper() == name: return plugin.parseStreamClient(*args, **kwargs) if name not in _clientParsers: raise ValueError("Unknown endpoint type: %r" % (aname,)) kwargs = _clientParsers[name](*args, **kwargs) return _endpointClientFactories[name](reactor, **kwargs) calendarserver-5.2+dfsg/twext/who/0000755000175000017500000000000012322625326016262 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/who/index.py0000644000175000017500000001572712263343324017756 0ustar rahulrahul# -*- test-case-name: twext.who.test.test_xml -*- ## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Indexed directory service implementation. """ __all__ = [ "DirectoryService", "DirectoryRecord", ] from itertools import chain from twisted.python.constants import Names, NamedConstant from twisted.internet.defer import succeed, inlineCallbacks, returnValue from twext.who.util import ConstantsContainer from twext.who.util import describe, uniqueResult, iterFlags from twext.who.idirectory import FieldName as BaseFieldName from twext.who.expression import MatchExpression, MatchType, MatchFlags from twext.who.directory import DirectoryService as BaseDirectoryService from twext.who.directory import DirectoryRecord as BaseDirectoryRecord ## # Data type extentions ## class FieldName(Names): memberUIDs = NamedConstant() memberUIDs.description = "member UIDs" memberUIDs.multiValue = True ## # Directory Service ## class DirectoryService(BaseDirectoryService): """ XML directory service. """ fieldName = ConstantsContainer(chain( BaseDirectoryService.fieldName.iterconstants(), FieldName.iterconstants() )) indexedFields = ( BaseFieldName.recordType, BaseFieldName.uid, BaseFieldName.guid, BaseFieldName.shortNames, BaseFieldName.emailAddresses, FieldName.memberUIDs, ) def __init__(self, realmName): BaseDirectoryService.__init__(self, realmName) self.flush() @property def index(self): self.loadRecords() return self._index @index.setter def index(self, value): self._index = value def loadRecords(self): """ Load records. """ raise NotImplementedError("Subclasses must implement loadRecords().") def flush(self): """ Flush the index. """ self._index = None @staticmethod def _queryFlags(flags): predicate = lambda x: x normalize = lambda x: x if flags is not None: for flag in iterFlags(flags): if flag == MatchFlags.NOT: predicate = lambda x: not x elif flag == MatchFlags.caseInsensitive: normalize = lambda x: x.lower() else: raise NotImplementedError( "Unknown query flag: {0}".format(describe(flag)) ) return predicate, normalize def indexedRecordsFromMatchExpression(self, expression, records=None): """ Finds records in the internal indexes matching a single expression. @param expression: an expression @type expression: L{object} """ predicate, normalize = self._queryFlags(expression.flags) fieldIndex = self.index[expression.fieldName] matchValue = normalize(expression.fieldValue) matchType = expression.matchType if matchType == MatchType.startsWith: indexKeys = ( key for key in fieldIndex if predicate(normalize(key).startswith(matchValue)) ) elif matchType == MatchType.contains: indexKeys = ( key for key in fieldIndex if predicate(matchValue in normalize(key)) ) elif matchType == MatchType.equals: if predicate(True): indexKeys = (matchValue,) else: indexKeys = ( key for key in fieldIndex if normalize(key) != matchValue ) else: raise NotImplementedError( "Unknown match type: {0}".format(describe(matchType)) ) matchingRecords = set() for key in indexKeys: matchingRecords |= fieldIndex.get(key, frozenset()) if records is not None: matchingRecords &= records return succeed(matchingRecords) def unIndexedRecordsFromMatchExpression(self, expression, records=None): """ Finds records not in the internal indexes matching a single expression. @param expression: an expression @type expression: L{object} """ predicate, normalize = self._queryFlags(expression.flags) matchValue = normalize(expression.fieldValue) matchType = expression.matchType if matchType == MatchType.startsWith: match = lambda fieldValue: predicate( fieldValue.startswith(matchValue) ) elif matchType == MatchType.contains: match = lambda fieldValue: predicate(matchValue in fieldValue) elif matchType == MatchType.equals: match = lambda fieldValue: predicate(fieldValue == matchValue) else: raise NotImplementedError( "Unknown match type: {0}".format(describe(matchType)) ) result = set() if records is None: records = ( uniqueResult(values) for values in self.index[self.fieldName.uid].itervalues() ) for record in records: fieldValues = record.fields.get(expression.fieldName, None) if fieldValues is None: continue for fieldValue in fieldValues: if match(normalize(fieldValue)): result.add(record) return succeed(result) def recordsFromExpression(self, expression, records=None): if isinstance(expression, MatchExpression): if expression.fieldName in self.indexedFields: return self.indexedRecordsFromMatchExpression( expression, records=records ) else: return self.unIndexedRecordsFromMatchExpression( expression, records=records ) else: return BaseDirectoryService.recordsFromExpression( self, expression, records=records ) class DirectoryRecord(BaseDirectoryRecord): """ XML directory record """ @inlineCallbacks def members(self): members = set() for uid in getattr(self, "memberUIDs", ()): members.add((yield self.service.recordWithUID(uid))) returnValue(members) def groups(self): return self.service.recordsWithFieldValue( FieldName.memberUIDs, self.uid ) calendarserver-5.2+dfsg/twext/who/expression.py0000644000175000017500000000513612263343324021037 0ustar rahulrahul# -*- test-case-name: twext.who.test.test_expression -*- ## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Directory query expressions. """ __all__ = [ "MatchType", "MatchFlags", "MatchExpression", ] from twisted.python.constants import Names, NamedConstant from twisted.python.constants import Flags, FlagConstant ## # Match expression ## class MatchType(Names): """ Query match types. """ equals = NamedConstant() startsWith = NamedConstant() contains = NamedConstant() equals.description = "equals" startsWith.description = "starts with" contains.description = "contains" class MatchFlags(Flags): """ Match expression flags. """ NOT = FlagConstant() NOT.description = "not" caseInsensitive = FlagConstant() caseInsensitive.description = "case insensitive" class MatchExpression(object): """ Query for a matching value in a given field. @ivar fieldName: a L{NamedConstant} specifying the field @ivar fieldValue: a text value to match @ivar matchType: a L{NamedConstant} specifying the match algorythm @ivar flags: L{NamedConstant} specifying additional options """ def __init__( self, fieldName, fieldValue, matchType=MatchType.equals, flags=None ): self.fieldName = fieldName self.fieldValue = fieldValue self.matchType = matchType self.flags = flags def __repr__(self): def describe(constant): return getattr(constant, "description", str(constant)) if self.flags is None: flags = "" else: flags = " ({0})".format(describe(self.flags)) return ( "<{self.__class__.__name__}: {fieldName!r} " "{matchType} {fieldValue!r}{flags}>" .format( self=self, fieldName=describe(self.fieldName), matchType=describe(self.matchType), fieldValue=describe(self.fieldValue), flags=flags, ) ) calendarserver-5.2+dfsg/twext/who/directory.py0000644000175000017500000002674412263343324020654 0ustar rahulrahul# -*- test-case-name: twext.who.test.test_directory -*- ## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Generic directory service base implementation """ __all__ = [ "DirectoryService", "DirectoryRecord", ] from uuid import UUID from zope.interface import implementer from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import succeed, fail from twext.who.idirectory import QueryNotSupportedError, NotAllowedError from twext.who.idirectory import FieldName, RecordType from twext.who.idirectory import Operand from twext.who.idirectory import IDirectoryService, IDirectoryRecord from twext.who.expression import MatchExpression from twext.who.util import uniqueResult, describe @implementer(IDirectoryService) class DirectoryService(object): """ Generic implementation of L{IDirectoryService}. This is a complete implementation of L{IDirectoryService}, with support for the query operands in L{Operand}. The C{recordsWith*} methods are all implemented in terms of L{recordsWithFieldValue}, which is in turn implemented in terms of L{recordsFromExpression}. L{recordsFromQuery} is also implemented in terms of {recordsFromExpression}. L{recordsFromExpression} (and therefore most uses of the other methods) will always fail with a L{QueryNotSupportedError}. A subclass should therefore override L{recordsFromExpression} with an implementation that handles any queries that it can support and its superclass' implementation with any query it cannot support. A subclass may override L{recordsFromQuery} if it is to support additional operands. L{updateRecords} and L{removeRecords} will fail with L{NotAllowedError} when asked to modify data. A subclass should override these methods if is to allow editing of directory information. @cvar recordType: a L{Names} class or compatible object (eg. L{ConstantsContainer}) which contains the L{NamedConstant}s denoting the record types that are supported by this directory service. @cvar fieldName: a L{Names} class or compatible object (eg. L{ConstantsContainer}) which contains the L{NamedConstant}s denoting the record field names that are supported by this directory service. @cvar normalizedFields: a L{dict} mapping of (ie. L{NamedConstant}s contained in the C{fieldName} class variable) to callables that take a field value (a L{unicode}) and return a normalized field value (also a L{unicode}). """ recordType = RecordType fieldName = FieldName normalizedFields = { FieldName.guid: lambda g: UUID(g).hex, FieldName.emailAddresses: lambda e: bytes(e).lower(), } def __init__(self, realmName): """ @param realmName: a realm name @type realmName: unicode """ self.realmName = realmName def __repr__(self): return ( "<{self.__class__.__name__} {self.realmName!r}>" .format(self=self) ) def recordTypes(self): return self.recordType.iterconstants() def recordsFromExpression(self, expression, records=None): """ Finds records matching a single expression. @note: The implementation in L{DirectoryService} always raises L{QueryNotSupportedError}. @note: This L{DirectoryService} adds a C{records} keyword argument to the interface defined by L{IDirectoryService}. This allows the implementation of L{DirectoryService.recordsFromQuery} to narrow the scope of records being searched as it applies expressions. This is therefore relevant to subclasses, which need to support the added parameter, but not to users of L{IDirectoryService}. @param expression: an expression to apply @type expression: L{object} @param records: a set of records to limit the search to. C{None} if the whole directory should be searched. @type records: L{set} or L{frozenset} @return: The matching records. @rtype: deferred iterable of L{IDirectoryRecord}s @raises: L{QueryNotSupportedError} if the expression is not supported by this directory service. """ return fail(QueryNotSupportedError( "Unknown expression: {0}".format(expression) )) @inlineCallbacks def recordsFromQuery(self, expressions, operand=Operand.AND): expressionIterator = iter(expressions) try: expression = expressionIterator.next() except StopIteration: returnValue(()) results = set((yield self.recordsFromExpression(expression))) for expression in expressions: if operand == Operand.AND: if not results: # No need to bother continuing here returnValue(()) records = results else: records = None recordsMatchingExpression = frozenset(( yield self.recordsFromExpression(expression, records=records) )) if operand == Operand.AND: results &= recordsMatchingExpression elif operand == Operand.OR: results |= recordsMatchingExpression else: raise QueryNotSupportedError( "Unknown operand: {0}".format(operand) ) returnValue(results) def recordsWithFieldValue(self, fieldName, value): return self.recordsFromExpression(MatchExpression(fieldName, value)) @inlineCallbacks def recordWithUID(self, uid): returnValue(uniqueResult( (yield self.recordsWithFieldValue(FieldName.uid, uid)) )) @inlineCallbacks def recordWithGUID(self, guid): returnValue(uniqueResult( (yield self.recordsWithFieldValue(FieldName.guid, guid)) )) def recordsWithRecordType(self, recordType): return self.recordsWithFieldValue(FieldName.recordType, recordType) @inlineCallbacks def recordWithShortName(self, recordType, shortName): returnValue(uniqueResult((yield self.recordsFromQuery(( MatchExpression(FieldName.recordType, recordType), MatchExpression(FieldName.shortNames, shortName), ))))) def recordsWithEmailAddress(self, emailAddress): return self.recordsWithFieldValue( FieldName.emailAddresses, emailAddress, ) def updateRecords(self, records, create=False): for record in records: return fail(NotAllowedError("Record updates not allowed.")) return succeed(None) def removeRecords(self, uids): for uid in uids: return fail(NotAllowedError("Record removal not allowed.")) return succeed(None) @implementer(IDirectoryRecord) class DirectoryRecord(object): """ Generic implementation of L{IDirectoryService}. This is an incomplete implementation of L{IDirectoryRecord}. L{groups} will always fail with L{NotImplementedError} and L{members} will do so if this is a group record. A subclass should override these methods to support group membership and complete this implementation. @cvar requiredFields: an iterable of field names that must be present in all directory records. """ requiredFields = ( FieldName.uid, FieldName.recordType, FieldName.shortNames, ) def __init__(self, service, fields): for fieldName in self.requiredFields: if fieldName not in fields or not fields[fieldName]: raise ValueError("{0} field is required.".format(fieldName)) if FieldName.isMultiValue(fieldName): values = fields[fieldName] if len(values) == 0: raise ValueError( "{0} field must have at least one value." .format(fieldName) ) for value in values: if not value: raise ValueError( "{0} field must not be empty.".format(fieldName) ) if ( fields[FieldName.recordType] not in service.recordType.iterconstants() ): raise ValueError( "Record type must be one of {0!r}, not {1!r}.".format( tuple(service.recordType.iterconstants()), fields[FieldName.recordType], ) ) # Normalize fields normalizedFields = {} for name, value in fields.items(): normalize = service.normalizedFields.get(name, None) if normalize is None: normalizedFields[name] = value continue if FieldName.isMultiValue(name): normalizedFields[name] = tuple((normalize(v) for v in value)) else: normalizedFields[name] = normalize(value) self.service = service self.fields = normalizedFields def __repr__(self): return ( "<{self.__class__.__name__} ({recordType}){shortName}>".format( self=self, recordType=describe(self.recordType), shortName=self.shortNames[0], ) ) def __eq__(self, other): if IDirectoryRecord.implementedBy(other.__class__): return ( self.service == other.service and self.fields == other.fields ) return NotImplemented def __ne__(self, other): eq = self.__eq__(other) if eq is NotImplemented: return NotImplemented return not eq def __getattr__(self, name): try: fieldName = self.service.fieldName.lookupByName(name) except ValueError: raise AttributeError(name) try: return self.fields[fieldName] except KeyError: raise AttributeError(name) def description(self): description = [self.__class__.__name__, ":"] for name, value in self.fields.items(): if hasattr(name, "description"): name = name.description else: name = str(name) if hasattr(value, "description"): value = value.description else: value = str(value) description.append("\n ") description.append(name) description.append(" = ") description.append(value) return "".join(description) def members(self): if self.recordType == RecordType.group: return fail( NotImplementedError("Subclasses must implement members()") ) return succeed(()) def groups(self): return fail(NotImplementedError("Subclasses must implement groups()")) calendarserver-5.2+dfsg/twext/who/test/0000755000175000017500000000000012322625326017241 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/who/test/test_xml.py0000644000175000017500000006230112263343324021453 0ustar rahulrahul## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ XML directory service tests. """ from time import sleep from twisted.trial import unittest from twisted.python.filepath import FilePath from twisted.internet.defer import inlineCallbacks from twext.who.idirectory import NoSuchRecordError from twext.who.idirectory import Operand from twext.who.expression import MatchExpression, MatchType, MatchFlags from twext.who.xml import ParseError from twext.who.xml import DirectoryService, DirectoryRecord from twext.who.test import test_directory class BaseTest(unittest.TestCase): def service(self, xmlData=None): return xmlService(self.mktemp(), xmlData) def assertRecords(self, records, uids): self.assertEquals( frozenset((record.uid for record in records)), frozenset((uids)), ) class DirectoryServiceBaseTest( BaseTest, test_directory.BaseDirectoryServiceTest, ): def test_repr(self): service = self.service() self.assertEquals(repr(service), "") service.loadRecords() self.assertEquals(repr(service), "") @inlineCallbacks def test_recordWithUID(self): service = self.service() record = (yield service.recordWithUID("__null__")) self.assertEquals(record, None) record = (yield service.recordWithUID("__wsanchez__")) self.assertEquals(record.uid, "__wsanchez__") @inlineCallbacks def test_recordWithGUID(self): service = self.service() record = ( yield service.recordWithGUID( "6C495FCD-7E78-4D5C-AA66-BC890AD04C9D" ) ) self.assertEquals(record, None) @inlineCallbacks def test_recordsWithRecordType(self): service = self.service() records = (yield service.recordsWithRecordType(object())) self.assertEquals(set(records), set()) records = ( yield service.recordsWithRecordType(service.recordType.user) ) self.assertRecords( records, ( "__wsanchez__", "__glyph__", "__sagen__", "__cdaboo__", "__dre__", "__exarkun__", "__dreid__", "__alyssa__", "__joe__", ), ) records = ( yield service.recordsWithRecordType(service.recordType.group) ) self.assertRecords( records, ( "__calendar-dev__", "__twisted__", "__developers__", ), ) @inlineCallbacks def test_recordWithShortName(self): service = self.service() record = ( yield service.recordWithShortName( service.recordType.user, "null", ) ) self.assertEquals(record, None) record = ( yield service.recordWithShortName( service.recordType.user, "wsanchez", ) ) self.assertEquals(record.uid, "__wsanchez__") record = ( yield service.recordWithShortName( service.recordType.user, "wilfredo_sanchez", ) ) self.assertEquals(record.uid, "__wsanchez__") @inlineCallbacks def test_recordsWithEmailAddress(self): service = self.service() records = ( yield service.recordsWithEmailAddress( "wsanchez@bitbucket.calendarserver.org" ) ) self.assertRecords(records, ("__wsanchez__",)) records = ( yield service.recordsWithEmailAddress( "wsanchez@devnull.twistedmatrix.com" ) ) self.assertRecords(records, ("__wsanchez__",)) records = ( yield service.recordsWithEmailAddress( "shared@example.com" ) ) self.assertRecords(records, ("__sagen__", "__dre__")) class DirectoryServiceRealmTest(BaseTest): def test_realmNameImmutable(self): def setRealmName(): service = self.service() service.realmName = "foo" self.assertRaises(AssertionError, setRealmName) class DirectoryServiceParsingTest(BaseTest): def test_reloadInterval(self): service = self.service() service.loadRecords(stat=False) lastRefresh = service._lastRefresh self.assertTrue(service._lastRefresh) sleep(1) service.loadRecords(stat=False) self.assertEquals(lastRefresh, service._lastRefresh) def test_reloadStat(self): service = self.service() service.loadRecords(loadNow=True) lastRefresh = service._lastRefresh self.assertTrue(service._lastRefresh) sleep(1) service.loadRecords(loadNow=True) self.assertEquals(lastRefresh, service._lastRefresh) def test_badXML(self): service = self.service(xmlData="Hello") self.assertRaises(ParseError, service.loadRecords) def test_badRootElement(self): service = self.service(xmlData=( """ """ )) self.assertRaises(ParseError, service.loadRecords) try: service.loadRecords() except ParseError as e: self.assertTrue(str(e).startswith("Incorrect root element"), e) else: raise AssertionError def test_noRealmName(self): service = self.service(xmlData=( """ """ )) self.assertRaises(ParseError, service.loadRecords) try: service.loadRecords() except ParseError as e: self.assertTrue(str(e).startswith("No realm name"), e) else: raise AssertionError def test_unknownFieldElementsClean(self): service = self.service() self.assertEquals(set(service.unknownFieldElements), set()) def test_unknownFieldElementsDirty(self): service = self.service(xmlData=( """ __wsanchez__ wsanchez Community and Freedom Party """ )) self.assertEquals( set(service.unknownFieldElements), set(("political-affiliation",)) ) def test_unknownRecordTypesClean(self): service = self.service() self.assertEquals(set(service.unknownRecordTypes), set()) def test_unknownRecordTypesDirty(self): service = self.service(xmlData=( """ __d600__ d600 Nikon D600 """ )) self.assertEquals(set(service.unknownRecordTypes), set(("camera",))) class DirectoryServiceQueryTest(BaseTest): @inlineCallbacks def test_queryAnd(self): service = self.service() records = yield service.recordsFromQuery( ( service.query("emailAddresses", "shared@example.com"), service.query("shortNames", "sagen"), ), operand=Operand.AND ) self.assertRecords(records, ("__sagen__",)) @inlineCallbacks def test_queryAndNoneFirst(self): """ Test optimized case, where first expression yields no results. """ service = self.service() records = yield service.recordsFromQuery( ( service.query("emailAddresses", "nobody@example.com"), service.query("shortNames", "sagen"), ), operand=Operand.AND ) self.assertRecords(records, ()) @inlineCallbacks def test_queryOr(self): service = self.service() records = yield service.recordsFromQuery( ( service.query("emailAddresses", "shared@example.com"), service.query("shortNames", "wsanchez"), ), operand=Operand.OR ) self.assertRecords(records, ("__sagen__", "__dre__", "__wsanchez__")) @inlineCallbacks def test_queryNot(self): service = self.service() records = yield service.recordsFromQuery( ( service.query("emailAddresses", "shared@example.com"), service.query("shortNames", "sagen", flags=MatchFlags.NOT), ), operand=Operand.AND ) self.assertRecords(records, ("__dre__",)) @inlineCallbacks def test_queryNotNoIndex(self): service = self.service() records = yield service.recordsFromQuery( ( service.query("emailAddresses", "shared@example.com"), service.query( "fullNames", "Andre LaBranche", flags=MatchFlags.NOT ), ), operand=Operand.AND ) self.assertRecords(records, ("__sagen__",)) @inlineCallbacks def test_queryCaseInsensitive(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "shortNames", "SagEn", flags=MatchFlags.caseInsensitive ), )) self.assertRecords(records, ("__sagen__",)) @inlineCallbacks def test_queryCaseInsensitiveNoIndex(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "fullNames", "moRGen SAGen", flags=MatchFlags.caseInsensitive ), )) self.assertRecords(records, ("__sagen__",)) @inlineCallbacks def test_queryStartsWith(self): service = self.service() records = yield service.recordsFromQuery(( service.query("shortNames", "wil", matchType=MatchType.startsWith), )) self.assertRecords(records, ("__wsanchez__",)) @inlineCallbacks def test_queryStartsWithNoIndex(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "fullNames", "Wilfredo", matchType=MatchType.startsWith ), )) self.assertRecords(records, ("__wsanchez__",)) @inlineCallbacks def test_queryStartsWithNot(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "shortNames", "w", matchType=MatchType.startsWith, flags=MatchFlags.NOT, ), )) self.assertRecords( records, ( '__alyssa__', '__calendar-dev__', '__cdaboo__', '__developers__', '__dre__', '__dreid__', '__exarkun__', '__glyph__', '__joe__', '__sagen__', '__twisted__', ), ) @inlineCallbacks def test_queryStartsWithNotAny(self): """ FIXME?: In the this case, the record __wsanchez__ has two shortNames, and one doesn't match the query. Should it be included or not? It is, because one matches the query, but should NOT require that all match? """ service = self.service() records = yield service.recordsFromQuery(( service.query( "shortNames", "wil", matchType=MatchType.startsWith, flags=MatchFlags.NOT, ), )) self.assertRecords( records, ( '__alyssa__', '__calendar-dev__', '__cdaboo__', '__developers__', '__dre__', '__dreid__', '__exarkun__', '__glyph__', '__joe__', '__sagen__', '__twisted__', '__wsanchez__', ), ) @inlineCallbacks def test_queryStartsWithNotNoIndex(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "fullNames", "Wilfredo", matchType=MatchType.startsWith, flags=MatchFlags.NOT, ), )) self.assertRecords( records, ( '__alyssa__', '__calendar-dev__', '__cdaboo__', '__developers__', '__dre__', '__dreid__', '__exarkun__', '__glyph__', '__joe__', '__sagen__', '__twisted__', ), ) @inlineCallbacks def test_queryStartsWithCaseInsensitive(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "shortNames", "WIL", matchType=MatchType.startsWith, flags=MatchFlags.caseInsensitive, ), )) self.assertRecords(records, ("__wsanchez__",)) @inlineCallbacks def test_queryStartsWithCaseInsensitiveNoIndex(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "fullNames", "wilfrEdo", matchType=MatchType.startsWith, flags=MatchFlags.caseInsensitive, ), )) self.assertRecords(records, ("__wsanchez__",)) @inlineCallbacks def test_queryContains(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "shortNames", "sanchez", matchType=MatchType.contains ), )) self.assertRecords(records, ("__wsanchez__",)) @inlineCallbacks def test_queryContainsNoIndex(self): service = self.service() records = yield service.recordsFromQuery(( service.query("fullNames", "fred", matchType=MatchType.contains), )) self.assertRecords(records, ("__wsanchez__",)) @inlineCallbacks def test_queryContainsNot(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "shortNames", "sanchez", matchType=MatchType.contains, flags=MatchFlags.NOT, ), )) self.assertRecords( records, ( '__alyssa__', '__calendar-dev__', '__cdaboo__', '__developers__', '__dre__', '__dreid__', '__exarkun__', '__glyph__', '__joe__', '__sagen__', '__twisted__', ), ) @inlineCallbacks def test_queryContainsNotNoIndex(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "fullNames", "fred", matchType=MatchType.contains, flags=MatchFlags.NOT, ), )) self.assertRecords( records, ( '__alyssa__', '__calendar-dev__', '__cdaboo__', '__developers__', '__dre__', '__dreid__', '__exarkun__', '__glyph__', '__joe__', '__sagen__', '__twisted__', ), ) @inlineCallbacks def test_queryContainsCaseInsensitive(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "shortNames", "Sanchez", matchType=MatchType.contains, flags=MatchFlags.caseInsensitive, ), )) self.assertRecords(records, ("__wsanchez__",)) @inlineCallbacks def test_queryContainsCaseInsensitiveNoIndex(self): service = self.service() records = yield service.recordsFromQuery(( service.query( "fullNames", "frEdo", matchType=MatchType.contains, flags=MatchFlags.caseInsensitive, ), )) self.assertRecords(records, ("__wsanchez__",)) class DirectoryServiceMutableTest(BaseTest): @inlineCallbacks def test_updateRecord(self): service = self.service() record = (yield service.recordWithUID("__wsanchez__")) fields = record.fields.copy() fields[service.fieldName.fullNames] = ["Wilfredo Sanchez Vega"] updatedRecord = DirectoryRecord(service, fields) yield service.updateRecords((updatedRecord,)) # Verify change is present immediately record = (yield service.recordWithUID("__wsanchez__")) self.assertEquals( set(record.fullNames), set(("Wilfredo Sanchez Vega",)) ) # Verify change is persisted service.flush() record = (yield service.recordWithUID("__wsanchez__")) self.assertEquals( set(record.fullNames), set(("Wilfredo Sanchez Vega",)) ) @inlineCallbacks def test_addRecord(self): service = self.service() newRecord = DirectoryRecord( service, fields={ service.fieldName.uid: "__plugh__", service.fieldName.recordType: service.recordType.user, service.fieldName.shortNames: ("plugh",), } ) yield service.updateRecords((newRecord,), create=True) # Verify change is present immediately record = (yield service.recordWithUID("__plugh__")) self.assertEquals(set(record.shortNames), set(("plugh",))) # Verify change is persisted service.flush() record = (yield service.recordWithUID("__plugh__")) self.assertEquals(set(record.shortNames), set(("plugh",))) def test_addRecordNoCreate(self): service = self.service() newRecord = DirectoryRecord( service, fields={ service.fieldName.uid: "__plugh__", service.fieldName.recordType: service.recordType.user, service.fieldName.shortNames: ("plugh",), } ) self.assertFailure( service.updateRecords((newRecord,)), NoSuchRecordError ) @inlineCallbacks def test_removeRecord(self): service = self.service() yield service.removeRecords(("__wsanchez__",)) # Verify change is present immediately self.assertEquals((yield service.recordWithUID("__wsanchez__")), None) # Verify change is persisted service.flush() self.assertEquals((yield service.recordWithUID("__wsanchez__")), None) def test_removeRecordNoExist(self): service = self.service() return service.removeRecords(("__plugh__",)) class DirectoryRecordTest(BaseTest, test_directory.BaseDirectoryRecordTest): @inlineCallbacks def test_members(self): service = self.service() record = (yield service.recordWithUID("__wsanchez__")) members = (yield record.members()) self.assertEquals(set(members), set()) record = (yield service.recordWithUID("__twisted__")) members = (yield record.members()) self.assertEquals( set((member.uid for member in members)), set(( "__wsanchez__", "__glyph__", "__exarkun__", "__dreid__", "__dre__", )) ) record = (yield service.recordWithUID("__developers__")) members = (yield record.members()) self.assertEquals( set((member.uid for member in members)), set(( "__calendar-dev__", "__twisted__", "__alyssa__", )) ) @inlineCallbacks def test_groups(self): service = self.service() record = (yield service.recordWithUID("__wsanchez__")) groups = (yield record.groups()) self.assertEquals( set(group.uid for group in groups), set(( "__calendar-dev__", "__twisted__", )) ) class QueryMixIn(object): def query(self, field, value, matchType=MatchType.equals, flags=None): name = getattr(self.fieldName, field) assert name is not None return MatchExpression( name, value, matchType=matchType, flags=flags, ) class TestService(DirectoryService, QueryMixIn): pass def xmlService(tmp, xmlData=None, serviceClass=None): if xmlData is None: xmlData = testXMLConfig if serviceClass is None: serviceClass = TestService filePath = FilePath(tmp) filePath.setContent(xmlData) return serviceClass(filePath) testXMLConfig = """ __wsanchez__ wsanchez wilfredo_sanchez Wilfredo Sanchez zehcnasw wsanchez@bitbucket.calendarserver.org wsanchez@devnull.twistedmatrix.com __glyph__ glyph Glyph Lefkowitz hpylg glyph@bitbucket.calendarserver.org glyph@devnull.twistedmatrix.com __sagen__ sagen Morgen Sagen negas sagen@bitbucket.calendarserver.org shared@example.com __cdaboo__ cdaboo Cyrus Daboo suryc cdaboo@bitbucket.calendarserver.org __dre__ dre Andre LaBranche erd dre@bitbucket.calendarserver.org shared@example.com __exarkun__ exarkun Jean-Paul Calderone nucraxe exarkun@devnull.twistedmatrix.com __dreid__ dreid David Reid dierd dreid@devnull.twistedmatrix.com __joe__ joe Joe Schmoe eoj joe@example.com __alyssa__ alyssa Alyssa P. Hacker assyla alyssa@example.com __calendar-dev__ calendar-dev Calendar Server developers dev@bitbucket.calendarserver.org __wsanchez__ __glyph__ __sagen__ __cdaboo__ __dre__ __twisted__ twisted Twisted Matrix Laboratories hack@devnull.twistedmatrix.com __wsanchez__ __glyph__ __exarkun__ __dreid__ __dre__ __developers__ developers All Developers __calendar-dev__ __twisted__ __alyssa__ """ calendarserver-5.2+dfsg/twext/who/test/test_aggregate.py0000644000175000017500000001562312263343324022606 0ustar rahulrahul## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Aggregate directory service tests. """ from twisted.python.components import proxyForInterface from twisted.trial import unittest from twext.who.idirectory import IDirectoryService, DirectoryConfigurationError from twext.who.aggregate import DirectoryService from twext.who.util import ConstantsContainer from twext.who.test import test_directory, test_xml from twext.who.test.test_xml import QueryMixIn, xmlService from twext.who.test.test_xml import TestService as XMLTestService class BaseTest(object): def service(self, services=None): if services is None: services = (self.xmlService(),) # # Make sure aggregate DirectoryService isn't making # implementation assumptions about the IDirectoryService # objects it gets. # services = tuple(( proxyForInterface(IDirectoryService)(s) for s in services )) class TestService(DirectoryService, QueryMixIn): pass return TestService("xyzzy", services) def xmlService(self, xmlData=None, serviceClass=None): return xmlService(self.mktemp(), xmlData, serviceClass) class DirectoryServiceBaseTest(BaseTest, test_xml.DirectoryServiceBaseTest): def test_repr(self): service = self.service() self.assertEquals(repr(service), "") class DirectoryServiceQueryTest(BaseTest, test_xml.DirectoryServiceQueryTest): pass class DirectoryServiceImmutableTest( BaseTest, test_directory.BaseDirectoryServiceImmutableTest, ): pass class AggregatedBaseTest(BaseTest): def service(self): class UsersDirectoryService(XMLTestService): recordType = ConstantsContainer((XMLTestService.recordType.user,)) class GroupsDirectoryService(XMLTestService): recordType = ConstantsContainer((XMLTestService.recordType.group,)) usersService = self.xmlService( testXMLConfigUsers, UsersDirectoryService ) groupsService = self.xmlService( testXMLConfigGroups, GroupsDirectoryService ) return BaseTest.service(self, (usersService, groupsService)) class DirectoryServiceAggregatedBaseTest( AggregatedBaseTest, DirectoryServiceBaseTest, ): pass class DirectoryServiceAggregatedQueryTest( AggregatedBaseTest, test_xml.DirectoryServiceQueryTest, ): pass class DirectoryServiceAggregatedImmutableTest( AggregatedBaseTest, test_directory.BaseDirectoryServiceImmutableTest, ): pass class DirectoryServiceTests(BaseTest, unittest.TestCase): def test_conflictingRecordTypes(self): self.assertRaises( DirectoryConfigurationError, BaseTest.service, self, (self.xmlService(), self.xmlService(testXMLConfigUsers)), ) testXMLConfigUsers = """ __wsanchez__ wsanchez wilfredo_sanchez Wilfredo Sanchez zehcnasw wsanchez@bitbucket.calendarserver.org wsanchez@devnull.twistedmatrix.com __glyph__ glyph Glyph Lefkowitz hpylg glyph@bitbucket.calendarserver.org glyph@devnull.twistedmatrix.com __sagen__ sagen Morgen Sagen negas sagen@bitbucket.calendarserver.org shared@example.com __cdaboo__ cdaboo Cyrus Daboo suryc cdaboo@bitbucket.calendarserver.org __dre__ dre Andre LaBranche erd dre@bitbucket.calendarserver.org shared@example.com __exarkun__ exarkun Jean-Paul Calderone nucraxe exarkun@devnull.twistedmatrix.com __dreid__ dreid David Reid dierd dreid@devnull.twistedmatrix.com __joe__ joe Joe Schmoe eoj joe@example.com __alyssa__ alyssa Alyssa P. Hacker assyla alyssa@example.com """ testXMLConfigGroups = """ __calendar-dev__ calendar-dev Calendar Server developers dev@bitbucket.calendarserver.org __wsanchez__ __glyph__ __sagen__ __cdaboo__ __dre__ __twisted__ twisted Twisted Matrix Laboratories hack@devnull.twistedmatrix.com __wsanchez__ __glyph__ __exarkun__ __dreid__ __dre__ __developers__ developers All Developers __calendar-dev__ __twisted__ __alyssa__ """ calendarserver-5.2+dfsg/twext/who/test/test_util.py0000644000175000017500000000701212263343324021626 0ustar rahulrahul## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Directory service utility tests. """ from twisted.trial import unittest from twisted.python.constants import Names, NamedConstant from twisted.python.constants import Flags, FlagConstant from twext.who.idirectory import DirectoryServiceError from twext.who.util import ConstantsContainer from twext.who.util import uniqueResult, describe class Tools(Names): hammer = NamedConstant() screwdriver = NamedConstant() hammer.description = "nail pounder" screwdriver.description = "screw twister" class Instruments(Names): hammer = NamedConstant() chisel = NamedConstant() class Switches(Flags): r = FlagConstant() g = FlagConstant() b = FlagConstant() r.description = "red" g.description = "green" b.description = "blue" black = FlagConstant() class ConstantsContainerTest(unittest.TestCase): def test_conflict(self): constants = set((Tools.hammer, Instruments.hammer)) self.assertRaises(ValueError, ConstantsContainer, constants) def test_attrs(self): constants = set((Tools.hammer, Tools.screwdriver, Instruments.chisel)) container = ConstantsContainer(constants) self.assertEquals(container.hammer, Tools.hammer) self.assertEquals(container.screwdriver, Tools.screwdriver) self.assertEquals(container.chisel, Instruments.chisel) self.assertRaises(AttributeError, lambda: container.plugh) def test_iterconstants(self): constants = set((Tools.hammer, Tools.screwdriver, Instruments.chisel)) container = ConstantsContainer(constants) self.assertEquals( set(container.iterconstants()), constants, ) def test_lookupByName(self): constants = set(( Instruments.hammer, Tools.screwdriver, Instruments.chisel, )) container = ConstantsContainer(constants) self.assertEquals( container.lookupByName("hammer"), Instruments.hammer, ) self.assertEquals( container.lookupByName("screwdriver"), Tools.screwdriver, ) self.assertEquals( container.lookupByName("chisel"), Instruments.chisel, ) self.assertRaises( ValueError, container.lookupByName, "plugh", ) class UtilTest(unittest.TestCase): def test_uniqueResult(self): self.assertEquals(1, uniqueResult((1,))) self.assertRaises(DirectoryServiceError, uniqueResult, (1, 2, 3)) def test_describe(self): self.assertEquals("nail pounder", describe(Tools.hammer)) self.assertEquals("hammer", describe(Instruments.hammer)) def test_describeFlags(self): self.assertEquals("blue", describe(Switches.b)) self.assertEquals("red|green", describe(Switches.r | Switches.g)) self.assertEquals("blue|black", describe(Switches.b | Switches.black)) calendarserver-5.2+dfsg/twext/who/test/test_directory.py0000644000175000017500000002562612263343324022670 0ustar rahulrahul## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Generic directory service base implementation tests. """ from zope.interface.verify import verifyObject, BrokenMethodImplementation from twisted.trial import unittest from twisted.trial.unittest import SkipTest from twisted.internet.defer import inlineCallbacks from twext.who.idirectory import QueryNotSupportedError, NotAllowedError from twext.who.idirectory import RecordType, FieldName from twext.who.idirectory import IDirectoryService, IDirectoryRecord from twext.who.directory import DirectoryService, DirectoryRecord class ServiceMixIn(object): realmName = "xyzzy" def service(self): if not hasattr(self, "_service"): self._service = DirectoryService(self.realmName) return self._service class BaseDirectoryServiceTest(ServiceMixIn): def test_interface(self): service = self.service() try: verifyObject(IDirectoryService, service) except BrokenMethodImplementation as e: self.fail(e) def test_init(self): service = self.service() self.assertEquals(service.realmName, self.realmName) def test_repr(self): service = self.service() self.assertEquals(repr(service), "") def test_recordTypes(self): service = self.service() self.assertEquals( set(service.recordTypes()), set(service.recordType.iterconstants()) ) @inlineCallbacks def test_recordsFromQueryNone(self): service = self.service() records = (yield service.recordsFromQuery(())) for record in records: self.failTest("No records expected") def test_recordsFromQueryBogus(self): service = self.service() self.assertFailure( service.recordsFromQuery((object(),)), QueryNotSupportedError ) def test_recordWithUID(self): raise SkipTest("Subclasses should implement this test.") def test_recordWithGUID(self): raise SkipTest("Subclasses should implement this test.") def test_recordsWithRecordType(self): raise SkipTest("Subclasses should implement this test.") def test_recordWithShortName(self): raise SkipTest("Subclasses should implement this test.") def test_recordsWithEmailAddress(self): raise SkipTest("Subclasses should implement this test.") class DirectoryServiceTest(unittest.TestCase, BaseDirectoryServiceTest): def test_recordsFromExpression(self): service = self.service() result = yield(service.recordsFromExpression(None)) self.assertFailure(result, QueryNotSupportedError) def test_recordWithUID(self): service = self.service() self.assertFailure( service.recordWithUID(None), QueryNotSupportedError ) def test_recordWithGUID(self): service = self.service() self.assertFailure( service.recordWithGUID(None), QueryNotSupportedError ) def test_recordsWithRecordType(self): service = self.service() self.assertFailure( service.recordsWithRecordType(None), QueryNotSupportedError ) def test_recordWithShortName(self): service = self.service() self.assertFailure( service.recordWithShortName(None, None), QueryNotSupportedError ) def test_recordsWithEmailAddress(self): service = self.service() self.assertFailure( service.recordsWithEmailAddress(None), QueryNotSupportedError ) class BaseDirectoryServiceImmutableTest(ServiceMixIn): def test_updateRecordsNotAllowed(self): service = self.service() newRecord = DirectoryRecord( service, fields={ service.fieldName.uid: "__plugh__", service.fieldName.recordType: service.recordType.user, service.fieldName.shortNames: ("plugh",), } ) self.assertFailure( service.updateRecords((newRecord,), create=True), NotAllowedError, ) self.assertFailure( service.updateRecords((newRecord,), create=False), NotAllowedError, ) def test_removeRecordsNotAllowed(self): service = self.service() service.removeRecords(()) self.assertFailure( service.removeRecords(("foo",)), NotAllowedError, ) class DirectoryServiceImmutableTest( unittest.TestCase, BaseDirectoryServiceImmutableTest, ): pass class BaseDirectoryRecordTest(ServiceMixIn): fields_wsanchez = { FieldName.uid: "UID:wsanchez", FieldName.recordType: RecordType.user, FieldName.shortNames: ("wsanchez", "wilfredo_sanchez"), FieldName.fullNames: ( "Wilfredo Sanchez", "Wilfredo Sanchez Vega", ), FieldName.emailAddresses: ( "wsanchez@calendarserver.org", "wsanchez@example.com", ) } fields_glyph = { FieldName.uid: "UID:glyph", FieldName.recordType: RecordType.user, FieldName.shortNames: ("glyph",), FieldName.fullNames: ("Glyph Lefkowitz",), FieldName.emailAddresses: ("glyph@calendarserver.org",) } fields_sagen = { FieldName.uid: "UID:sagen", FieldName.recordType: RecordType.user, FieldName.shortNames: ("sagen",), FieldName.fullNames: ("Morgen Sagen",), FieldName.emailAddresses: ("sagen@CalendarServer.org",) } fields_staff = { FieldName.uid: "UID:staff", FieldName.recordType: RecordType.group, FieldName.shortNames: ("staff",), FieldName.fullNames: ("Staff",), FieldName.emailAddresses: ("staff@CalendarServer.org",) } def makeRecord(self, fields=None, service=None): if fields is None: fields = self.fields_wsanchez if service is None: service = self.service() return DirectoryRecord(service, fields) def test_interface(self): record = self.makeRecord() try: verifyObject(IDirectoryRecord, record) except BrokenMethodImplementation as e: self.fail(e) def test_init(self): service = self.service() wsanchez = self.makeRecord(self.fields_wsanchez, service=service) self.assertEquals(wsanchez.service, service) self.assertEquals(wsanchez.fields, self.fields_wsanchez) def test_initWithNoUID(self): fields = self.fields_wsanchez.copy() del fields[FieldName.uid] self.assertRaises(ValueError, self.makeRecord, fields) fields = self.fields_wsanchez.copy() fields[FieldName.uid] = "" self.assertRaises(ValueError, self.makeRecord, fields) def test_initWithNoRecordType(self): fields = self.fields_wsanchez.copy() del fields[FieldName.recordType] self.assertRaises(ValueError, self.makeRecord, fields) fields = self.fields_wsanchez.copy() fields[FieldName.recordType] = "" self.assertRaises(ValueError, self.makeRecord, fields) def test_initWithNoShortNames(self): fields = self.fields_wsanchez.copy() del fields[FieldName.shortNames] self.assertRaises(ValueError, self.makeRecord, fields) fields = self.fields_wsanchez.copy() fields[FieldName.shortNames] = () self.assertRaises(ValueError, self.makeRecord, fields) fields = self.fields_wsanchez.copy() fields[FieldName.shortNames] = ("",) self.assertRaises(ValueError, self.makeRecord, fields) fields = self.fields_wsanchez.copy() fields[FieldName.shortNames] = ("wsanchez", "") self.assertRaises(ValueError, self.makeRecord, fields) def test_initWithBogusRecordType(self): fields = self.fields_wsanchez.copy() fields[FieldName.recordType] = object() self.assertRaises(ValueError, self.makeRecord, fields) def test_initNormalize(self): sagen = self.makeRecord(self.fields_sagen) self.assertEquals( sagen.fields[FieldName.emailAddresses], ("sagen@calendarserver.org",) ) def test_compare(self): fields_glyphmod = self.fields_glyph.copy() del fields_glyphmod[FieldName.emailAddresses] plugh = DirectoryService("plugh") wsanchez = self.makeRecord(self.fields_wsanchez) wsanchezmod = self.makeRecord(self.fields_wsanchez, plugh) glyph = self.makeRecord(self.fields_glyph) glyphmod = self.makeRecord(fields_glyphmod) self.assertEquals(wsanchez, wsanchez) self.assertNotEqual(wsanchez, glyph) self.assertNotEqual(glyph, glyphmod) # UID matches, other fields don't self.assertNotEqual(glyphmod, wsanchez) self.assertNotEqual(wsanchez, wsanchezmod) # Different service def test_attributeAccess(self): wsanchez = self.makeRecord(self.fields_wsanchez) self.assertEquals( wsanchez.recordType, wsanchez.fields[FieldName.recordType] ) self.assertEquals( wsanchez.uid, wsanchez.fields[FieldName.uid] ) self.assertEquals( wsanchez.shortNames, wsanchez.fields[FieldName.shortNames] ) self.assertEquals( wsanchez.emailAddresses, wsanchez.fields[FieldName.emailAddresses] ) @inlineCallbacks def test_members(self): wsanchez = self.makeRecord(self.fields_wsanchez) self.assertEquals( set((yield wsanchez.members())), set() ) raise SkipTest("Subclasses should implement this test.") def test_groups(self): raise SkipTest("Subclasses should implement this test.") class DirectoryRecordTest(unittest.TestCase, BaseDirectoryRecordTest): def test_members(self): wsanchez = self.makeRecord(self.fields_wsanchez) self.assertEquals( set((yield wsanchez.members())), set() ) staff = self.makeRecord(self.fields_staff) self.assertFailure(staff.members(), NotImplementedError) def test_groups(self): wsanchez = self.makeRecord(self.fields_wsanchez) self.assertFailure(wsanchez.groups(), NotImplementedError) calendarserver-5.2+dfsg/twext/who/test/test_expression.py0000644000175000017500000000323112263343324023047 0ustar rahulrahul## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Directory service expression tests. """ from twisted.trial import unittest from twext.who.idirectory import FieldName from twext.who.expression import MatchExpression, MatchType, MatchFlags class MatchExpressionTest(unittest.TestCase): def test_repr(self): self.assertEquals( "", repr(MatchExpression( FieldName.fullNames, "Wilfredo Sanchez", )), ) self.assertEquals( "", repr(MatchExpression( FieldName.fullNames, "Sanchez", matchType=MatchType.contains, )), ) self.assertEquals( "", repr(MatchExpression( FieldName.fullNames, "Wilfredo", matchType=MatchType.startsWith, flags=MatchFlags.NOT, )), ) calendarserver-5.2+dfsg/twext/who/test/__init__.py0000644000175000017500000000121312263343324021346 0ustar rahulrahul## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Directory service integration tests """ calendarserver-5.2+dfsg/twext/who/aggregate.py0000644000175000017500000000543612263343324020571 0ustar rahulrahul# -*- test-case-name: twext.who.test.test_aggregate -*- ## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Directory service which aggregates multiple directory services. """ __all__ = [ "DirectoryService", "DirectoryRecord", ] from itertools import chain from twisted.internet.defer import gatherResults, FirstError from twext.who.idirectory import DirectoryConfigurationError from twext.who.idirectory import IDirectoryService from twext.who.index import DirectoryService as BaseDirectoryService from twext.who.index import DirectoryRecord from twext.who.util import ConstantsContainer class DirectoryService(BaseDirectoryService): """ Aggregate directory service. """ def __init__(self, realmName, services): recordTypes = set() for service in services: if not IDirectoryService.implementedBy(service.__class__): raise ValueError( "Not a directory service: {0}".format(service) ) for recordType in service.recordTypes(): if recordType in recordTypes: raise DirectoryConfigurationError( "Aggregated services may not vend " "the same record type: {0}" .format(recordType) ) recordTypes.add(recordType) BaseDirectoryService.__init__(self, realmName) self._services = tuple(services) @property def services(self): return self._services @property def recordType(self): if not hasattr(self, "_recordType"): self._recordType = ConstantsContainer(chain(*tuple( s.recordTypes() for s in self.services ))) return self._recordType def recordsFromExpression(self, expression, records=None): ds = [] for service in self.services: d = service.recordsFromExpression(expression, records) ds.append(d) def unwrapFirstError(f): f.trap(FirstError) return f.value.subFailure d = gatherResults(ds, consumeErrors=True) d.addCallback(lambda results: chain(*results)) d.addErrback(unwrapFirstError) return d calendarserver-5.2+dfsg/twext/who/xml.py0000644000175000017500000003041212263343324017433 0ustar rahulrahul# -*- test-case-name: twext.who.test.test_xml -*- ## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from __future__ import absolute_import """ XML directory service implementation. """ __all__ = [ "ParseError", "DirectoryService", "DirectoryRecord", ] from time import time from xml.etree.ElementTree import parse as parseXML from xml.etree.ElementTree import ParseError as XMLParseError from xml.etree.ElementTree import tostring as etreeToString from xml.etree.ElementTree import Element as XMLElement from twisted.python.constants import Values, ValueConstant from twisted.internet.defer import fail from twext.who.idirectory import DirectoryServiceError from twext.who.idirectory import NoSuchRecordError, UnknownRecordTypeError from twext.who.idirectory import RecordType, FieldName as BaseFieldName from twext.who.index import DirectoryService as BaseDirectoryService from twext.who.index import DirectoryRecord from twext.who.index import FieldName as IndexFieldName ## # Exceptions ## class ParseError(DirectoryServiceError): """ Parse error. """ ## # XML constants ## class Element(Values): directory = ValueConstant("directory") record = ValueConstant("record") # # Field names # uid = ValueConstant("uid") uid.fieldName = BaseFieldName.uid guid = ValueConstant("guid") guid.fieldName = BaseFieldName.guid shortName = ValueConstant("short-name") shortName.fieldName = BaseFieldName.shortNames fullName = ValueConstant("full-name") fullName.fieldName = BaseFieldName.fullNames emailAddress = ValueConstant("email") emailAddress.fieldName = BaseFieldName.emailAddresses password = ValueConstant("password") password.fieldName = BaseFieldName.password memberUID = ValueConstant("member-uid") memberUID.fieldName = IndexFieldName.memberUIDs class Attribute(Values): realm = ValueConstant("realm") recordType = ValueConstant("type") class Value(Values): # # Booleans # true = ValueConstant("true") false = ValueConstant("false") # # Record types # user = ValueConstant("user") user.recordType = RecordType.user group = ValueConstant("group") group.recordType = RecordType.group ## # Directory Service ## class DirectoryService(BaseDirectoryService): """ XML directory service. """ element = Element attribute = Attribute value = Value refreshInterval = 4 def __init__(self, filePath): BaseDirectoryService.__init__(self, realmName=noRealmName) self.filePath = filePath def __repr__(self): realmName = self._realmName if realmName is None: realmName = "(not loaded)" else: realmName = repr(realmName) return ( "<{self.__class__.__name__} {realmName}>".format( self=self, realmName=realmName, ) ) @property def realmName(self): self.loadRecords() return self._realmName @realmName.setter def realmName(self, value): if value is not noRealmName: raise AssertionError("realmName may not be set directly") @property def unknownRecordTypes(self): self.loadRecords() return self._unknownRecordTypes @property def unknownFieldElements(self): self.loadRecords() return self._unknownFieldElements def loadRecords(self, loadNow=False, stat=True): """ Load records from L{self.filePath}. Does nothing if a successful refresh has happened within the last L{self.refreshInterval} seconds. @param loadNow: If true, load now (ignoring L{self.refreshInterval}) @type loadNow: L{type} @param stat: If true, check file metadata and don't reload if unchanged. @type loadNow: L{type} """ # # Punt if we've read the file recently # now = time() if not loadNow and now - self._lastRefresh <= self.refreshInterval: return # # Punt if we've read the file and it's still the same. # if stat: self.filePath.restat() cacheTag = ( self.filePath.getModificationTime(), self.filePath.getsize() ) if cacheTag == self._cacheTag: return else: cacheTag = None # # Open and parse the file # try: fh = self.filePath.open() try: etree = parseXML(fh) except XMLParseError as e: raise ParseError(e) finally: fh.close() # # Pull data from DOM # directoryNode = etree.getroot() if directoryNode.tag != self.element.directory.value: raise ParseError( "Incorrect root element: {0}".format(directoryNode.tag) ) realmName = directoryNode.get( self.attribute.realm.value, "" ).encode("utf-8") if not realmName: raise ParseError("No realm name.") unknownRecordTypes = set() unknownFieldElements = set() records = set() for recordNode in directoryNode: try: records.add( self.parseRecordNode(recordNode, unknownFieldElements) ) except UnknownRecordTypeError as e: unknownRecordTypes.add(e.token) # # Store results # index = {} for fieldName in self.indexedFields: index[fieldName] = {} for record in records: for fieldName in self.indexedFields: values = record.fields.get(fieldName, None) if values is not None: if not BaseFieldName.isMultiValue(fieldName): values = (values,) for value in values: index[fieldName].setdefault(value, set()).add(record) self._realmName = realmName self._unknownRecordTypes = unknownRecordTypes self._unknownFieldElements = unknownFieldElements self._cacheTag = cacheTag self._lastRefresh = now self.index = index return etree def parseRecordNode(self, recordNode, unknownFieldElements=None): recordTypeAttribute = recordNode.get( self.attribute.recordType.value, "" ).encode("utf-8") if recordTypeAttribute: try: recordType = ( self.value.lookupByValue(recordTypeAttribute).recordType ) except (ValueError, AttributeError): raise UnknownRecordTypeError(recordTypeAttribute) else: recordType = self.recordType.user fields = {} fields[self.fieldName.recordType] = recordType for fieldNode in recordNode: try: fieldElement = self.element.lookupByValue(fieldNode.tag) except ValueError: if unknownFieldElements is not None: unknownFieldElements.add(fieldNode.tag) try: fieldName = fieldElement.fieldName except AttributeError: if unknownFieldElements is not None: unknownFieldElements.add(fieldNode.tag) value = fieldNode.text.encode("utf-8") if BaseFieldName.isMultiValue(fieldName): values = fields.setdefault(fieldName, []) values.append(value) else: fields[fieldName] = value return DirectoryRecord(self, fields) def _uidForRecordNode(self, recordNode): uidNode = recordNode.find(self.element.uid.value) if uidNode is None: raise NotImplementedError("No UID node") return uidNode.text def flush(self): BaseDirectoryService.flush(self) self._realmName = None self._unknownRecordTypes = None self._unknownFieldElements = None self._cacheTag = None self._lastRefresh = 0 def updateRecords(self, records, create=False): # Index the records to update by UID recordsByUID = dict(((record.uid, record) for record in records)) # Index the record type -> attribute mappings. recordTypes = {} for valueName in self.value.iterconstants(): recordType = getattr(valueName, "recordType", None) if recordType is not None: recordTypes[recordType] = valueName.value del valueName # Index the field name -> element mappings. fieldNames = {} for elementName in self.element.iterconstants(): fieldName = getattr(elementName, "fieldName", None) if fieldName is not None: fieldNames[fieldName] = elementName.value del elementName directoryNode = self._directoryNodeForEditing() def fillRecordNode(recordNode, record): for (name, value) in record.fields.items(): if name == self.fieldName.recordType: if value in recordTypes: recordNode.set( self.attribute.recordType.value, recordTypes[value] ) else: raise AssertionError( "Unknown record type: {0}".format(value) ) else: if name in fieldNames: tag = fieldNames[name] if BaseFieldName.isMultiValue(name): values = value else: values = (value,) for value in values: subNode = XMLElement(tag) subNode.text = value recordNode.append(subNode) else: raise AssertionError( "Unknown field name: {0!r}".format(name) ) # Walk through the record nodes in the XML tree and apply # updates. for recordNode in directoryNode: uid = self._uidForRecordNode(recordNode) record = recordsByUID.get(uid, None) if record: recordNode.clear() fillRecordNode(recordNode, record) del recordsByUID[uid] if recordsByUID: if not create: return fail(NoSuchRecordError(recordsByUID.keys())) for uid, record in recordsByUID.items(): recordNode = XMLElement(self.element.record.value) fillRecordNode(recordNode, record) directoryNode.append(recordNode) self._writeDirectoryNode(directoryNode) def removeRecords(self, uids): directoryNode = self._directoryNodeForEditing() # # Walk through the record nodes in the XML tree and start # zapping. # for recordNode in directoryNode: uid = self._uidForRecordNode(recordNode) if uid in uids: directoryNode.remove(recordNode) self._writeDirectoryNode(directoryNode) def _directoryNodeForEditing(self): """ Drop cached data and load the XML DOM. """ self.flush() etree = self.loadRecords(loadNow=True) return etree.getroot() def _writeDirectoryNode(self, directoryNode): self.filePath.setContent(etreeToString(directoryNode)) self.flush() noRealmName = object() calendarserver-5.2+dfsg/twext/who/idirectory.py0000644000175000017500000002623112263343324021014 0ustar rahulrahul# -*- test-case-name: twext.who.test -*- ## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Directory service interface. """ __all__ = [ "DirectoryServiceError", "DirectoryConfigurationError", "DirectoryAvailabilityError", "UnknownRecordTypeError", "QueryNotSupportedError", "NoSuchRecordError", "NotAllowedError", "RecordType", "FieldName", "Operand", "IDirectoryService", "IDirectoryRecord", ] from zope.interface import Attribute, Interface from twisted.python.constants import Names, NamedConstant # # Exceptions # class DirectoryServiceError(Exception): """ Directory service generic error. """ class DirectoryConfigurationError(DirectoryServiceError): """ Directory configuration error. """ class DirectoryAvailabilityError(DirectoryServiceError): """ Directory not available. """ class UnknownRecordTypeError(DirectoryServiceError): """ Unknown record type. """ def __init__(self, token): DirectoryServiceError.__init__(self, token) self.token = token class QueryNotSupportedError(DirectoryServiceError): """ Query not supported. """ class NoSuchRecordError(DirectoryServiceError): """ Record does not exist. """ class NotAllowedError(DirectoryServiceError): """ It seems you aren't permitted to do that. """ # # Data Types # class RecordType(Names): """ Constants for common directory record types. """ user = NamedConstant() group = NamedConstant() user.description = "user" group.description = "group" class FieldName(Names): """ Constants for common directory record field names. Fields as assciated with either a single value or an iterable of values. @cvar uid: The primary unique identifier for a directory record. The associated value must be a L{unicode}. @cvar guid: The globally unique identifier for a directory record. The associated value must be a L{UUID} or C{None}. @cvar recordType: The type of a directory record. The associated value must be a L{NamedConstant}. @cvar shortNames: The short names for a directory record. The associated values must L{unicode}s and there must be at least one associated value. @cvar fullNames: The full names for a directory record. The associated values must be L{unicode}s. @cvar emailAddresses: The email addresses for a directory record. The associated values must be L{unicodes}. @cvar password: The clear text password for a directory record. The associated value must be a L{unicode} or C{None}. """ uid = NamedConstant() guid = NamedConstant() recordType = NamedConstant() shortNames = NamedConstant() fullNames = NamedConstant() emailAddresses = NamedConstant() password = NamedConstant() uid.description = "UID" guid.description = "GUID" recordType.description = "record type" shortNames.description = "short names" fullNames.description = "full names" emailAddresses.description = "email addresses" password.description = "password" shortNames.multiValue = True fullNames.multiValue = True emailAddresses.multiValue = True @staticmethod def isMultiValue(name): """ Check for whether a field is multi-value (as opposed to single-value). @return: C{True} if the field is multi-value, C{False} otherwise. @rtype: L{BOOL} """ return getattr(name, "multiValue", False) class Operand(Names): """ Contants for common operands. """ OR = NamedConstant() AND = NamedConstant() OR.description = "or" AND.description = "and" # # Interfaces # class IDirectoryService(Interface): """ Directory service. A directory service is a service that vends information about principals such as users, locations, printers, and other resources. This information is provided in the form of directory records. A directory service can be queried for the types of records it supports, and for specific records matching certain criteria. A directory service may allow support the editing, removal and addition of records. Services are read-only should fail with L{NotAllowedError} in editing methods. The L{FieldName.uid} field, the L{FieldName.guid} field (if not C{None}), and the combination of the L{FieldName.recordType} and L{FieldName.shortName} fields must be unique to each directory record vended by a directory service. """ realmName = Attribute( "The name of the authentication realm this service represents." ) def recordTypes(): """ Get the record types supported by this directory service. @return: The record types that are supported by this directory service. @rtype: iterable of L{NamedConstant}s """ def recordsFromExpression(self, expression): """ Find records matching an expression. @param expression: an expression to apply @type expression: L{object} @return: The matching records. @rtype: deferred iterable of L{IDirectoryRecord}s @raises: L{QueryNotSupportedError} if the expression is not supported by this directory service. """ def recordsFromQuery(expressions, operand=Operand.AND): """ Find records by composing a query consisting of an iterable of expressions and an operand. @param expressions: expressions to query against @type expressions: iterable of L{object}s @param operand: an operand @type operand: a L{NamedConstant} @return: The matching records. @rtype: deferred iterable of L{IDirectoryRecord}s @raises: L{QueryNotSupportedError} if the query is not supported by this directory service. """ def recordsWithFieldValue(fieldName, value): """ Find records that have the given field name with the given value. @param fieldName: a field name @type fieldName: L{NamedConstant} @param value: a value to match @type value: L{bytes} @return: The matching records. @rtype: deferred iterable of L{IDirectoryRecord}s """ def recordWithUID(uid): """ Find the record that has the given UID. @param uid: a UID @type uid: L{bytes} @return: The matching record or C{None} if there is no match. @rtype: deferred L{IDirectoryRecord}s or C{None} """ def recordWithGUID(guid): """ Find the record that has the given GUID. @param guid: a GUID @type guid: L{bytes} @return: The matching record or C{None} if there is no match. @rtype: deferred L{IDirectoryRecord}s or C{None} """ def recordsWithRecordType(recordType): """ Find the records that have the given record type. @param recordType: a record type @type recordType: L{NamedConstant} @return: The matching records. @rtype: deferred iterable of L{IDirectoryRecord}s """ def recordWithShortName(recordType, shortName): """ Find the record that has the given record type and short name. @param recordType: a record type @type recordType: L{NamedConstant} @param shortName: a short name @type shortName: L{bytes} @return: The matching record or C{None} if there is no match. @rtype: deferred L{IDirectoryRecord}s or C{None} """ def recordsWithEmailAddress(emailAddress): """ Find the records that have the given email address. @param emailAddress: an email address @type emailAddress: L{bytes} @return: The matching records. @rtype: deferred iterable of L{IDirectoryRecord}s """ def updateRecords(records, create=False): """ Updates existing directory records. @param records: the records to update @type records: iterable of L{IDirectoryRecord}s @param create: if true, create records if necessary @type create: boolean @return: unspecifiied @rtype: deferred object @raises L{NotAllowedError}: if the update is not allowed by the directory service. """ def removeRecords(uids): """ Removes the records with the given UIDs. @param uids: the UIDs of the records to remove @type uids: iterable of L{bytes} @return: unspecifiied @rtype: deferred object @raises L{NotAllowedError}: if the removal is not allowed by the directory service. """ class IDirectoryRecord(Interface): """ Directory record. A directory record corresponds to a principal, and contains information about the principal such as idenfiers, names and passwords. This information is stored in a set of fields (a mapping of field names and values). Some fields allow for multiple values while others allow only one value. This is discoverable by calling L{FieldName.isMultiValue} on the field name. The field L{FieldName.recordType} will be present in all directory records, as all records must have a type. Which other fields are required is implementation-specific. Principals (called group principals) may have references to other principals as members. Records representing group principals will typically be records with the record type L{RecordType.group}, but it is not prohibited for other record types to have members. Fields may also be accessed as attributes. For example: C{record.recordType} is equivalent to C{record.fields[FieldName.recordType]}. """ service = Attribute("The L{IDirectoryService} this record exists in.") fields = Attribute("A mapping with L{NamedConstant} keys.") def members(): """ Find the records that are members of this group. Only direct members are included; members of members are not expanded. @return: a deferred iterable of L{IDirectoryRecord}s which are direct members of this group. """ def groups(): """ Find the group records that this record is a member of. Only groups for which this record is a direct member is are included; membership is not expanded. @return: a deferred iterable of L{IDirectoryRecord}s which are groups that this record is a member of. """ calendarserver-5.2+dfsg/twext/who/__init__.py0000644000175000017500000000125612263343324020376 0ustar rahulrahul# -*- test-case-name: twext.who.test -*- ## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Directory service integration """ calendarserver-5.2+dfsg/twext/who/util.py0000644000175000017500000000476412263343324017623 0ustar rahulrahul# -*- test-case-name: twext.who.test.test_util -*- ## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Directory service module utilities. """ __all__ = [ "ConstantsContainer", "uniqueResult", "describe", "iterFlags", ] from twisted.python.constants import FlagConstant from twext.who.idirectory import DirectoryServiceError class ConstantsContainer(object): """ A container for constants. """ def __init__(self, constants): myConstants = {} for constant in constants: if constant.name in myConstants: raise ValueError("Name conflict: {0}".format(constant.name)) myConstants[constant.name] = constant self._constants = myConstants def __getattr__(self, name): try: return self._constants[name] except KeyError: raise AttributeError(name) def iterconstants(self): return self._constants.itervalues() def lookupByName(self, name): try: return self._constants[name] except KeyError: raise ValueError(name) def uniqueResult(values): result = None for value in values: if result is None: result = value else: raise DirectoryServiceError( "Multiple values found where one expected." ) return result def describe(constant): if isinstance(constant, FlagConstant): parts = [] for flag in iterFlags(constant): parts.append(getattr(flag, "description", flag.name)) return "|".join(parts) else: return getattr(constant, "description", constant.name) def iterFlags(flags): if hasattr(flags, "__iter__"): return flags else: # Work around http://twistedmatrix.com/trac/ticket/6302 # FIXME: This depends on a private attribute (flags._container) return (flags._container.lookupByName(name) for name in flags.names) calendarserver-5.2+dfsg/twext/web2/0000755000175000017500000000000012322625326016324 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/client/0000755000175000017500000000000012322625326017602 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/client/interfaces.py0000644000175000017500000000476012263343324022305 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_client -*- ## # Copyright (c) 2007 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## from zope.interface import Interface class IHTTPClientManager(Interface): """I coordinate between multiple L{HTTPClientProtocol} objects connected to a single server to facilite request queuing and pipelining. """ def clientBusy(proto): """Called when the L{HTTPClientProtocol} doesn't want to accept anymore requests. @param proto: The L{HTTPClientProtocol} that is changing state. @type proto: L{HTTPClientProtocol} """ pass def clientIdle(proto): """Called when an L{HTTPClientProtocol} is able to accept more requests. @param proto: The L{HTTPClientProtocol} that is changing state. @type proto: L{HTTPClientProtocol} """ pass def clientPipelining(proto): """Called when the L{HTTPClientProtocol} determines that it is able to support request pipelining. @param proto: The L{HTTPClientProtocol} that is changing state. @type proto: L{HTTPClientProtocol} """ pass def clientGone(proto): """Called when the L{HTTPClientProtocol} disconnects from the server. @param proto: The L{HTTPClientProtocol} that is changing state. @type proto: L{HTTPClientProtocol} """ pass calendarserver-5.2+dfsg/twext/web2/client/http.py0000644000175000017500000003040612263343324021135 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_client -*- ## # Copyright (c) 2001-2007 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ Client-side HTTP implementation. """ from zope.interface import implements from twisted.internet.defer import Deferred from twisted.protocols.basic import LineReceiver from twisted.protocols.policies import TimeoutMixin from twext.web2.responsecode import BAD_REQUEST, HTTP_VERSION_NOT_SUPPORTED from twext.web2.http import parseVersion, Response from twext.web2.http_headers import Headers from twext.web2.stream import ProducerStream, StreamProducer, IByteStream from twext.web2.channel.http import HTTPParser, PERSIST_NO_PIPELINE, PERSIST_PIPELINE from twext.web2.client.interfaces import IHTTPClientManager class ProtocolError(Exception): """ Exception raised when a HTTP error happened. """ class ClientRequest(object): """ A class for describing an HTTP request to be sent to the server. """ def __init__(self, method, uri, headers, stream): """ @param method: The HTTP method to for this request, ex: 'GET', 'HEAD', 'POST', etc. @type method: C{str} @param uri: The URI of the resource to request, this may be absolute or relative, however the interpretation of this URI is left up to the remote server. @type uri: C{str} @param headers: Headers to be sent to the server. It is important to note that this object does not create any implicit headers. So it is up to the HTTP Client to add required headers such as 'Host'. @type headers: C{dict}, L{twext.web2.http_headers.Headers}, or C{None} @param stream: Content body to send to the remote HTTP server. @type stream: L{twext.web2.stream.IByteStream} """ self.method = method self.uri = uri if isinstance(headers, Headers): self.headers = headers else: self.headers = Headers(headers or {}) if stream is not None: self.stream = IByteStream(stream) else: self.stream = None class HTTPClientChannelRequest(HTTPParser): parseCloseAsEnd = True outgoing_version = "HTTP/1.1" chunkedOut = False finished = False closeAfter = False def __init__(self, channel, request, closeAfter): HTTPParser.__init__(self, channel) self.request = request self.closeAfter = closeAfter self.transport = self.channel.transport self.responseDefer = Deferred() def submit(self): l = [] request = self.request if request.method == "HEAD": # No incoming data will arrive. self.length = 0 l.append('%s %s %s\r\n' % (request.method, request.uri, self.outgoing_version)) if request.headers is not None: for name, valuelist in request.headers.getAllRawHeaders(): for value in valuelist: l.append("%s: %s\r\n" % (name, value)) if request.stream is not None: if request.stream.length is not None: l.append("%s: %s\r\n" % ('Content-Length', request.stream.length)) else: # Got a stream with no length. Send as chunked and hope, against # the odds, that the server actually supports chunked uploads. l.append("%s: %s\r\n" % ('Transfer-Encoding', 'chunked')) self.chunkedOut = True if self.closeAfter: l.append("%s: %s\r\n" % ('Connection', 'close')) else: l.append("%s: %s\r\n" % ('Connection', 'Keep-Alive')) l.append("\r\n") self.transport.writeSequence(l) d = StreamProducer(request.stream).beginProducing(self) d.addCallback(self._finish).addErrback(self._error) def registerProducer(self, producer, streaming): """ Register a producer. """ self.transport.registerProducer(producer, streaming) def unregisterProducer(self): self.transport.unregisterProducer() def write(self, data): if not data: return elif self.chunkedOut: self.transport.writeSequence(("%X\r\n" % len(data), data, "\r\n")) else: self.transport.write(data) def _finish(self, x): """ We are finished writing data. """ if self.chunkedOut: # write last chunk and closing CRLF self.transport.write("0\r\n\r\n") self.finished = True self.channel.requestWriteFinished(self) del self.transport def _error(self, err): """ Abort parsing, and depending of the status of the request, either fire the C{responseDefer} if no response has been sent yet, or close the stream. """ self.abortParse() if hasattr(self, 'stream') and self.stream is not None: self.stream.finish(err) else: self.responseDefer.errback(err) def _abortWithError(self, errcode, text): """ Abort parsing by forwarding a C{ProtocolError} to C{_error}. """ self._error(ProtocolError(text)) def connectionLost(self, reason): self._error(reason) def gotInitialLine(self, initialLine): parts = initialLine.split(' ', 2) # Parse the initial request line if len(parts) != 3: self._abortWithError(BAD_REQUEST, "Bad response line: %s" % (initialLine,)) return strversion, self.code, message = parts try: protovers = parseVersion(strversion) if protovers[0] != 'http': raise ValueError() except ValueError: self._abortWithError(BAD_REQUEST, "Unknown protocol: %s" % (strversion,)) return self.version = protovers[1:3] # Ensure HTTP 0 or HTTP 1. if self.version[0] != 1: self._abortWithError(HTTP_VERSION_NOT_SUPPORTED, 'Only HTTP 1.x is supported.') return ## FIXME: Actually creates Response, function is badly named! def createRequest(self): self.stream = ProducerStream(self.length) self.response = Response(self.code, self.inHeaders, self.stream) self.stream.registerProducer(self, True) del self.inHeaders ## FIXME: Actually processes Response, function is badly named! def processRequest(self): self.responseDefer.callback(self.response) def handleContentChunk(self, data): self.stream.write(data) def handleContentComplete(self): self.stream.finish() class EmptyHTTPClientManager(object): """ A dummy HTTPClientManager. It doesn't do any client management, and is meant to be used only when creating an HTTPClientProtocol directly. """ implements(IHTTPClientManager) def clientBusy(self, proto): pass def clientIdle(self, proto): pass def clientPipelining(self, proto): pass def clientGone(self, proto): pass class HTTPClientProtocol(LineReceiver, TimeoutMixin, object): """ A HTTP 1.1 Client with request pipelining support. """ chanRequest = None maxHeaderLength = 10240 firstLine = 1 readPersistent = PERSIST_NO_PIPELINE # inputTimeOut should be pending whenever a complete request has # been written but the complete response has not yet been # received, and be reset every time data is received. inputTimeOut = 60 * 4 def __init__(self, manager=None): """ @param manager: The object this client reports it state to. @type manager: L{IHTTPClientManager} """ self.outRequest = None self.inRequests = [] if manager is None: manager = EmptyHTTPClientManager() self.manager = manager def lineReceived(self, line): if not self.inRequests: # server sending random unrequested data. self.transport.loseConnection() return # If not currently writing this request, set timeout if self.inRequests[0] is not self.outRequest: self.setTimeout(self.inputTimeOut) if self.firstLine: self.firstLine = 0 self.inRequests[0].gotInitialLine(line) else: self.inRequests[0].lineReceived(line) def rawDataReceived(self, data): if not self.inRequests: # Server sending random unrequested data. self.transport.loseConnection() return # If not currently writing this request, set timeout if self.inRequests[0] is not self.outRequest: self.setTimeout(self.inputTimeOut) self.inRequests[0].rawDataReceived(data) def submitRequest(self, request, closeAfter=True): """ @param request: The request to send to a remote server. @type request: L{ClientRequest} @param closeAfter: If True the 'Connection: close' header will be sent, otherwise 'Connection: keep-alive' @type closeAfter: C{bool} @rtype: L{twisted.internet.defer.Deferred} @return: A Deferred which will be called back with the L{twext.web2.http.Response} from the server. """ # Assert we're in a valid state to submit more assert self.outRequest is None assert ((self.readPersistent is PERSIST_NO_PIPELINE and not self.inRequests) or self.readPersistent is PERSIST_PIPELINE) self.manager.clientBusy(self) if closeAfter: self.readPersistent = False self.outRequest = chanRequest = HTTPClientChannelRequest(self, request, closeAfter) self.inRequests.append(chanRequest) chanRequest.submit() return chanRequest.responseDefer def requestWriteFinished(self, request): assert request is self.outRequest self.outRequest = None # Tell the manager if more requests can be submitted. self.setTimeout(self.inputTimeOut) if self.readPersistent is PERSIST_PIPELINE: self.manager.clientPipelining(self) def requestReadFinished(self, request): assert self.inRequests[0] is request del self.inRequests[0] self.firstLine = True if not self.inRequests: if self.readPersistent: self.setTimeout(None) self.manager.clientIdle(self) else: self.transport.loseConnection() def setReadPersistent(self, persist): self.readPersistent = persist if not persist: # Tell all requests but first to abort. for request in self.inRequests[1:]: request.connectionLost(None) del self.inRequests[1:] def connectionLost(self, reason): self.readPersistent = False self.setTimeout(None) self.manager.clientGone(self) # Tell all requests to abort. for request in self.inRequests: if request is not None: request.connectionLost(reason) calendarserver-5.2+dfsg/twext/web2/client/__init__.py0000644000175000017500000000241012263343324021707 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_client -*- ## # Copyright (c) 2004 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ Twisted.web2.client: Client Implementation """ calendarserver-5.2+dfsg/twext/web2/channel/0000755000175000017500000000000012322625325017733 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/channel/http.py0000644000175000017500000012505112263343324021270 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_http -*- ## # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # Copyright (c) 2008-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## import time import warnings import socket from random import randint from cStringIO import StringIO from zope.interface import implements from twisted.internet import interfaces, protocol, reactor from twisted.internet.defer import succeed, Deferred from twisted.protocols import policies, basic from twext.python.log import Logger from twext.web2 import responsecode from twext.web2 import http_headers from twext.web2 import http from twext.web2.http import RedirectResponse from twext.web2.server import Request from twistedcaldav.config import config from twistedcaldav import accounting log = Logger() class OverloadedLoggingServerProtocol (protocol.Protocol): def __init__(self, retryAfter, outstandingRequests): self.retryAfter = retryAfter self.outstandingRequests = outstandingRequests def connectionMade(self): log.info(overloaded=self) self.transport.write( "HTTP/1.0 503 Service Unavailable\r\n" "Content-Type: text/html\r\n" ) if self.retryAfter: self.transport.write( "Retry-After: %s\r\n" % (self.retryAfter,) ) self.transport.write( "Connection: close\r\n\r\n" "Service Unavailable" "

Service Unavailable

" "The server is currently overloaded, " "please try again later." ) self.transport.loseConnection() class SSLRedirectRequest(Request): """ An L{SSLRedirectRequest} prevents processing if the request is over plain HTTP; instead, it redirects to HTTPS. """ def process(self): ignored, secure = self.chanRequest.getHostInfo() if not secure: if config.SSLPort == 443: location = ( "https://%s%s" % (config.ServerHostName, self.uri) ) else: location = ( "https://%s:%d%s" % (config.ServerHostName, config.SSLPort, self.uri) ) return super(SSLRedirectRequest, self).writeResponse( RedirectResponse(location) ) else: return super(SSLRedirectRequest, self).process() # >% PERSIST_NO_PIPELINE, PERSIST_PIPELINE = (1,2) _cachedHostNames = {} def _cachedGetHostByAddr(hostaddr): hostname = _cachedHostNames.get(hostaddr) if hostname is None: try: hostname = socket.gethostbyaddr(hostaddr)[0] except socket.herror: hostname = hostaddr _cachedHostNames[hostaddr]=hostname return hostname class StringTransport(object): """ I am a StringIO wrapper that conforms for the transport API. I support the 'writeSequence' method. """ def __init__(self): self.s = StringIO() def writeSequence(self, seq): self.s.write(''.join(seq)) def __getattr__(self, attr): return getattr(self.__dict__['s'], attr) class AbortedException(Exception): pass class HTTPParser(object): """This class handles the parsing side of HTTP processing. With a suitable subclass, it can parse either the client side or the server side of the connection. """ # Class config: parseCloseAsEnd = False # Instance vars chunkedIn = False headerlen = 0 length = None inHeaders = None partialHeader = '' connHeaders = None finishedReading = False channel = None # For subclassing... # Needs attributes: # version # Needs functions: # createRequest() # processRequest() # _abortWithError() # handleContentChunk(data) # handleContentComplete() # Needs functions to exist on .channel # channel.maxHeaderLength # channel.requestReadFinished(self) # channel.setReadPersistent(self, persistent) # (from LineReceiver): # channel.setRawMode() # channel.setLineMode(extraneous) # channel.pauseProducing() # channel.resumeProducing() # channel.stopProducing() def __init__(self, channel): self.inHeaders = http_headers.Headers() self.channel = channel def lineReceived(self, line): if self.chunkedIn: # Parsing a chunked input if self.chunkedIn == 1: # First we get a line like "chunk-size [';' chunk-extension]" # (where chunk extension is just random crap as far as we're concerned) # RFC says to ignore any extensions you don't recognize -- that's all of them. chunksize = line.split(';', 1)[0] try: self.length = int(chunksize, 16) except: self._abortWithError(responsecode.BAD_REQUEST, "Invalid chunk size, not a hex number: %s!" % chunksize) if self.length < 0: self._abortWithError(responsecode.BAD_REQUEST, "Invalid chunk size, negative.") if self.length == 0: # We're done, parse the trailers line self.chunkedIn = 3 else: # Read self.length bytes of raw data self.channel.setRawMode() elif self.chunkedIn == 2: # After we got data bytes of the appropriate length, we end up here, # waiting for the CRLF, then go back to get the next chunk size. if line != '': self._abortWithError(responsecode.BAD_REQUEST, "Excess %d bytes sent in chunk transfer mode" % len(line)) self.chunkedIn = 1 elif self.chunkedIn == 3: # TODO: support Trailers (maybe! but maybe not!) # After getting the final "0" chunk we're here, and we *EAT MERCILESSLY* # any trailer headers sent, and wait for the blank line to terminate the # request. if line == '': self.allContentReceived() # END of chunk handling elif line == '': # Empty line => End of headers if self.partialHeader: self.headerReceived(self.partialHeader) self.partialHeader = '' self.allHeadersReceived() # can set chunkedIn self.createRequest() if self.chunkedIn: # stay in linemode waiting for chunk header pass elif self.length == 0: # no content expected self.allContentReceived() else: # await raw data as content self.channel.setRawMode() # Should I do self.pauseProducing() here? self.processRequest() else: self.headerlen += len(line) if self.headerlen > self.channel.maxHeaderLength: self._abortWithError(responsecode.BAD_REQUEST, 'Headers too long.') if line[0] in ' \t': # Append a header continuation self.partialHeader += line else: if self.partialHeader: self.headerReceived(self.partialHeader) self.partialHeader = line def rawDataReceived(self, data): """Handle incoming content.""" datalen = len(data) if datalen < self.length: self.handleContentChunk(data) self.length = self.length - datalen else: self.handleContentChunk(data[:self.length]) extraneous = data[self.length:] channel = self.channel # could go away from allContentReceived. if not self.chunkedIn: self.allContentReceived() else: # NOTE: in chunked mode, self.length is the size of the current chunk, # so we still have more to read. self.chunkedIn = 2 # Read next chunksize channel.setLineMode(extraneous) def headerReceived(self, line): """ Store this header away. Check for too much header data (> channel.maxHeaderLength) and non-ASCII characters; abort the connection with C{BAD_REQUEST} if so. """ nameval = line.split(':', 1) if len(nameval) != 2: self._abortWithError(responsecode.BAD_REQUEST, "No ':' in header.") name, val = nameval for field in name, val: try: field.decode('ascii') except UnicodeDecodeError: self._abortWithError(responsecode.BAD_REQUEST, "Headers must be ASCII") val = val.lstrip(' \t') self.inHeaders.addRawHeader(name, val) def allHeadersReceived(self): # Split off connection-related headers connHeaders = self.splitConnectionHeaders() # Set connection parameters from headers self.setConnectionParams(connHeaders) self.connHeaders = connHeaders def allContentReceived(self): self.finishedReading = True self.channel.requestReadFinished(self) self.handleContentComplete() def splitConnectionHeaders(self): """ Split off connection control headers from normal headers. The normal headers are then passed on to user-level code, while the connection headers are stashed in .connHeaders and used for things like request/response framing. This corresponds roughly with the HTTP RFC's description of 'hop-by-hop' vs 'end-to-end' headers in RFC2616 S13.5.1, with the following exceptions: - proxy-authenticate and proxy-authorization are not treated as connection headers. - content-length is, as it is intimately related with low-level HTTP parsing, and is made available to user-level code via the stream length, rather than a header value. (except for HEAD responses, in which case it is NOT used by low-level HTTP parsing, and IS kept in the normal headers. """ def move(name): h = inHeaders.getRawHeaders(name, None) if h is not None: inHeaders.removeHeader(name) connHeaders.setRawHeaders(name, h) # NOTE: According to HTTP spec, we're supposed to eat the # 'Proxy-Authenticate' and 'Proxy-Authorization' headers also, but that # doesn't sound like a good idea to me, because it makes it impossible # to have a non-authenticating transparent proxy in front of an # authenticating proxy. An authenticating proxy can eat them itself. # # 'Proxy-Connection' is an undocumented HTTP 1.0 abomination. connHeaderNames = ['content-length', 'connection', 'keep-alive', 'te', 'trailers', 'transfer-encoding', 'upgrade', 'proxy-connection'] inHeaders = self.inHeaders connHeaders = http_headers.Headers() move('connection') if self.version < (1,1): # Remove all headers mentioned in Connection, because a HTTP 1.0 # proxy might have erroneously forwarded it from a 1.1 client. for name in connHeaders.getHeader('connection', ()): if inHeaders.hasHeader(name): inHeaders.removeHeader(name) else: # Otherwise, just add the headers listed to the list of those to move connHeaderNames.extend(connHeaders.getHeader('connection', ())) # If the request was HEAD, self.length has been set to 0 by # HTTPClientRequest.submit; in this case, Content-Length should # be treated as a response header, not a connection header. # Note: this assumes the invariant that .length will always be None # coming into this function, unless this is a HEAD request. if self.length is not None: connHeaderNames.remove('content-length') for headername in connHeaderNames: move(headername) return connHeaders def setConnectionParams(self, connHeaders): # Figure out persistent connection stuff if self.version >= (1,1): if 'close' in connHeaders.getHeader('connection', ()): readPersistent = False else: readPersistent = PERSIST_PIPELINE elif 'keep-alive' in connHeaders.getHeader('connection', ()): readPersistent = PERSIST_NO_PIPELINE else: readPersistent = False # Okay, now implement section 4.4 Message Length to determine # how to find the end of the incoming HTTP message. transferEncoding = connHeaders.getHeader('transfer-encoding') if transferEncoding: if transferEncoding[-1] == 'chunked': # Chunked self.chunkedIn = 1 # Cut off the chunked encoding (cause it's special) transferEncoding = transferEncoding[:-1] elif not self.parseCloseAsEnd: # Would close on end of connection, except this can't happen for # client->server data. (Well..it could actually, since TCP has half-close # but the HTTP spec says it can't, so we'll pretend it's right.) self._abortWithError(responsecode.BAD_REQUEST, "Transfer-Encoding received without chunked in last position.") # TODO: support gzip/etc encodings. # FOR NOW: report an error if the client uses any encodings. # They shouldn't, because we didn't send a TE: header saying it's okay. if transferEncoding: self._abortWithError(responsecode.NOT_IMPLEMENTED, "Transfer-Encoding %s not supported." % transferEncoding) else: # No transfer-coding. self.chunkedIn = 0 if self.parseCloseAsEnd: # If no Content-Length, then it's indeterminate length data # (unless the responsecode was one of the special no body ones) # Also note that for HEAD requests, connHeaders won't have # content-length even if the response did. if self.code in http.NO_BODY_CODES: self.length = 0 else: self.length = connHeaders.getHeader('content-length', self.length) # If it's an indeterminate stream without transfer encoding, it must be # the last request. if self.length is None: readPersistent = False else: # If no Content-Length either, assume no content. self.length = connHeaders.getHeader('content-length', 0) # Set the calculated persistence self.channel.setReadPersistent(readPersistent) def abortParse(self): # If we're erroring out while still reading the request if not self.finishedReading: self.finishedReading = True self.channel.setReadPersistent(False) self.channel.requestReadFinished(self) # producer interface def pauseProducing(self): if not self.finishedReading: self.channel.pauseProducing() def resumeProducing(self): if not self.finishedReading: self.channel.resumeProducing() def stopProducing(self): if not self.finishedReading: self.channel.stopProducing() class HTTPChannelRequest(HTTPParser): """This class handles the state and parsing for one HTTP request. It is responsible for all the low-level connection oriented behavior. Thus, it takes care of keep-alive, de-chunking, etc., and passes the non-connection headers on to the user-level Request object.""" command = path = version = None queued = 0 request = None out_version = "HTTP/1.1" def __init__(self, channel, queued=0): HTTPParser.__init__(self, channel) self.queued=queued # Buffer writes to a string until we're first in line # to write a response if queued: self.transport = StringTransport() else: self.transport = self.channel.transport # set the version to a fallback for error generation self.version = (1,0) def gotInitialLine(self, initialLine): parts = initialLine.split() # Parse the initial request line if len(parts) != 3: if len(parts) == 1: parts.append('/') if len(parts) == 2 and parts[1][0] == '/': parts.append('HTTP/0.9') else: self._abortWithError(responsecode.BAD_REQUEST, 'Bad request line: %s' % initialLine) self.command, self.path, strversion = parts try: protovers = http.parseVersion(strversion) if protovers[0] != 'http': raise ValueError() except ValueError: self._abortWithError(responsecode.BAD_REQUEST, "Unknown protocol: %s" % strversion) self.version = protovers[1:3] # Ensure HTTP 0 or HTTP 1. if self.version[0] > 1: self._abortWithError(responsecode.HTTP_VERSION_NOT_SUPPORTED, 'Only HTTP 0.9 and HTTP 1.x are supported.') if self.version[0] == 0: # simulate end of headers, as HTTP 0 doesn't have headers. self.lineReceived('') def lineLengthExceeded(self, line, wasFirst=False): code = wasFirst and responsecode.REQUEST_URI_TOO_LONG or responsecode.BAD_REQUEST self._abortWithError(code, 'Header line too long.') def createRequest(self): self.request = self.channel.requestFactory(self, self.command, self.path, self.version, self.length, self.inHeaders) del self.inHeaders def processRequest(self): self.request.process() def handleContentChunk(self, data): self.request.handleContentChunk(data) def handleContentComplete(self): self.request.handleContentComplete() ############## HTTPChannelRequest *RESPONSE* methods ############# producer = None chunkedOut = False finished = False ##### Request Callbacks ##### def writeIntermediateResponse(self, code, headers=None): if self.version >= (1,1): self._writeHeaders(code, headers, False) def writeHeaders(self, code, headers): self._writeHeaders(code, headers, True) def _writeHeaders(self, code, headers, addConnectionHeaders): # HTTP 0.9 doesn't have headers. if self.version[0] == 0: return l = [] code_message = responsecode.RESPONSES.get(code, "Unknown Status") l.append('%s %s %s\r\n' % (self.out_version, code, code_message)) if headers is not None: for name, valuelist in headers.getAllRawHeaders(): for value in valuelist: l.append("%s: %s\r\n" % (name, value)) if addConnectionHeaders: # if we don't have a content length, we send data in # chunked mode, so that we can support persistent connections. if (headers.getHeader('content-length') is None and self.command != "HEAD" and code not in http.NO_BODY_CODES): if self.version >= (1,1): l.append("%s: %s\r\n" % ('Transfer-Encoding', 'chunked')) self.chunkedOut = True else: # Cannot use persistent connections if we can't do chunking self.channel.dropQueuedRequests() if self.channel.isLastRequest(self): l.append("%s: %s\r\n" % ('Connection', 'close')) elif self.version < (1,1): l.append("%s: %s\r\n" % ('Connection', 'Keep-Alive')) l.append("\r\n") self.transport.writeSequence(l) def write(self, data): if not data: return elif self.chunkedOut: self.transport.writeSequence(("%X\r\n" % len(data), data, "\r\n")) else: self.transport.write(data) def finish(self): """We are finished writing data.""" if self.finished: warnings.warn("Warning! request.finish called twice.", stacklevel=2) return if self.chunkedOut: # write last chunk and closing CRLF self.transport.write("0\r\n\r\n") self.finished = True if not self.queued: self._cleanup() def abortConnection(self, closeWrite=True): """Abort the HTTP connection because of some kind of unrecoverable error. If closeWrite=False, then only abort reading, but leave the writing side alone. This is mostly for internal use by the HTTP request parsing logic, so that it can call an error page generator. Otherwise, completely shut down the connection. """ self.abortParse() if closeWrite: if self.producer: self.producer.stopProducing() self.unregisterProducer() self.finished = True if self.queued: self.transport.reset() self.transport.truncate() else: self._cleanup() def getHostInfo(self): return self.channel._host, self.channel._secure def getRemoteHost(self): return self.channel.transport.getPeer() ##### End Request Callbacks ##### def _abortWithError(self, errorcode, text=''): """Handle low level protocol errors.""" headers = http_headers.Headers() headers.setHeader('content-length', len(text)+1) self.abortConnection(closeWrite=False) self.writeHeaders(errorcode, headers) self.write(text) self.write("\n") self.finish() log.warn("Aborted request (%d) %s" % (errorcode, text)) raise AbortedException def _cleanup(self): """Called when have finished responding and are no longer queued.""" if self.producer: log.error(RuntimeError("Producer was not unregistered for %s" % self)) self.unregisterProducer() self.channel.requestWriteFinished(self) del self.transport # methods for channel - end users should not use these def noLongerQueued(self): """Notify the object that it is no longer queued. We start writing whatever data we have to the transport, etc. This method is not intended for users. """ if not self.queued: raise RuntimeError, "noLongerQueued() got called unnecessarily." self.queued = 0 # set transport to real one and send any buffer data data = self.transport.getvalue() self.transport = self.channel.transport if data: self.transport.write(data) # if we have producer, register it with transport if (self.producer is not None) and not self.finished: self.transport.registerProducer(self.producer, True) # if we're finished, clean up if self.finished: self._cleanup() # consumer interface def registerProducer(self, producer, streaming): """Register a producer. """ if self.producer: raise ValueError, "registering producer %s before previous one (%s) was unregistered" % (producer, self.producer) self.producer = producer if self.queued: producer.pauseProducing() else: self.transport.registerProducer(producer, streaming) def unregisterProducer(self): """Unregister the producer.""" if not self.queued: self.transport.unregisterProducer() self.producer = None def connectionLost(self, reason): """connection was lost""" if self.queued and self.producer: self.producer.stopProducing() self.producer = None if self.request: self.request.connectionLost(reason) class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin, object): """A receiver for HTTP requests. Handles splitting up the connection for the multiple HTTPChannelRequests that may be in progress on this channel. @ivar timeOut: number of seconds to wait before terminating an idle connection. @ivar maxPipeline: number of outstanding in-progress requests to allow before pausing the input. @ivar maxHeaderLength: number of bytes of header to accept from the client. """ implements(interfaces.IHalfCloseableProtocol) ## Configuration parameters. Set in instances or subclasses. # How many simultaneous requests to handle. maxPipeline = 4 # Timeout when between two requests betweenRequestsTimeOut = 15 # Timeout between lines or bytes while reading a request inputTimeOut = 60 * 4 # Timeout between end of request read and end of response write idleTimeOut = 60 * 5 # Timeout when closing non-persistent connection closeTimeOut = 20 # maximum length of headers (10KiB) maxHeaderLength = 10240 # Allow persistent connections? allowPersistentConnections = True # ChannelRequest chanRequestFactory = HTTPChannelRequest requestFactory = http.Request _first_line = 2 readPersistent = PERSIST_PIPELINE _readLost = False _writeLost = False _abortTimer = None chanRequest = None def _callLater(self, secs, fun): reactor.callLater(secs, fun) def __init__(self): # the request queue self.requests = [] def connectionMade(self): self._secure = interfaces.ISSLTransport(self.transport, None) is not None address = self.transport.getHost() self._host = _cachedGetHostByAddr(address.host) self.setTimeout(self.inputTimeOut) self.factory.addConnectedChannel(self) def lineReceived(self, line): if self._first_line: self.setTimeout(self.inputTimeOut) # if this connection is not persistent, drop any data which # the client (illegally) sent after the last request. if not self.readPersistent: self.dataReceived = self.lineReceived = lambda *args: None return # IE sends an extraneous empty line (\r\n) after a POST request; # eat up such a line, but only ONCE if not line and self._first_line == 1: self._first_line = 2 return self._first_line = 0 if not self.allowPersistentConnections: # Don't allow a second request self.readPersistent = False try: self.chanRequest = self.chanRequestFactory(self, len(self.requests)) self.requests.append(self.chanRequest) self.chanRequest.gotInitialLine(line) except AbortedException: pass else: try: self.chanRequest.lineReceived(line) except AbortedException: pass def lineLengthExceeded(self, line): if self._first_line: # Fabricate a request object to respond to the line length violation. self.chanRequest = self.chanRequestFactory(self, len(self.requests)) self.requests.append(self.chanRequest) self.chanRequest.gotInitialLine("GET fake HTTP/1.0") try: self.chanRequest.lineLengthExceeded(line, self._first_line) except AbortedException: pass def rawDataReceived(self, data): self.setTimeout(self.inputTimeOut) try: self.chanRequest.rawDataReceived(data) except AbortedException: pass def requestReadFinished(self, request): if(self.readPersistent is PERSIST_NO_PIPELINE or len(self.requests) >= self.maxPipeline): self.pauseProducing() # reset state variables self._first_line = 1 self.chanRequest = None self.setLineMode() # Set an idle timeout, in case this request takes a long # time to finish generating output. if len(self.requests) > 0: self.setTimeout(self.idleTimeOut) def _startNextRequest(self): # notify next request, if present, it can start writing del self.requests[0] if self._writeLost: self.transport.loseConnection() elif self.requests: self.requests[0].noLongerQueued() # resume reading if allowed to if(not self._readLost and self.readPersistent is not PERSIST_NO_PIPELINE and len(self.requests) < self.maxPipeline): self.resumeProducing() elif self._readLost: # No more incoming data, they already closed! self.transport.loseConnection() else: # no requests in queue, resume reading self.setTimeout(self.betweenRequestsTimeOut) self.resumeProducing() def setReadPersistent(self, persistent): if self.readPersistent: # only allow it to be set if it's not currently False self.readPersistent = persistent def dropQueuedRequests(self): """Called when a response is written that forces a connection close.""" self.readPersistent = False # Tell all requests but first to abort. for request in self.requests[1:]: request.connectionLost(None) del self.requests[1:] def isLastRequest(self, request): # Is this channel handling the last possible request return not self.readPersistent and self.requests[-1] == request def requestWriteFinished(self, request): """Called by first request in queue when it is done.""" if request != self.requests[0]: raise TypeError # Don't del because we haven't finished cleanup, so, # don't want queue len to be 0 yet. self.requests[0] = None if self.readPersistent or len(self.requests) > 1: # Do this in the next reactor loop so as to # not cause huge call stacks with fast # incoming requests. self._callLater(0, self._startNextRequest) else: # Set an abort timer in case an orderly close hangs self.setTimeout(None) self._abortTimer = reactor.callLater(self.closeTimeOut, self._abortTimeout) #reactor.callLater(0.1, self.transport.loseConnection) self.transport.loseConnection() def timeoutConnection(self): #log.info("Timing out client: %s" % str(self.transport.getPeer())) # Set an abort timer in case an orderly close hangs self._abortTimer = reactor.callLater(self.closeTimeOut, self._abortTimeout) policies.TimeoutMixin.timeoutConnection(self) def _abortTimeout(self): log.error("Connection aborted - took too long to close: {c}", c=str(self.transport.getPeer())) self._abortTimer = None self.transport.abortConnection() def readConnectionLost(self): """Read connection lost""" # If in the lingering-close state, lose the socket. if self._abortTimer: self._abortTimer.cancel() self._abortTimer = None self.transport.loseConnection() return # If between requests, drop connection # when all current requests have written their data. self._readLost = True if not self.requests: # No requests in progress, lose now. self.transport.loseConnection() # If currently in the process of reading a request, this is # probably a client abort, so lose the connection. if self.chanRequest: self.transport.loseConnection() def connectionLost(self, reason): self.factory.removeConnectedChannel(self) self._writeLost = True self.readConnectionLost() self.setTimeout(None) # Tell all requests to abort. for request in self.requests: if request is not None: request.connectionLost(reason) class OverloadedServerProtocol(protocol.Protocol): def connectionMade(self): self.transport.write("HTTP/1.0 503 Service Unavailable\r\n" "Content-Type: text/html\r\n" "Connection: close\r\n\r\n" "503 Service Unavailable" "

Service Unavailable

" "The server is currently overloaded, " "please try again later.") self.transport.loseConnection() class HTTPFactory(protocol.ServerFactory): """ Factory for HTTP server. @ivar outstandingRequests: the number of currently connected HTTP channels. @type outstandingRequests: C{int} @ivar connectedChannels: all the channels that have currently active connections. @type connectedChannels: C{set} of L{HTTPChannel} """ protocol = HTTPChannel protocolArgs = None def __init__(self, requestFactory, maxRequests=600, **kwargs): self.maxRequests = maxRequests self.protocolArgs = kwargs self.protocolArgs['requestFactory'] = requestFactory self.connectedChannels = set() self.allConnectionsClosedDeferred = None def buildProtocol(self, addr): if self.outstandingRequests >= self.maxRequests: return OverloadedServerProtocol() p = protocol.ServerFactory.buildProtocol(self, addr) for arg,value in self.protocolArgs.iteritems(): setattr(p, arg, value) return p def addConnectedChannel(self, channel): """ Add a connected channel to the set of currently connected channels and increase the outstanding request count. """ self.connectedChannels.add(channel) def removeConnectedChannel(self, channel): """ Remove a connected channel from the set of currently connected channels and decrease the outstanding request count. If someone is waiting for all the requests to be completed, self.allConnectionsClosedDeferred will be non-None; fire that callback when the number of outstanding requests hits zero. """ self.connectedChannels.remove(channel) if self.allConnectionsClosedDeferred is not None: if self.outstandingRequests == 0: self.allConnectionsClosedDeferred.callback(None) @property def outstandingRequests(self): return len(self.connectedChannels) def allConnectionsClosed(self): """ Return a Deferred that will fire when all outstanding requests have completed. @return: A Deferred with a result of None """ if self.outstandingRequests == 0: return succeed(None) self.allConnectionsClosedDeferred = Deferred() return self.allConnectionsClosedDeferred class HTTP503LoggingFactory (HTTPFactory): """ Factory for HTTP server which emits a 503 response when overloaded. """ def __init__(self, requestFactory, maxRequests=600, retryAfter=0, vary=False, **kwargs): self.retryAfter = retryAfter self.vary = vary HTTPFactory.__init__(self, requestFactory, maxRequests, **kwargs) def buildProtocol(self, addr): if self.vary: retryAfter = randint(int(self.retryAfter * 1/2), int(self.retryAfter * 3/2)) else: retryAfter = self.retryAfter if self.outstandingRequests >= self.maxRequests: return OverloadedLoggingServerProtocol(retryAfter, self.outstandingRequests) p = protocol.ServerFactory.buildProtocol(self, addr) for arg,value in self.protocolArgs.iteritems(): setattr(p, arg, value) return p class HTTPLoggingChannelRequest(HTTPChannelRequest): class TransportLoggingWrapper(object): def __init__(self, transport, logData): self.transport = transport self.logData = logData def write(self, data): if self.logData is not None and data: self.logData.append(data) self.transport.write(data) def writeSequence(self, seq): if self.logData is not None and seq: self.logData.append(''.join(seq)) self.transport.writeSequence(seq) def __getattr__(self, attr): return getattr(self.__dict__['transport'], attr) class LogData(object): def __init__(self): self.request = [] self.response = [] def __init__(self, channel, queued=0): super(HTTPLoggingChannelRequest, self).__init__(channel, queued) if accounting.accountingEnabledForCategory("HTTP"): self.logData = HTTPLoggingChannelRequest.LogData() self.transport = HTTPLoggingChannelRequest.TransportLoggingWrapper(self.transport, self.logData.response) else: self.logData = None def gotInitialLine(self, initialLine): if self.logData is not None: self.startTime = time.time() self.logData.request.append(">>>> Request starting at: %.3f\r\n\r\n" % (self.startTime,)) self.logData.request.append("%s\r\n" % (initialLine,)) super(HTTPLoggingChannelRequest, self).gotInitialLine(initialLine) def lineReceived(self, line): if self.logData is not None: # We don't want to log basic credentials loggedLine = line if line.lower().startswith("authorization:"): bits = line[14:].strip().split(" ") if bits[0].lower() == "basic" and len(bits) == 2: loggedLine = "%s %s %s" % (line[:14], bits[0], "X" * len(bits[1])) self.logData.request.append("%s\r\n" % (loggedLine,)) super(HTTPLoggingChannelRequest, self).lineReceived(line) def handleContentChunk(self, data): if self.logData is not None: self.logData.request.append(data) super(HTTPLoggingChannelRequest, self).handleContentChunk(data) def handleContentComplete(self): if self.logData is not None: doneTime = time.time() self.logData.request.append("\r\n\r\n>>>> Request complete at: %.3f (elapsed: %.1f ms)" % (doneTime, 1000 * (doneTime - self.startTime),)) super(HTTPLoggingChannelRequest, self).handleContentComplete() def writeHeaders(self, code, headers): if self.logData is not None: doneTime = time.time() self.logData.response.append("\r\n\r\n<<<< Response sending at: %.3f (elapsed: %.1f ms)\r\n\r\n" % (doneTime, 1000 * (doneTime - self.startTime),)) super(HTTPLoggingChannelRequest, self).writeHeaders(code, headers) def finish(self): super(HTTPLoggingChannelRequest, self).finish() if self.logData is not None: doneTime = time.time() self.logData.response.append("\r\n\r\n<<<< Response complete at: %.3f (elapsed: %.1f ms)\r\n" % (doneTime, 1000 * (doneTime - self.startTime),)) accounting.emitAccounting("HTTP", "", "".join(self.logData.request) + "".join(self.logData.response), self.command) HTTPChannel.chanRequestFactory = HTTPLoggingChannelRequest class LimitingHTTPFactory(HTTPFactory): """ HTTPFactory which stores maxAccepts on behalf of the MaxAcceptPortMixin @ivar myServer: a reference to a L{MaxAcceptTCPServer} that this L{LimitingHTTPFactory} will limit. This must be set externally. """ def __init__(self, requestFactory, maxRequests=600, maxAccepts=100, **kwargs): HTTPFactory.__init__(self, requestFactory, maxRequests, **kwargs) self.maxAccepts = maxAccepts def buildProtocol(self, addr): """ Override L{HTTPFactory.buildProtocol} in order to avoid ever returning an L{OverloadedServerProtocol}; this should be handled in other ways. """ p = protocol.ServerFactory.buildProtocol(self, addr) for arg, value in self.protocolArgs.iteritems(): setattr(p, arg, value) return p def addConnectedChannel(self, channel): """ Override L{HTTPFactory.addConnectedChannel} to pause listening on the socket when there are too many outstanding channels. """ HTTPFactory.addConnectedChannel(self, channel) if self.outstandingRequests >= self.maxRequests: self.myServer.myPort.stopReading() def removeConnectedChannel(self, channel): """ Override L{HTTPFactory.removeConnectedChannel} to resume listening on the socket when there are too many outstanding channels. """ HTTPFactory.removeConnectedChannel(self, channel) if self.outstandingRequests < self.maxRequests: self.myServer.myPort.startReading() __all__ = [ "HTTPFactory", "HTTP503LoggingFactory", "LimitingHTTPFactory", "SSLRedirectRequest", ] calendarserver-5.2+dfsg/twext/web2/channel/__init__.py0000644000175000017500000000257012263343324022050 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_cgi,twext.web2.test.test_http -*- ## # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ Various backend channel implementations for web2. """ from twext.web2.channel.http import HTTPFactory __all__ = ['HTTPFactory'] calendarserver-5.2+dfsg/twext/web2/dav/0000755000175000017500000000000012322625326017076 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/dav/test/0000755000175000017500000000000012322625325020054 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/dav/test/test_move.py0000644000175000017500000001067212263343324022441 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## import os import twext.web2.dav.test.util import twext.web2.dav.test.test_copy from twext.web2 import responsecode from twext.web2.dav.test.util import serialize from twext.web2.dav.test.test_copy import sumFile class MOVE(twext.web2.dav.test.util.TestCase): """ MOVE request """ # FIXME: # Check that properties are being moved def test_MOVE_create(self): """ MOVE to new resource. """ def test(response, path, isfile, sum, uri, depth, dst_path): if response.code != responsecode.CREATED: self.fail("Incorrect response code for MOVE %s: %s != %s" % (uri, response.code, responsecode.CREATED)) if response.headers.getHeader("location") is None: self.fail("Reponse to MOVE %s with CREATE status is missing location: header." % (uri,)) if isfile: if not os.path.isfile(dst_path): self.fail("MOVE %s produced no output file" % (uri,)) if sum != sumFile(dst_path): self.fail("MOVE %s produced different file" % (uri,)) else: if not os.path.isdir(dst_path): self.fail("MOVE %s produced no output directory" % (uri,)) if sum != sumFile(dst_path): self.fail("isdir %s produced different directory" % (uri,)) return serialize(self.send, work(self, test)) def test_MOVE_exists(self): """ MOVE to existing resource. """ def test(response, path, isfile, sum, uri, depth, dst_path): if response.code != responsecode.PRECONDITION_FAILED: self.fail("Incorrect response code for MOVE without overwrite %s: %s != %s" % (uri, response.code, responsecode.PRECONDITION_FAILED)) else: # FIXME: Check XML error code (2518bis) pass return serialize(self.send, work(self, test, overwrite=False)) def test_MOVE_overwrite(self): """ MOVE to existing resource with overwrite header. """ def test(response, path, isfile, sum, uri, depth, dst_path): if response.code != responsecode.NO_CONTENT: self.fail("Incorrect response code for MOVE with overwrite %s: %s != %s" % (uri, response.code, responsecode.NO_CONTENT)) else: # FIXME: Check XML error code (2518bis) pass return serialize(self.send, work(self, test, overwrite=True)) def test_MOVE_no_parent(self): """ MOVE to resource with no parent. """ def test(response, path, isfile, sum, uri, depth, dst_path): if response.code != responsecode.CONFLICT: self.fail("Incorrect response code for MOVE with no parent %s: %s != %s" % (uri, response.code, responsecode.CONFLICT)) else: # FIXME: Check XML error code (2518bis) pass return serialize(self.send, work(self, test, dst=os.path.join(self.docroot, "elvislives!"))) def work(self, test, overwrite=None, dst=None): return twext.web2.dav.test.test_copy.work(self, test, overwrite, dst, depths=(None,)) calendarserver-5.2+dfsg/twext/web2/dav/test/tworequest_client.py0000644000175000017500000000171311337102650024204 0ustar rahulrahulimport socket, sys test_type = sys.argv[1] port = int(sys.argv[2]) socket_type = sys.argv[3] s = socket.socket(socket.AF_INET) s.connect(("127.0.0.1", port)) s.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 40000) if socket_type == 'ssl': s2 = socket.ssl(s) send=s2.write recv=s2.read else: send=s.send recv=s.recv print >> sys.stderr, ">> Making %s request to port %d" % (socket_type, port) send("PUT /forbidden HTTP/1.1\r\n") send("Host: localhost\r\n") print >> sys.stderr, ">> Sending lots of data" send("Content-Length: 100\r\n\r\n") send("X"*100) send("PUT /forbidden HTTP/1.1\r\n") send("Host: localhost\r\n") print >> sys.stderr, ">> Sending lots of data" send("Content-Length: 100\r\n\r\n") send("X"*100) #import time #time.sleep(5) print >> sys.stderr, ">> Getting data" data='' while len(data) < 299999: try: x=recv(10000) except: break if x == '': break data+=x sys.stdout.write(data) calendarserver-5.2+dfsg/twext/web2/dav/test/test_put.py0000644000175000017500000001223212263343324022275 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## import os import filecmp from twext.web2 import responsecode from twext.web2.iweb import IResponse from twext.web2.stream import FileStream from twext.web2.http import HTTPError import twext.web2.dav.test.util from twext.web2.test.test_server import SimpleRequest from twext.web2.dav.test.util import serialize class PUT(twext.web2.dav.test.util.TestCase): """ PUT request """ def test_PUT_simple(self): """ PUT request """ dst_path = os.path.join(self.docroot, "dst") def checkResult(response, path): response = IResponse(response) if response.code not in ( responsecode.CREATED, responsecode.NO_CONTENT ): self.fail("PUT failed: %s" % (response.code,)) if not os.path.isfile(dst_path): self.fail("PUT failed to create file %s." % (dst_path,)) if not filecmp.cmp(path, dst_path): self.fail("PUT failed to preserve data for file %s in file %s." % (path, dst_path)) etag = response.headers.getHeader("etag") if not etag: self.fail("No etag header in PUT response %r." % (response,)) # # We need to serialize these request & test iterations because they can # interfere with each other. # def work(): dst_uri = "/dst" for name in os.listdir(self.docroot): if name == "dst": continue path = os.path.join(self.docroot, name) # Can't really PUT something you can't read if not os.path.isfile(path): continue def do_test(response): checkResult(response, path) request = SimpleRequest(self.site, "PUT", dst_uri) request.stream = FileStream(file(path, "rb")) yield (request, do_test) return serialize(self.send, work()) def test_PUT_again(self): """ PUT on existing resource with If-None-Match header """ dst_path = os.path.join(self.docroot, "dst") dst_uri = "/dst" def work(): for code in ( responsecode.CREATED, responsecode.PRECONDITION_FAILED, responsecode.NO_CONTENT, responsecode.PRECONDITION_FAILED, responsecode.NO_CONTENT, responsecode.CREATED, ): def checkResult(response, code=code): response = IResponse(response) if response.code != code: self.fail("Incorrect response code for PUT (%s != %s)" % (response.code, code)) def onError(f): f.trap(HTTPError) return checkResult(f.value.response) request = SimpleRequest(self.site, "PUT", dst_uri) request.stream = FileStream(file(__file__, "rb")) if code == responsecode.CREATED: if os.path.isfile(dst_path): os.remove(dst_path) request.headers.setHeader("if-none-match", ("*",)) elif code == responsecode.PRECONDITION_FAILED: request.headers.setHeader("if-none-match", ("*",)) yield (request, (checkResult, onError)) return serialize(self.send, work()) def test_PUT_no_parent(self): """ PUT with no parent """ dst_uri = "/put/no/parent" def checkResult(response): response = IResponse(response) if response.code != responsecode.CONFLICT: self.fail("Incorrect response code for PUT with no parent (%s != %s)" % (response.code, responsecode.CONFLICT)) request = SimpleRequest(self.site, "PUT", dst_uri) request.stream = FileStream(file(__file__, "rb")) return self.send(request, checkResult) calendarserver-5.2+dfsg/twext/web2/dav/test/test_acl.py0000644000175000017500000003566512263343324022243 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## import os from twisted.cred.portal import Portal from twext.web2 import responsecode from twext.web2.auth import basic from twext.web2.stream import MemoryStream from twext.web2.dav.util import davXMLFromStream from twext.web2.dav.auth import TwistedPasswordProperty, IPrincipal, DavRealm, TwistedPropertyChecker, AuthenticationWrapper from twext.web2.dav.fileop import rmdir from twext.web2.test.test_server import SimpleRequest from twext.web2.dav.test.util import Site, serialize from twext.web2.dav.test.test_resource import \ TestDAVPrincipalResource, TestPrincipalsCollection from txdav.xml import element import twext.web2.dav.test.util class ACL(twext.web2.dav.test.util.TestCase): """ RFC 3744 (WebDAV ACL) tests. """ def createDocumentRoot(self): docroot = self.mktemp() os.mkdir(docroot) userResource = TestDAVPrincipalResource("/principals/users/user01") userResource.writeDeadProperty(TwistedPasswordProperty("user01")) principalCollection = TestPrincipalsCollection( "/principals/", children={"users": TestPrincipalsCollection( "/principals/users/", children={"user01": userResource})}) rootResource = self.resource_class( docroot, principalCollections=(principalCollection,)) portal = Portal(DavRealm()) portal.registerChecker(TwistedPropertyChecker()) credentialFactories = (basic.BasicCredentialFactory(""),) loginInterfaces = (IPrincipal,) self.site = Site(AuthenticationWrapper( rootResource, portal, credentialFactories, credentialFactories, loginInterfaces )) rootResource.setAccessControlList(self.grant(element.All())) for name, acl in ( ("none" , self.grant()), ("read" , self.grant(element.Read())), ("read-write" , self.grant(element.Read(), element.Write())), ("unlock" , self.grant(element.Unlock())), ("all" , self.grant(element.All())), ): filename = os.path.join(docroot, name) if not os.path.isfile(filename): file(filename, "w").close() resource = self.resource_class(filename) resource.setAccessControlList(acl) for name, acl in ( ("nobind" , self.grant()), ("bind" , self.grant(element.Bind())), ("unbind" , self.grant(element.Bind(), element.Unbind())), ): dirname = os.path.join(docroot, name) if not os.path.isdir(dirname): os.mkdir(dirname) resource = self.resource_class(dirname) resource.setAccessControlList(acl) return docroot def restore(self): # Get rid of whatever messed up state the test has now so that we'll # get a fresh docroot. This isn't very cool; tests should be doing # less so that they don't need a fresh copy of this state. if hasattr(self, "_docroot"): rmdir(self._docroot) del self._docroot def test_COPY_MOVE_source(self): """ Verify source access controls during COPY and MOVE. """ def work(): dst_path = os.path.join(self.docroot, "copy_dst") dst_uri = "/" + os.path.basename(dst_path) for src, status in ( ("nobind", responsecode.FORBIDDEN), ("bind", responsecode.FORBIDDEN), ("unbind", responsecode.CREATED), ): src_path = os.path.join(self.docroot, "src_" + src) src_uri = "/" + os.path.basename(src_path) if not os.path.isdir(src_path): os.mkdir(src_path) src_resource = self.resource_class(src_path) src_resource.setAccessControlList({ "nobind": self.grant(), "bind" : self.grant(element.Bind()), "unbind": self.grant(element.Bind(), element.Unbind()) }[src]) for name, acl in ( ("none" , self.grant()), ("read" , self.grant(element.Read())), ("read-write" , self.grant(element.Read(), element.Write())), ("unlock" , self.grant(element.Unlock())), ("all" , self.grant(element.All())), ): filename = os.path.join(src_path, name) if not os.path.isfile(filename): file(filename, "w").close() self.resource_class(filename).setAccessControlList(acl) for method in ("COPY", "MOVE"): for name, code in ( ("none" , {"COPY": responsecode.FORBIDDEN, "MOVE": status}[method]), ("read" , {"COPY": responsecode.CREATED, "MOVE": status}[method]), ("read-write" , {"COPY": responsecode.CREATED, "MOVE": status}[method]), ("unlock" , {"COPY": responsecode.FORBIDDEN, "MOVE": status}[method]), ("all" , {"COPY": responsecode.CREATED, "MOVE": status}[method]), ): path = os.path.join(src_path, name) uri = src_uri + "/" + name request = SimpleRequest(self.site, method, uri) request.headers.setHeader("destination", dst_uri) _add_auth_header(request) def test(response, code=code, path=path): if os.path.isfile(dst_path): os.remove(dst_path) if response.code != code: return self.oops(request, response, code, method, name) yield (request, test) return serialize(self.send, work()) def test_COPY_MOVE_dest(self): """ Verify destination access controls during COPY and MOVE. """ def work(): src_path = os.path.join(self.docroot, "read") uri = "/" + os.path.basename(src_path) for method in ("COPY", "MOVE"): for name, code in ( ("nobind" , responsecode.FORBIDDEN), ("bind" , responsecode.CREATED), ("unbind" , responsecode.CREATED), ): dst_parent_path = os.path.join(self.docroot, name) dst_path = os.path.join(dst_parent_path, "dst") request = SimpleRequest(self.site, method, uri) request.headers.setHeader("destination", "/" + name + "/dst") _add_auth_header(request) def test(response, code=code, dst_path=dst_path): if os.path.isfile(dst_path): os.remove(dst_path) if response.code != code: return self.oops(request, response, code, method, name) yield (request, test) self.restore() return serialize(self.send, work()) def test_DELETE(self): """ Verify access controls during DELETE. """ def work(): for name, code in ( ("nobind" , responsecode.FORBIDDEN), ("bind" , responsecode.FORBIDDEN), ("unbind" , responsecode.NO_CONTENT), ): collection_path = os.path.join(self.docroot, name) path = os.path.join(collection_path, "dst") file(path, "w").close() request = SimpleRequest(self.site, "DELETE", "/" + name + "/dst") _add_auth_header(request) def test(response, code=code, path=path): if response.code != code: return self.oops(request, response, code, "DELETE", name) yield (request, test) return serialize(self.send, work()) def test_UNLOCK(self): """ Verify access controls during UNLOCK of unowned lock. """ raise NotImplementedError() test_UNLOCK.todo = "access controls on UNLOCK unimplemented" def test_MKCOL_PUT(self): """ Verify access controls during MKCOL. """ for method in ("MKCOL", "PUT"): def work(): for name, code in ( ("nobind" , responsecode.FORBIDDEN), ("bind" , responsecode.CREATED), ("unbind" , responsecode.CREATED), ): collection_path = os.path.join(self.docroot, name) path = os.path.join(collection_path, "dst") if os.path.isfile(path): os.remove(path) elif os.path.isdir(path): os.rmdir(path) request = SimpleRequest(self.site, method, "/" + name + "/dst") _add_auth_header(request) def test(response, code=code, path=path): if response.code != code: return self.oops(request, response, code, method, name) yield (request, test) return serialize(self.send, work()) def test_PUT_exists(self): """ Verify access controls during PUT of existing file. """ def work(): for name, code in ( ("none" , responsecode.FORBIDDEN), ("read" , responsecode.FORBIDDEN), ("read-write" , responsecode.NO_CONTENT), ("unlock" , responsecode.FORBIDDEN), ("all" , responsecode.NO_CONTENT), ): path = os.path.join(self.docroot, name) request = SimpleRequest(self.site, "PUT", "/" + name) _add_auth_header(request) def test(response, code=code, path=path): if response.code != code: return self.oops(request, response, code, "PUT", name) yield (request, test) return serialize(self.send, work()) def test_PROPFIND(self): """ Verify access controls during PROPFIND. """ raise NotImplementedError() test_PROPFIND.todo = "access controls on PROPFIND unimplemented" def test_PROPPATCH(self): """ Verify access controls during PROPPATCH. """ def work(): for name, code in ( ("none" , responsecode.FORBIDDEN), ("read" , responsecode.FORBIDDEN), ("read-write" , responsecode.MULTI_STATUS), ("unlock" , responsecode.FORBIDDEN), ("all" , responsecode.MULTI_STATUS), ): path = os.path.join(self.docroot, name) request = SimpleRequest(self.site, "PROPPATCH", "/" + name) request.stream = MemoryStream( element.WebDAVDocument(element.PropertyUpdate()).toxml() ) _add_auth_header(request) def test(response, code=code, path=path): if response.code != code: return self.oops(request, response, code, "PROPPATCH", name) yield (request, test) return serialize(self.send, work()) def test_GET_REPORT(self): """ Verify access controls during GET and REPORT. """ def work(): for method in ("GET", "REPORT"): if method == "GET": ok = responsecode.OK elif method == "REPORT": ok = responsecode.MULTI_STATUS else: raise AssertionError("We shouldn't be here. (method = %r)" % (method,)) for name, code in ( ("none" , responsecode.FORBIDDEN), ("read" , ok), ("read-write" , ok), ("unlock" , responsecode.FORBIDDEN), ("all" , ok), ): path = os.path.join(self.docroot, name) request = SimpleRequest(self.site, method, "/" + name) if method == "REPORT": request.stream = MemoryStream(element.PrincipalPropertySearch().toxml()) _add_auth_header(request) def test(response, code=code, path=path): if response.code != code: return self.oops(request, response, code, method, name) yield (request, test) return serialize(self.send, work()) def oops(self, request, response, code, method, name): def gotResponseData(doc): if doc is None: doc_xml = None else: doc_xml = doc.toxml() def fail(acl): self.fail("Incorrect status code %s (!= %s) for %s of resource %s with %s ACL: %s\nACL: %s" % (response.code, code, method, request.uri, name, doc_xml, acl.toxml())) def getACL(resource): return resource.accessControlList(request) d = request.locateResource(request.uri) d.addCallback(getACL) d.addCallback(fail) return d d = davXMLFromStream(response.stream) d.addCallback(gotResponseData) return d def _add_auth_header(request): request.headers.setHeader( "authorization", ("basic", "user01:user01".encode("base64")) ) calendarserver-5.2+dfsg/twext/web2/dav/test/test_mkcol.py0000644000175000017500000000600112263343324022567 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## import os from twext.web2 import responsecode from twext.web2.iweb import IResponse from twext.web2.stream import MemoryStream from twext.web2.dav.fileop import rmdir from twext.web2.test.test_server import SimpleRequest import twext.web2.dav.test.util class MKCOL(twext.web2.dav.test.util.TestCase): """ MKCOL request """ # FIXME: # Try in nonexistant parent collection. # Try on existing resource. # Try with request body? def test_MKCOL(self): """ MKCOL request """ path, uri = self.mkdtemp("collection") rmdir(path) def check_result(response): response = IResponse(response) if response.code != responsecode.CREATED: self.fail("MKCOL response %s != %s" % (response.code, responsecode.CREATED)) if not os.path.isdir(path): self.fail("MKCOL did not create directory %s" % (path,)) request = SimpleRequest(self.site, "MKCOL", uri) return self.send(request, check_result) def test_MKCOL_invalid_body(self): """ MKCOL request with invalid request body (Any body at all is invalid in our implementation; there is no such thing as a valid body.) """ path, uri = self.mkdtemp("collection") rmdir(path) def check_result(response): response = IResponse(response) if response.code != responsecode.UNSUPPORTED_MEDIA_TYPE: self.fail("MKCOL response %s != %s" % (response.code, responsecode.UNSUPPORTED_MEDIA_TYPE)) if os.path.isdir(path): self.fail("MKCOL incorrectly created directory %s" % (path,)) request = SimpleRequest(self.site, "MKCOL", uri) request.stream = MemoryStream("This is not a valid MKCOL request body.") return self.send(request, check_result) calendarserver-5.2+dfsg/twext/web2/dav/test/test_static.py0000644000175000017500000000452412263343324022761 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from twext.web2.dav.test import util from txdav.xml import element as davxml from twext.web2.stream import readStream from twext.web2.test.test_server import SimpleRequest class DAVFileTest(util.TestCase): def test_renderPrivileges(self): """ Verify that a directory listing includes children which you don't have access to. """ request = SimpleRequest(self.site, "GET", "/") def setEmptyACL(resource): resource.setAccessControlList(davxml.ACL()) # Empty ACL = no access return resource def renderRoot(_): d = request.locateResource("/") d.addCallback(lambda r: r.render(request)) return d def assertListing(response): data = [] d = readStream(response.stream, lambda s: data.append(str(s))) d.addCallback(lambda _: self.failIf( 'dir2/' not in "".join(data), "'dir2' expected in listing: %r" % (data,) )) return d d = request.locateResource("/dir2") d.addCallback(setEmptyACL) d.addCallback(renderRoot) d.addCallback(assertListing) return d calendarserver-5.2+dfsg/twext/web2/dav/test/data/0000755000175000017500000000000012322625325020765 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/0000755000175000017500000000000012322625325021565 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPFIND_request.xml0000644000175000017500000000051411337102650025274 0ustar rahulrahul calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/REPORT_request.xml0000644000175000017500000000044211337102650025066 0ustar rahulrahul calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPFIND_response.xml0000644000175000017500000000424611337102650025450 0ustar rahulrahul /uploads/ 2005-07-05T23:08:01Z Tue, 05 Jul 2005 23:08:01 GMT "77a99-66-27dd9640" httpd/unix-directory HTTP/1.1 200 OK /uploads/foo.txt 2005-07-05T23:08:08Z 19 Tue, 05 Jul 2005 23:08:08 GMT "77a9f-13-28486600" F text/plain HTTP/1.1 200 OK calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPFIND_bad.xml0000644000175000017500000000014411337102650024331 0ustar rahulrahul calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPPATCH_request.xml0000644000175000017500000000173011337102650025414 0ustar rahulrahul value0 value1 value2 value3 value4 value5 value6 value7 value8 value9 calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPFIND_nonamespace.xml0000644000175000017500000000017211337102650026075 0ustar rahulrahul calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/REPORT_response.xml0000644000175000017500000000347311337102650025243 0ustar rahulrahul http://www.webdav.org/foo.html http://repo.webdav.org/his/23 http://repo.webdav.org/his/23/ver/1 Fred http://www.webdav.org/ws/dev/sally HTTP/1.1 200 OK http://repo.webdav.org/his/23/ver/2 Sally http://repo.webdav.org/act/add-refresh-cmd HTTP/1.1 200 OK HTTP/1.1 200 OK HTTP/1.1 200 OK calendarserver-5.2+dfsg/twext/web2/dav/test/data/quota_100.txt0000644000175000017500000000014411337102650023232 0ustar rahulrahul123456789 123456789 123456789 123456789 123456789 123456789 123456789 123456789 123456789 123456789 calendarserver-5.2+dfsg/twext/web2/dav/test/test_prop.py0000644000175000017500000003330212263343324022446 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from twext.web2 import responsecode from twext.web2.iweb import IResponse from twext.web2.stream import MemoryStream from twext.web2 import http_headers from twext.web2.dav.util import davXMLFromStream from twext.web2.test.test_server import SimpleRequest from twext.web2.dav.test.util import serialize from txdav.xml import element as davxml from txdav.xml.element import dav_namespace, lookupElement import twext.web2.dav.test.util # Remove dynamic live properties that exist dynamicLiveProperties = ( (dav_namespace, "quota-available-bytes" ), (dav_namespace, "quota-used-bytes" ), ) class PROP(twext.web2.dav.test.util.TestCase): """ PROPFIND, PROPPATCH requests """ def liveProperties(self): return [lookupElement(qname)() for qname in self.site.resource.liveProperties() if (qname[0] == dav_namespace) and qname not in dynamicLiveProperties] def test_PROPFIND_basic(self): """ PROPFIND request """ def check_result(response): response = IResponse(response) if response.code != responsecode.MULTI_STATUS: self.fail("Incorrect response code for PROPFIND (%s != %s)" % (response.code, responsecode.MULTI_STATUS)) content_type = response.headers.getHeader("content-type") if content_type not in (http_headers.MimeType("text", "xml"), http_headers.MimeType("application", "xml")): self.fail("Incorrect content-type for PROPFIND response (%r not in %r)" % (content_type, (http_headers.MimeType("text", "xml"), http_headers.MimeType("application", "xml")))) return davXMLFromStream(response.stream).addCallback(check_xml) def check_xml(doc): multistatus = doc.root_element if not isinstance(multistatus, davxml.MultiStatus): self.fail("PROPFIND response XML root element is not multistatus: %r" % (multistatus,)) for response in multistatus.childrenOfType(davxml.PropertyStatusResponse): if response.childOfType(davxml.HRef) == "/": for propstat in response.childrenOfType(davxml.PropertyStatus): status = propstat.childOfType(davxml.Status) properties = propstat.childOfType(davxml.PropertyContainer).children if status.code != responsecode.OK: self.fail("PROPFIND failed (status %s) to locate live properties: %s" % (status.code, properties)) properties_to_find = [p.qname() for p in self.liveProperties()] for property in properties: qname = property.qname() if qname in properties_to_find: properties_to_find.remove(qname) else: self.fail("PROPFIND found property we didn't ask for: %r" % (property,)) if properties_to_find: self.fail("PROPFIND failed to find properties: %r" % (properties_to_find,)) break else: self.fail("No response for URI /") query = davxml.PropertyFind(davxml.PropertyContainer(*self.liveProperties())) request = SimpleRequest(self.site, "PROPFIND", "/") depth = "1" if depth is not None: request.headers.setHeader("depth", depth) request.stream = MemoryStream(query.toxml()) return self.send(request, check_result) def test_PROPFIND_list(self): """ PROPFIND with allprop, propname """ def check_result(which): def _check_result(response): response = IResponse(response) if response.code != responsecode.MULTI_STATUS: self.fail("Incorrect response code for PROPFIND (%s != %s)" % (response.code, responsecode.MULTI_STATUS)) return davXMLFromStream(response.stream).addCallback(check_xml, which) return _check_result def check_xml(doc, which): response = doc.root_element.childOfType(davxml.PropertyStatusResponse) self.failUnless( response.childOfType(davxml.HRef) == "/", "Incorrect response URI: %s != /" % (response.childOfType(davxml.HRef),) ) for propstat in response.childrenOfType(davxml.PropertyStatus): status = propstat.childOfType(davxml.Status) properties = propstat.childOfType(davxml.PropertyContainer).children if status.code != responsecode.OK: self.fail("PROPFIND failed (status %s) to locate live properties: %s" % (status.code, properties)) if which.name == "allprop": properties_to_find = [p.qname() for p in self.liveProperties() if not p.hidden] else: properties_to_find = [p.qname() for p in self.liveProperties()] for property in properties: qname = property.qname() if qname in properties_to_find: properties_to_find.remove(qname) elif qname[0] != dav_namespace: pass else: self.fail("PROPFIND with %s found property we didn't expect: %r" % (which.name, property)) if which.name == "propname": # Element should be empty self.failUnless(len(property.children) == 0) else: # Element should have a value, unless the property exists and is empty... # Verify that there is a value for live properties for which we know # that this should be the case. if property.namespace == dav_namespace and property.name in ( "getetag", "getcontenttype", "getlastmodified", "creationdate", "displayname", ): self.failIf( len(property.children) == 0, "Property has no children: %r" % (property.toxml(),) ) if properties_to_find: self.fail("PROPFIND with %s failed to find properties: %r" % (which.name, properties_to_find)) properties = propstat.childOfType(davxml.PropertyContainer).children def work(): for which in (davxml.AllProperties(), davxml.PropertyName()): query = davxml.PropertyFind(which) request = SimpleRequest(self.site, "PROPFIND", "/") request.headers.setHeader("depth", "0") request.stream = MemoryStream(query.toxml()) yield (request, check_result(which)) return serialize(self.send, work()) def test_PROPPATCH_basic(self): """ PROPPATCH """ # FIXME: # Do PROPFIND to make sure it's still there # Test nonexistant resource # Test None namespace in property def check_patch_response(response): response = IResponse(response) if response.code != responsecode.MULTI_STATUS: self.fail("Incorrect response code for PROPFIND (%s != %s)" % (response.code, responsecode.MULTI_STATUS)) content_type = response.headers.getHeader("content-type") if content_type not in (http_headers.MimeType("text", "xml"), http_headers.MimeType("application", "xml")): self.fail("Incorrect content-type for PROPPATCH response (%r not in %r)" % (content_type, (http_headers.MimeType("text", "xml"), http_headers.MimeType("application", "xml")))) return davXMLFromStream(response.stream).addCallback(check_patch_xml) def check_patch_xml(doc): multistatus = doc.root_element if not isinstance(multistatus, davxml.MultiStatus): self.fail("PROPFIND response XML root element is not multistatus: %r" % (multistatus,)) # Requested a property change one resource, so there should be exactly one response response = multistatus.childOfType(davxml.Response) # Should have a response description (its contents are arbitrary) response.childOfType(davxml.ResponseDescription) # Requested property change was on / self.failUnless( response.childOfType(davxml.HRef) == "/", "Incorrect response URI: %s != /" % (response.childOfType(davxml.HRef),) ) # Requested one property change, so there should be exactly one property status propstat = response.childOfType(davxml.PropertyStatus) # And the contained property should be a SpiffyProperty self.failIf( propstat.childOfType(davxml.PropertyContainer).childOfType(SpiffyProperty) is None, "Not a SpiffyProperty in PROPPATCH property status: %s" % (propstat.toxml()) ) # And the status should be 200 self.failUnless( propstat.childOfType(davxml.Status).code == responsecode.OK, "Incorrect status code for PROPPATCH of property %s: %s != %s" % (propstat.childOfType(davxml.PropertyContainer).toxml(), propstat.childOfType(davxml.Status).code, responsecode.OK) ) patch = davxml.PropertyUpdate( davxml.Set( davxml.PropertyContainer( SpiffyProperty.fromString("This is a spiffy resource.") ) ) ) request = SimpleRequest(self.site, "PROPPATCH", "/") request.stream = MemoryStream(patch.toxml()) return self.send(request, check_patch_response) def test_PROPPATCH_liveprop(self): """ PROPPATCH on a live property """ prop = davxml.GETETag.fromString("some-etag-string") patch = davxml.PropertyUpdate(davxml.Set(davxml.PropertyContainer(prop))) return self._simple_PROPPATCH(patch, prop, responsecode.FORBIDDEN, "edit of live property") def test_PROPPATCH_exists_not(self): """ PROPPATCH remove a non-existant property """ prop = davxml.Timeout() # Timeout isn't a valid property, so it won't exist. patch = davxml.PropertyUpdate(davxml.Remove(davxml.PropertyContainer(prop))) return self._simple_PROPPATCH(patch, prop, responsecode.OK, "remove of non-existant property") def _simple_PROPPATCH(self, patch, prop, expected_code, what): def check_result(response): response = IResponse(response) if response.code != responsecode.MULTI_STATUS: self.fail("Incorrect response code for PROPPATCH (%s != %s)" % (response.code, responsecode.MULTI_STATUS)) return davXMLFromStream(response.stream).addCallback(check_xml) def check_xml(doc): response = doc.root_element.childOfType(davxml.Response) propstat = response.childOfType(davxml.PropertyStatus) self.failUnless( response.childOfType(davxml.HRef) == "/", "Incorrect response URI: %s != /" % (response.childOfType(davxml.HRef),) ) self.failIf( propstat.childOfType(davxml.PropertyContainer).childOfType(prop) is None, "Not a %s in PROPPATCH property status: %s" % (prop.sname(), propstat.toxml()) ) self.failUnless( propstat.childOfType(davxml.Status).code == expected_code, "Incorrect status code for PROPPATCH %s: %s != %s" % (what, propstat.childOfType(davxml.Status).code, expected_code) ) request = SimpleRequest(self.site, "PROPPATCH", "/") request.stream = MemoryStream(patch.toxml()) return self.send(request, check_result) class SpiffyProperty (davxml.WebDAVTextElement): namespace = "http://twistedmatrix.com/ns/private/tests" name = "spiffyproperty" calendarserver-5.2+dfsg/twext/web2/dav/test/test_copy.py0000644000175000017500000001601512263343324022442 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from hashlib import md5 import os import urllib import twext.web2.dav.test.util from twext.web2 import responsecode from twext.web2.test.test_server import SimpleRequest from twext.web2.dav.test.util import dircmp, serialize from twext.web2.dav.fileop import rmdir class COPY(twext.web2.dav.test.util.TestCase): """ COPY request """ # FIXME: # Check that properties are being copied def test_COPY_create(self): """ COPY to new resource. """ def test(response, path, isfile, sum, uri, depth, dst_path): if response.code != responsecode.CREATED: self.fail("Incorrect response code for COPY %s (depth=%r): %s != %s" % (uri, depth, response.code, responsecode.CREATED)) if response.headers.getHeader("location") is None: self.fail("Reponse to COPY %s (depth=%r) with CREATE status is missing location: header." % (uri, depth)) if os.path.isfile(path): if not os.path.isfile(dst_path): self.fail("COPY %s (depth=%r) produced no output file" % (uri, depth)) if not cmp(path, dst_path): self.fail("COPY %s (depth=%r) produced different file" % (uri, depth)) os.remove(dst_path) elif os.path.isdir(path): if not os.path.isdir(dst_path): self.fail("COPY %s (depth=%r) produced no output directory" % (uri, depth)) if depth in ("infinity", None): if dircmp(path, dst_path): self.fail("COPY %s (depth=%r) produced different directory" % (uri, depth)) elif depth == "0": for filename in os.listdir(dst_path): self.fail("COPY %s (depth=%r) shouldn't copy directory contents (eg. %s)" % (uri, depth, filename)) else: raise AssertionError("Unknown depth: %r" % (depth,)) rmdir(dst_path) else: self.fail("Source %s is neither a file nor a directory" % (path,)) return serialize(self.send, work(self, test)) def test_COPY_exists(self): """ COPY to existing resource. """ def test(response, path, isfile, sum, uri, depth, dst_path): if response.code != responsecode.PRECONDITION_FAILED: self.fail("Incorrect response code for COPY without overwrite %s: %s != %s" % (uri, response.code, responsecode.PRECONDITION_FAILED)) else: # FIXME: Check XML error code (2518bis) pass return serialize(self.send, work(self, test, overwrite=False)) def test_COPY_overwrite(self): """ COPY to existing resource with overwrite header. """ def test(response, path, isfile, sum, uri, depth, dst_path): if response.code != responsecode.NO_CONTENT: self.fail("Incorrect response code for COPY with overwrite %s: %s != %s" % (uri, response.code, responsecode.NO_CONTENT)) else: # FIXME: Check XML error code (2518bis) pass self.failUnless(os.path.exists(dst_path), "COPY didn't produce file: %s" % (dst_path,)) return serialize(self.send, work(self, test, overwrite=True)) def test_COPY_no_parent(self): """ COPY to resource with no parent. """ def test(response, path, isfile, sum, uri, depth, dst_path): if response.code != responsecode.CONFLICT: self.fail("Incorrect response code for COPY with no parent %s: %s != %s" % (uri, response.code, responsecode.CONFLICT)) else: # FIXME: Check XML error code (2518bis) pass return serialize(self.send, work(self, test, dst=os.path.join(self.docroot, "elvislives!"))) def work(self, test, overwrite=None, dst=None, depths=("0", "infinity", None)): if dst is None: dst = os.path.join(self.docroot, "dst") os.mkdir(dst) for basename in os.listdir(self.docroot): if basename == "dst": continue uri = urllib.quote("/" + basename) path = os.path.join(self.docroot, basename) isfile = os.path.isfile(path) sum = sumFile(path) dst_path = os.path.join(dst, basename) dst_uri = urllib.quote("/dst/" + basename) if not isfile: uri += "/" dst_uri += "/" if overwrite is not None: # Create a file at dst_path to create a conflict file(dst_path, "w").close() for depth in depths: def do_test(response, path=path, isfile=isfile, sum=sum, uri=uri, depth=depth, dst_path=dst_path): test(response, path, isfile, sum, uri, depth, dst_path) request = SimpleRequest(self.site, self.__class__.__name__, uri) request.headers.setHeader("destination", dst_uri) if depth is not None: request.headers.setHeader("depth", depth) if overwrite is not None: request.headers.setHeader("overwrite", overwrite) yield (request, do_test) def sumFile(path): m = md5() if os.path.isfile(path): f = file(path) try: m.update(f.read()) finally: f.close() elif os.path.isdir(path): for dir, subdirs, files in os.walk(path): for filename in files: m.update(filename) f = file(os.path.join(dir, filename)) try: m.update(f.read()) finally: f.close() for dirname in subdirs: m.update(dirname + "/") else: raise AssertionError() return m.digest() calendarserver-5.2+dfsg/twext/web2/dav/test/test_http.py0000644000175000017500000000670212263343324022451 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## import errno from twisted.python.failure import Failure from twext.web2 import responsecode from twext.web2.http import HTTPError from twext.web2.dav.http import ErrorResponse, statusForFailure import twext.web2.dav.test.util class HTTP(twext.web2.dav.test.util.TestCase): """ HTTP Utilities """ def test_statusForFailure_errno(self): """ statusForFailure() for exceptions with known errno values """ for ex_class in (IOError, OSError): for exception, result in ( (ex_class(errno.EACCES, "Permission denied" ), responsecode.FORBIDDEN), (ex_class(errno.EPERM , "Permission denied" ), responsecode.FORBIDDEN), (ex_class(errno.ENOSPC, "No space available"), responsecode.INSUFFICIENT_STORAGE_SPACE), (ex_class(errno.ENOENT, "No such file" ), responsecode.NOT_FOUND), ): self._check_exception(exception, result) def test_statusForFailure_HTTPError(self): """ statusForFailure() for HTTPErrors """ for code in responsecode.RESPONSES: self._check_exception(HTTPError(code), code) self._check_exception(HTTPError(ErrorResponse(code, ("http://twistedmatrix.com/", "bar"))), code) def test_statusForFailure_exception(self): """ statusForFailure() for known/unknown exceptions """ for exception, result in ( (NotImplementedError("Duh..."), responsecode.NOT_IMPLEMENTED), ): self._check_exception(exception, result) class UnknownException (Exception): pass try: self._check_exception(UnknownException(), None) except UnknownException: pass else: self.fail("Unknown exception should have re-raised.") def _check_exception(self, exception, result): try: raise exception except Exception: failure = Failure() status = statusForFailure(failure) self.failUnless( status == result, "Failure %r (%s) generated incorrect status code: %s != %s" % (failure, failure.value, status, result) ) else: raise AssertionError("We shouldn't be here.") calendarserver-5.2+dfsg/twext/web2/dav/test/test_resource.py0000644000175000017500000004114612263343324023322 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from twisted.internet.defer import DeferredList, waitForDeferred, deferredGenerator, succeed from twisted.cred.portal import Portal from twext.web2 import responsecode from twext.web2.http import HTTPError from twext.web2.auth import basic from twext.web2.server import Site from txdav.xml import element as davxml from twext.web2.dav.resource import DAVResource, AccessDeniedError, \ DAVPrincipalResource, DAVPrincipalCollectionResource, davPrivilegeSet from twext.web2.dav.auth import TwistedPasswordProperty, DavRealm, TwistedPropertyChecker, IPrincipal, AuthenticationWrapper from twext.web2.test.test_server import SimpleRequest from twext.web2.dav.test.util import InMemoryPropertyStore import twext.web2.dav.test.util class TestCase(twext.web2.dav.test.util.TestCase): def setUp(self): twext.web2.dav.test.util.TestCase.setUp(self) TestResource._cachedPropertyStores = {} class GenericDAVResource(TestCase): def setUp(self): TestCase.setUp(self) rootresource = TestResource(None, { "file1": TestResource("/file1"), "file2": AuthAllResource("/file2"), "dir1": TestResource("/dir1/", { "subdir1": TestResource("/dir1/subdir1/",{}) }), "dir2": AuthAllResource("/dir2/", { "file1": TestResource("/dir2/file1"), "file2": TestResource("/dir2/file2"), "subdir1": TestResource("/dir2/subdir1/", { "file1": TestResource("/dir2/subdir1/file1"), "file2": TestResource("/dir2/subdir1/file2") }) }) }) self.site = Site(rootresource) def test_findChildren(self): """ This test asserts that we have: 1) not found any unexpected children 2) found all expected children It does this for all depths C{"0"}, C{"1"}, and C{"infintiy"} """ expected_children = { "0": [], "1": [ "/file1", "/file2", "/dir1/", "/dir2/", ], "infinity": [ "/file1", "/file2", "/dir1/", "/dir1/subdir1/", "/dir2/", "/dir2/file1", "/dir2/file2", "/dir2/subdir1/", "/dir2/subdir1/file1", "/dir2/subdir1/file2", ], } request = SimpleRequest(self.site, "GET", "/") resource = waitForDeferred(request.locateResource("/")) yield resource resource = resource.getResult() def checkChildren(resource, uri): self.assertEquals(uri, resource.uri) if uri not in expected_children[depth]: unexpected_children.append(uri) else: found_children.append(uri) for depth in ["0", "1", "infinity"]: found_children = [] unexpected_children = [] fc = resource.findChildren(depth, request, checkChildren) completed = waitForDeferred(fc) yield completed completed.getResult() self.assertEquals( unexpected_children, [], "Found unexpected children: %r" % (unexpected_children,) ) expected_children[depth].sort() found_children.sort() self.assertEquals(expected_children[depth], found_children) test_findChildren = deferredGenerator(test_findChildren) def test_findChildrenWithPrivileges(self): """ This test revokes read privileges for the C{"/file2"} and C{"/dir2/"} resource to verify that we can not find them giving our unauthenticated privileges. """ expected_children = [ "/file1", "/dir1/", ] request = SimpleRequest(self.site, "GET", "/") resource = waitForDeferred(request.locateResource("/")) yield resource resource = resource.getResult() def checkChildren(resource, uri): self.assertEquals(uri, resource.uri) if uri not in expected_children: unexpected_children.append(uri) else: found_children.append(uri) found_children = [] unexpected_children = [] privileges = waitForDeferred(resource.currentPrivileges(request)) yield privileges privileges = privileges.getResult() fc = resource.findChildren("1", request, checkChildren, privileges) completed = waitForDeferred(fc) yield completed completed.getResult() self.assertEquals( unexpected_children, [], "Found unexpected children: %r" % (unexpected_children,) ) expected_children.sort() found_children.sort() self.assertEquals(expected_children, found_children) test_findChildrenWithPrivileges = deferredGenerator(test_findChildrenWithPrivileges) def test_findChildrenCallbackRaises(self): """ Verify that when the user callback raises an exception the completion deferred returned by findChildren errbacks TODO: Verify that the user callback doesn't get called subsequently """ def raiseOnChild(resource, uri): raise Exception("Oh no!") def findChildren(resource): return self.assertFailure( resource.findChildren("infinity", request, raiseOnChild), Exception ) request = SimpleRequest(self.site, "GET", "/") d = request.locateResource("/").addCallback(findChildren) return d class AccessTests(TestCase): def setUp(self): TestCase.setUp(self) gooduser = TestDAVPrincipalResource("/users/gooduser") gooduser.writeDeadProperty(TwistedPasswordProperty("goodpass")) baduser = TestDAVPrincipalResource("/users/baduser") baduser.writeDeadProperty(TwistedPasswordProperty("badpass")) rootresource = TestPrincipalsCollection("/", { "users": TestResource("/users/", {"gooduser": gooduser, "baduser": baduser}) }) protected = TestResource( "/protected", principalCollections=[rootresource]) protected.setAccessControlList(davxml.ACL( davxml.ACE( davxml.Principal(davxml.HRef("/users/gooduser")), davxml.Grant(davxml.Privilege(davxml.All())), davxml.Protected() ) )) rootresource.children["protected"] = protected portal = Portal(DavRealm()) portal.registerChecker(TwistedPropertyChecker()) credentialFactories = (basic.BasicCredentialFactory(""),) loginInterfaces = (IPrincipal,) self.rootresource = rootresource self.site = Site(AuthenticationWrapper( self.rootresource, portal, credentialFactories, credentialFactories, loginInterfaces, )) def checkSecurity(self, request): """ Locate the resource named by the given request's URI, then authorize it for the 'Read' permission. """ d = request.locateResource(request.uri) d.addCallback(lambda r: r.authorize(request, (davxml.Read(),))) return d def assertErrorResponse(self, error, expectedcode, otherExpectations=lambda err: None): self.assertEquals(error.response.code, expectedcode) otherExpectations(error) def test_checkPrivileges(self): """ DAVResource.checkPrivileges() """ ds = [] authAllResource = AuthAllResource() requested_access = (davxml.All(),) site = Site(authAllResource) def expectError(failure): failure.trap(AccessDeniedError) errors = failure.value.errors self.failUnless(len(errors) == 1) subpath, denials = errors[0] self.failUnless(subpath is None) self.failUnless( tuple(denials) == requested_access, "%r != %r" % (tuple(denials), requested_access) ) def expectOK(result): self.failUnlessEquals(result, None) def _checkPrivileges(resource): d = resource.checkPrivileges(request, requested_access) return d # No auth; should deny request = SimpleRequest(site, "GET", "/") d = request.locateResource("/").addCallback(_checkPrivileges).addErrback(expectError) ds.append(d) # Has auth; should allow request = SimpleRequest(site, "GET", "/") request.authnUser = davxml.Principal(davxml.HRef("/users/d00d")) request.authzUser = davxml.Principal(davxml.HRef("/users/d00d")) d = request.locateResource("/") d.addCallback(_checkPrivileges) d.addCallback(expectOK) ds.append(d) return DeferredList(ds) def test_authorize(self): """ Authorizing a known user with the correct password will not raise an exception, indicating that the user is properly authorized given their credentials. """ request = SimpleRequest(self.site, "GET", "/protected") request.headers.setHeader( "authorization", ("basic", "gooduser:goodpass".encode("base64"))) return self.checkSecurity(request) def test_badUsernameOrPassword(self): request = SimpleRequest(self.site, "GET", "/protected") request.headers.setHeader( "authorization", ("basic", "gooduser:badpass".encode("base64")) ) d = self.assertFailure(self.checkSecurity(request), HTTPError) def expectWwwAuth(err): self.failUnless(err.response.headers.hasHeader("WWW-Authenticate"), "No WWW-Authenticate header present.") d.addCallback(self.assertErrorResponse, responsecode.UNAUTHORIZED, expectWwwAuth) return d def test_lacksPrivileges(self): request = SimpleRequest(self.site, "GET", "/protected") request.headers.setHeader( "authorization", ("basic", "baduser:badpass".encode("base64")) ) d = self.assertFailure(self.checkSecurity(request), HTTPError) d.addCallback(self.assertErrorResponse, responsecode.FORBIDDEN) return d ## # Utilities ## class TestResource (DAVResource): """A simple test resource used for creating trees of DAV Resources """ _cachedPropertyStores = {} acl = davxml.ACL( davxml.ACE( davxml.Principal(davxml.All()), davxml.Grant(davxml.Privilege(davxml.All())), davxml.Protected(), ) ) def __init__(self, uri=None, children=None, principalCollections=()): """ @param uri: A string respresenting the URI of the given resource @param children: a dictionary of names to Resources """ DAVResource.__init__(self, principalCollections=principalCollections) self.children = children self.uri = uri def deadProperties(self): """ Retrieve deadProperties from a special place in memory """ if not hasattr(self, "_dead_properties"): dp = TestResource._cachedPropertyStores.get(self.uri) if dp is None: TestResource._cachedPropertyStores[self.uri] = InMemoryPropertyStore(self) dp = TestResource._cachedPropertyStores[self.uri] self._dead_properties = dp return self._dead_properties def isCollection(self): return self.children is not None def listChildren(self): return self.children.keys() def supportedPrivileges(self, request): return succeed(davPrivilegeSet) def currentPrincipal(self, request): if hasattr(request, "authzUser"): return request.authzUser else: return davxml.Principal(davxml.Unauthenticated()) def locateChild(self, request, segments): child = segments[0] if child == "": return self, segments[1:] elif child in self.children: return self.children[child], segments[1:] else: raise HTTPError(404) def setAccessControlList(self, acl): self.acl = acl def accessControlList(self, request, **kwargs): return succeed(self.acl) class TestPrincipalsCollection(DAVPrincipalCollectionResource, TestResource): """ A full implementation of L{IDAVPrincipalCollectionResource}, implemented as a L{TestResource} which assumes a single L{TestResource} child named 'users'. """ def __init__(self, url, children): DAVPrincipalCollectionResource.__init__(self, url) TestResource.__init__(self, url, children, principalCollections=(self,)) def principalForUser(self, user): """ @see L{IDAVPrincipalCollectionResource.principalForUser}. """ return self.principalForShortName('users', user) def principalForAuthID(self, creds): """ Retrieve the principal for the authentication identifier from a set of credentials. Note that although this method is not actually invoked anywhere in web2.dav, this test class is currently imported by CalendarServer, which requires this method. @param creds: credentials which identify a user @type creds: L{twisted.cred.credentials.IUsernameHashedPassword} or L{twisted.cred.credentials.IUsernamePassword} @return: a DAV principal resource representing a user. @rtype: L{IDAVPrincipalResource} or C{NoneType} """ # XXX either move this to CalendarServer entirely or document it on # IDAVPrincipalCollectionResource return self.principalForShortName('users', creds.username) def principalForShortName(self, type, shortName): """ Retrieve the principal of a given type from this resource. Note that although this method is not actually invoked anywhere (aside from test methods) in web2.dav, this test class is currently imported by CalendarServer, which requires this method. @param: a short string (such as 'users' or 'groups') identifying both the principal type, and the name of a resource in the 'children' dictionary, which itself is a L{TestResource} with L{IDAVPrincipalCollectionResource} children. @return: a DAV principal resource of the given type with the given name. @rtype: L{IDAVPrincipalResource} or C{NoneType} """ # XXX either move this to CalendarServer entirely or document it on # IDAVPrincipalCollectionResource typeResource = self.children.get(type, None) user = None if typeResource: user = typeResource.children.get(shortName, None) return user class AuthAllResource (TestResource): """ Give Authenticated principals all privileges and deny everyone else. """ acl = davxml.ACL( davxml.ACE( davxml.Principal(davxml.Authenticated()), davxml.Grant(davxml.Privilege(davxml.All())), davxml.Protected(), ) ) class TestDAVPrincipalResource(DAVPrincipalResource, TestResource): """ Get deadProperties from TestResource """ def principalURL(self): return self.uri calendarserver-5.2+dfsg/twext/web2/dav/test/test_lock.py0000644000175000017500000000342612263343324022422 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from twisted.trial.unittest import SkipTest import twext.web2.dav.test.util class LOCK_UNLOCK(twext.web2.dav.test.util.TestCase): """ LOCK, UNLOCK requests """ # FIXME: # Check PUT # Check POST # Check PROPPATCH # Check LOCK # Check UNLOCK # Check MOVE, COPY # Check DELETE # Check MKCOL # Check null resource # Check collections # Check depth # Check If header # Refresh lock def test_LOCK_UNLOCK(self): """ LOCK, UNLOCK request """ raise SkipTest("test unimplemented") test_LOCK_UNLOCK.todo = "LOCK/UNLOCK unimplemented" calendarserver-5.2+dfsg/twext/web2/dav/test/test_report_expand.py0000644000175000017500000000276412263343324024350 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from twisted.trial.unittest import SkipTest import twext.web2.dav.test.util class REPORT_expand(twext.web2.dav.test.util.TestCase): """ DAV:expand-property REPORT request """ def test_REPORT_expand_property(self): """ DAV:expand-property REPORT request. """ raise SkipTest("test unimplemeted") calendarserver-5.2+dfsg/twext/web2/dav/test/test_util.py0000644000175000017500000001320212263343324022440 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from twisted.trial import unittest from twext.web2.dav import util class Utilities(unittest.TestCase): """ Utilities. """ def test_normalizeURL(self): """ normalizeURL() """ self.assertEquals(util.normalizeURL("http://server//foo"), "http://server/foo") self.assertEquals(util.normalizeURL("http://server/foo/.."), "http://server/") self.assertEquals(util.normalizeURL("/foo/bar/..//"), "/foo") self.assertEquals(util.normalizeURL("/foo/bar/.//"), "/foo/bar") self.assertEquals(util.normalizeURL("//foo///bar/../baz"), "/foo/baz") self.assertEquals(util.normalizeURL("//foo///bar/./baz"), "/foo/bar/baz") self.assertEquals(util.normalizeURL("///../"), "/") self.assertEquals(util.normalizeURL("/.."), "/") def test_joinURL(self): """ joinURL() """ self.assertEquals(util.joinURL("http://server/foo/"), "http://server/foo/") self.assertEquals(util.joinURL("http://server/foo", "/bar"), "http://server/foo/bar") self.assertEquals(util.joinURL("http://server/foo", "bar"), "http://server/foo/bar") self.assertEquals(util.joinURL("http://server/foo/", "/bar"), "http://server/foo/bar") self.assertEquals(util.joinURL("http://server/foo/", "/bar/.."), "http://server/foo") self.assertEquals(util.joinURL("http://server/foo/", "/bar/."), "http://server/foo/bar") self.assertEquals(util.joinURL("http://server/foo/", "/bar/../"), "http://server/foo/") self.assertEquals(util.joinURL("http://server/foo/", "/bar/./"), "http://server/foo/bar/") self.assertEquals(util.joinURL("http://server/foo/../", "/bar"), "http://server/bar") self.assertEquals(util.joinURL("/foo/"), "/foo/") self.assertEquals(util.joinURL("/foo", "/bar"), "/foo/bar") self.assertEquals(util.joinURL("/foo", "bar"), "/foo/bar") self.assertEquals(util.joinURL("/foo/", "/bar"), "/foo/bar") self.assertEquals(util.joinURL("/foo/", "/bar/.."), "/foo") self.assertEquals(util.joinURL("/foo/", "/bar/."), "/foo/bar") self.assertEquals(util.joinURL("/foo/", "/bar/../"), "/foo/") self.assertEquals(util.joinURL("/foo/", "/bar/./"), "/foo/bar/") self.assertEquals(util.joinURL("/foo/../", "/bar"), "/bar") self.assertEquals(util.joinURL("/foo", "/../"), "/") self.assertEquals(util.joinURL("/foo", "/./"), "/foo/") def test_parentForURL(self): """ parentForURL() """ self.assertEquals(util.parentForURL("http://server/"), None) self.assertEquals(util.parentForURL("http://server//"), None) self.assertEquals(util.parentForURL("http://server/foo/.."), None) self.assertEquals(util.parentForURL("http://server/foo/../"), None) self.assertEquals(util.parentForURL("http://server/foo/."), "http://server/") self.assertEquals(util.parentForURL("http://server/foo/./"), "http://server/") self.assertEquals(util.parentForURL("http://server/foo"), "http://server/") self.assertEquals(util.parentForURL("http://server//foo"), "http://server/") self.assertEquals(util.parentForURL("http://server/foo/bar/.."), "http://server/") self.assertEquals(util.parentForURL("http://server/foo/bar/."), "http://server/foo/") self.assertEquals(util.parentForURL("http://server/foo/bar"), "http://server/foo/") self.assertEquals(util.parentForURL("http://server/foo/bar/"), "http://server/foo/") self.assertEquals(util.parentForURL("http://server/foo/bar?x=1&y=2"), "http://server/foo/") self.assertEquals(util.parentForURL("http://server/foo/bar/?x=1&y=2"), "http://server/foo/") self.assertEquals(util.parentForURL("/"), None) self.assertEquals(util.parentForURL("/foo/.."), None) self.assertEquals(util.parentForURL("/foo/../"), None) self.assertEquals(util.parentForURL("/foo/."), "/") self.assertEquals(util.parentForURL("/foo/./"), "/") self.assertEquals(util.parentForURL("/foo"), "/") self.assertEquals(util.parentForURL("/foo"), "/") self.assertEquals(util.parentForURL("/foo/bar/.."), "/") self.assertEquals(util.parentForURL("/foo/bar/."), "/foo/") self.assertEquals(util.parentForURL("/foo/bar"), "/foo/") self.assertEquals(util.parentForURL("/foo/bar/"), "/foo/") self.assertEquals(util.parentForURL("/foo/bar?x=1&y=2"), "/foo/") self.assertEquals(util.parentForURL("/foo/bar/?x=1&y=2"), "/foo/") calendarserver-5.2+dfsg/twext/web2/dav/test/test_quota.py0000644000175000017500000001546712263343324022633 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from twext.web2 import responsecode from twext.web2.iweb import IResponse from twext.web2.stream import FileStream import twext.web2.dav.test.util from twext.web2.test.test_server import SimpleRequest from twext.web2.dav.test.util import Site from txdav.xml import element as davxml import os class QuotaBase(twext.web2.dav.test.util.TestCase): def createDocumentRoot(self): docroot = self.mktemp() os.mkdir(docroot) rootresource = self.resource_class(docroot) rootresource.setAccessControlList(self.grantInherit(davxml.All())) self.site = Site(rootresource) self.site.resource.setQuotaRoot(None, 100000) return docroot def checkQuota(self, value): def _defer(quota): self.assertEqual(quota, value) d = self.site.resource.currentQuotaUse(None) d.addCallback(_defer) return d class QuotaEmpty(QuotaBase): def test_Empty_Quota(self): return self.checkQuota(0) class QuotaPUT(QuotaBase): def test_Quota_PUT(self): """ Quota change on PUT """ dst_uri = "/dst" def checkResult(response): response = IResponse(response) if response.code != responsecode.CREATED: self.fail("Incorrect response code for PUT (%s != %s)" % (response.code, responsecode.CREATED)) return self.checkQuota(100) request = SimpleRequest(self.site, "PUT", dst_uri) request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb")) return self.send(request, checkResult) class QuotaDELETE(QuotaBase): def test_Quota_DELETE(self): """ Quota change on DELETE """ dst_uri = "/dst" def checkPUTResult(response): response = IResponse(response) if response.code != responsecode.CREATED: self.fail("Incorrect response code for PUT (%s != %s)" % (response.code, responsecode.CREATED)) def doDelete(_ignore): def checkDELETEResult(response): response = IResponse(response) if response.code != responsecode.NO_CONTENT: self.fail("Incorrect response code for PUT (%s != %s)" % (response.code, responsecode.NO_CONTENT)) return self.checkQuota(0) request = SimpleRequest(self.site, "DELETE", dst_uri) return self.send(request, checkDELETEResult) d = self.checkQuota(100) d.addCallback(doDelete) return d request = SimpleRequest(self.site, "PUT", dst_uri) request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb")) return self.send(request, checkPUTResult) class OverQuotaPUT(QuotaBase): def test_Quota_PUT(self): """ Quota change on PUT """ dst_uri = "/dst" self.site.resource.setQuotaRoot(None, 90) def checkResult(response): response = IResponse(response) if response.code != responsecode.INSUFFICIENT_STORAGE_SPACE: self.fail("Incorrect response code for PUT (%s != %s)" % (response.code, responsecode.INSUFFICIENT_STORAGE_SPACE)) return self.checkQuota(0) request = SimpleRequest(self.site, "PUT", dst_uri) request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb")) return self.send(request, checkResult) class QuotaOKAdjustment(QuotaBase): def test_Quota_OK_Adjustment(self): """ Quota adjustment OK """ dst_uri = "/dst" def checkPUTResult(response): response = IResponse(response) if response.code != responsecode.CREATED: self.fail("Incorrect response code for PUT (%s != %s)" % (response.code, responsecode.CREATED)) def doOKAdjustment(_ignore): def checkAdjustmentResult(_ignore): return self.checkQuota(10) d = self.site.resource.quotaSizeAdjust(None, -90) d.addCallback(checkAdjustmentResult) return d d = self.checkQuota(100) d.addCallback(doOKAdjustment) return d request = SimpleRequest(self.site, "PUT", dst_uri) request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb")) return self.send(request, checkPUTResult) class QuotaBadAdjustment(QuotaBase): def test_Quota_Bad_Adjustment(self): """ Quota adjustment too much """ dst_uri = "/dst" def checkPUTResult(response): response = IResponse(response) if response.code != responsecode.CREATED: self.fail("Incorrect response code for PUT (%s != %s)" % (response.code, responsecode.CREATED)) def doBadAdjustment(_ignore): def checkAdjustmentResult(_ignore): return self.checkQuota(100) d = self.site.resource.quotaSizeAdjust(None, -200) d.addCallback(checkAdjustmentResult) return d d = self.checkQuota(100) d.addCallback(doBadAdjustment) return d request = SimpleRequest(self.site, "PUT", dst_uri) request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb")) return self.send(request, checkPUTResult) calendarserver-5.2+dfsg/twext/web2/dav/test/test_xattrprops.py0000644000175000017500000004145112022736174023722 0ustar rahulrahul# Copyright (c) 2009 Twisted Matrix Laboratories. # See LICENSE for details. """ Tests for L{twext.web2.dav.xattrprops}. """ from zlib import compress, decompress from pickle import dumps from cPickle import UnpicklingError from twext.python.filepath import CachingFilePath as FilePath from twisted.trial.unittest import TestCase from twext.web2.responsecode import NOT_FOUND, INTERNAL_SERVER_ERROR from twext.web2.responsecode import FORBIDDEN from twext.web2.http import HTTPError from twext.web2.dav.static import DAVFile from txdav.xml.element import Depth, WebDAVDocument try: from twext.web2.dav.xattrprops import xattrPropertyStore except ImportError: xattrPropertyStore = None else: from xattr import xattr class ExtendedAttributesPropertyStoreTests(TestCase): """ Tests for L{xattrPropertyStore}. """ if xattrPropertyStore is None: skip = "xattr package missing, cannot test xattr property store" def setUp(self): """ Create a resource and a xattr property store for it. """ self.resourcePath = FilePath(self.mktemp()) self.resourcePath.setContent("") self.attrs = xattr(self.resourcePath.path) self.resource = DAVFile(self.resourcePath.path) self.propertyStore = xattrPropertyStore(self.resource) def test_getAbsent(self): """ L{xattrPropertyStore.get} raises L{HTTPError} with a I{NOT FOUND} response code if passed the name of an attribute for which there is no corresponding value. """ error = self.assertRaises(HTTPError, self.propertyStore.get, ("foo", "bar")) self.assertEquals(error.response.code, NOT_FOUND) def _forbiddenTest(self, method): # Remove access to the directory containing the file so that getting # extended attributes from it fails with EPERM. self.resourcePath.parent().chmod(0) # Make sure to restore access to it later so that it can be deleted # after the test run is finished. self.addCleanup(self.resourcePath.parent().chmod, 0700) # Try to get a property from it - and fail. document = self._makeValue() error = self.assertRaises( HTTPError, getattr(self.propertyStore, method), document.root_element.qname()) # Make sure that the status is FORBIDDEN, a roughly reasonable mapping # of the EPERM failure. self.assertEquals(error.response.code, FORBIDDEN) def _missingTest(self, method): # Remove access to the directory containing the file so that getting # extended attributes from it fails with EPERM. self.resourcePath.parent().chmod(0) # Make sure to restore access to it later so that it can be deleted # after the test run is finished. self.addCleanup(self.resourcePath.parent().chmod, 0700) # Try to get a property from it - and fail. document = self._makeValue() error = self.assertRaises( HTTPError, getattr(self.propertyStore, method), document.root_element.qname()) # Make sure that the status is FORBIDDEN, a roughly reasonable mapping # of the EPERM failure. self.assertEquals(error.response.code, FORBIDDEN) def test_getErrors(self): """ If there is a problem getting the specified property (aside from the property not existing), L{xattrPropertyStore.get} raises L{HTTPError} with a status code which is determined by the nature of the problem. """ self._forbiddenTest('get') def test_getMissing(self): """ Test missing file. """ resourcePath = FilePath(self.mktemp()) resource = DAVFile(resourcePath.path) propertyStore = xattrPropertyStore(resource) # Try to get a property from it - and fail. document = self._makeValue() error = self.assertRaises( HTTPError, propertyStore.get, document.root_element.qname()) # Make sure that the status is NOT FOUND. self.assertEquals(error.response.code, NOT_FOUND) def _makeValue(self, uid=None): """ Create and return any old WebDAVDocument for use by the get tests. """ element = Depth(uid if uid is not None else "0") document = WebDAVDocument(element) return document def _setValue(self, originalDocument, value, uid=None): element = originalDocument.root_element attribute = ( self.propertyStore.deadPropertyXattrPrefix + (uid if uid is not None else "") + element.sname()) self.attrs[attribute] = value def _getValue(self, originalDocument, uid=None): element = originalDocument.root_element attribute = ( self.propertyStore.deadPropertyXattrPrefix + (uid if uid is not None else "") + element.sname()) return self.attrs[attribute] def _checkValue(self, originalDocument, uid=None): property = originalDocument.root_element.qname() # Try to load it via xattrPropertyStore.get loadedDocument = self.propertyStore.get(property, uid) # XXX Why isn't this a WebDAVDocument? self.assertIsInstance(loadedDocument, Depth) self.assertEquals(str(loadedDocument), uid if uid else "0") def test_getXML(self): """ If there is an XML document associated with the property name passed to L{xattrPropertyStore.get}, that value is parsed into a L{WebDAVDocument}, the root element of which C{get} then returns. """ document = self._makeValue() self._setValue(document, document.toxml()) self._checkValue(document) def test_getCompressed(self): """ If there is a compressed value associated with the property name passed to L{xattrPropertyStore.get}, that value is decompressed and parsed into a L{WebDAVDocument}, the root element of which C{get} then returns. """ document = self._makeValue() self._setValue(document, compress(document.toxml())) self._checkValue(document) def test_getPickled(self): """ If there is a pickled document associated with the property name passed to L{xattrPropertyStore.get}, that value is unpickled into a L{WebDAVDocument}, the root element of which is returned. """ document = self._makeValue() self._setValue(document, dumps(document)) self._checkValue(document) def test_getUpgradeXML(self): """ If the value associated with the property name passed to L{xattrPropertyStore.get} is an uncompressed XML document, it is upgraded on access by compressing it. """ document = self._makeValue() originalValue = document.toxml() self._setValue(document, originalValue) self._checkValue(document) self.assertEquals( decompress(self._getValue(document)), document.root_element.toxml(pretty=False)) def test_getUpgradeCompressedPickle(self): """ If the value associated with the property name passed to L{xattrPropertyStore.get} is a compressed pickled document, it is upgraded on access to the compressed XML format. """ document = self._makeValue() self._setValue(document, compress(dumps(document))) self._checkValue(document) self.assertEquals( decompress(self._getValue(document)), document.root_element.toxml(pretty=False)) def test_getInvalid(self): """ If the value associated with the property name passed to L{xattrPropertyStore.get} cannot be interpreted, an error is logged and L{HTTPError} is raised with the I{INTERNAL SERVER ERROR} response code. """ document = self._makeValue() self._setValue( document, "random garbage goes here! \0 that nul is definitely garbage") property = document.root_element.qname() error = self.assertRaises(HTTPError, self.propertyStore.get, property) self.assertEquals(error.response.code, INTERNAL_SERVER_ERROR) self.assertEquals( len(self.flushLoggedErrors(UnpicklingError)), 1) def test_set(self): """ L{xattrPropertyStore.set} accepts a L{WebDAVElement} and stores a compressed XML document representing it in an extended attribute. """ document = self._makeValue() self.propertyStore.set(document.root_element) self.assertEquals( decompress(self._getValue(document)), document.root_element.toxml(pretty=False)) def test_delete(self): """ L{xattrPropertyStore.delete} deletes the named property. """ document = self._makeValue() self.propertyStore.set(document.root_element) self.propertyStore.delete(document.root_element.qname()) self.assertRaises(KeyError, self._getValue, document) def test_deleteNonExistent(self): """ L{xattrPropertyStore.delete} does nothing if passed a property which has no value. """ document = self._makeValue() self.propertyStore.delete(document.root_element.qname()) self.assertRaises(KeyError, self._getValue, document) def test_deleteErrors(self): """ If there is a problem deleting the specified property (aside from the property not existing), L{xattrPropertyStore.delete} raises L{HTTPError} with a status code which is determined by the nature of the problem. """ # Remove the file so that deleting extended attributes of it fails with # EEXIST. self.resourcePath.remove() # Try to delete a property from it - and fail. document = self._makeValue() error = self.assertRaises( HTTPError, self.propertyStore.delete, document.root_element.qname()) # Make sure that the status is NOT FOUND, a roughly reasonable mapping # of the EEXIST failure. self.assertEquals(error.response.code, NOT_FOUND) def test_contains(self): """ L{xattrPropertyStore.contains} returns C{True} if the given property has a value, C{False} otherwise. """ document = self._makeValue() self.assertFalse( self.propertyStore.contains(document.root_element.qname())) self._setValue(document, document.toxml()) self.assertTrue( self.propertyStore.contains(document.root_element.qname())) def test_containsError(self): """ If there is a problem checking if the specified property exists (aside from the property not existing), L{xattrPropertyStore.contains} raises L{HTTPError} with a status code which is determined by the nature of the problem. """ self._forbiddenTest('contains') def test_containsMissing(self): """ Test missing file. """ resourcePath = FilePath(self.mktemp()) resource = DAVFile(resourcePath.path) propertyStore = xattrPropertyStore(resource) # Try to get a property from it - and fail. document = self._makeValue() self.assertFalse(propertyStore.contains(document.root_element.qname())) def test_list(self): """ L{xattrPropertyStore.list} returns a C{list} of property names associated with the wrapped file. """ prefix = self.propertyStore.deadPropertyXattrPrefix self.attrs[prefix + '{foo}bar'] = 'baz' self.attrs[prefix + '{bar}baz'] = 'quux' self.assertEquals( set(self.propertyStore.list()), set([(u'foo', u'bar'), (u'bar', u'baz')])) def test_listError(self): """ If there is a problem checking if the specified property exists (aside from the property not existing), L{xattrPropertyStore.contains} raises L{HTTPError} with a status code which is determined by the nature of the problem. """ # Remove access to the directory containing the file so that getting # extended attributes from it fails with EPERM. self.resourcePath.parent().chmod(0) # Make sure to restore access to it later so that it can be deleted # after the test run is finished. self.addCleanup(self.resourcePath.parent().chmod, 0700) # Try to get a property from it - and fail. self._makeValue() error = self.assertRaises(HTTPError, self.propertyStore.list) # Make sure that the status is FORBIDDEN, a roughly reasonable mapping # of the EPERM failure. self.assertEquals(error.response.code, FORBIDDEN) def test_listMissing(self): """ Test missing file. """ resourcePath = FilePath(self.mktemp()) resource = DAVFile(resourcePath.path) propertyStore = xattrPropertyStore(resource) # Try to get a property from it - and fail. self.assertEqual(propertyStore.list(), []) def test_get_uids(self): """ L{xattrPropertyStore.get} accepts a L{WebDAVElement} and stores a compressed XML document representing it in an extended attribute. """ for uid in (None, "123", "456",): document = self._makeValue(uid) self._setValue(document, document.toxml(), uid=uid) for uid in (None, "123", "456",): document = self._makeValue(uid) self._checkValue(document, uid=uid) def test_set_uids(self): """ L{xattrPropertyStore.set} accepts a L{WebDAVElement} and stores a compressed XML document representing it in an extended attribute. """ for uid in (None, "123", "456",): document = self._makeValue(uid) self.propertyStore.set(document.root_element, uid=uid) self.assertEquals( decompress(self._getValue(document, uid)), document.root_element.toxml(pretty=False)) def test_delete_uids(self): """ L{xattrPropertyStore.set} accepts a L{WebDAVElement} and stores a compressed XML document representing it in an extended attribute. """ for delete_uid in (None, "123", "456",): for uid in (None, "123", "456",): document = self._makeValue(uid) self.propertyStore.set(document.root_element, uid=uid) self.propertyStore.delete(document.root_element.qname(), uid=delete_uid) self.assertRaises(KeyError, self._getValue, document, uid=delete_uid) for uid in (None, "123", "456",): if uid == delete_uid: continue document = self._makeValue(uid) self.assertEquals( decompress(self._getValue(document, uid)), document.root_element.toxml(pretty=False)) def test_contains_uids(self): """ L{xattrPropertyStore.contains} returns C{True} if the given property has a value, C{False} otherwise. """ for uid in (None, "123", "456",): document = self._makeValue(uid) self.assertFalse( self.propertyStore.contains(document.root_element.qname(), uid=uid)) self._setValue(document, document.toxml(), uid=uid) self.assertTrue( self.propertyStore.contains(document.root_element.qname(), uid=uid)) def test_list_uids(self): """ L{xattrPropertyStore.list} returns a C{list} of property names associated with the wrapped file. """ prefix = self.propertyStore.deadPropertyXattrPrefix for uid in (None, "123", "456",): user = uid if uid is not None else "" self.attrs[prefix + '%s{foo}bar' % (user,)] = 'baz%s' % (user,) self.attrs[prefix + '%s{bar}baz' % (user,)] = 'quux%s' % (user,) self.attrs[prefix + '%s{moo}mar%s' % (user, user,)] = 'quux%s' % (user,) for uid in (None, "123", "456",): user = uid if uid is not None else "" self.assertEquals( set(self.propertyStore.list(uid)), set([ (u'foo', u'bar'), (u'bar', u'baz'), (u'moo', u'mar%s' % (user,)), ])) self.assertEquals( set(self.propertyStore.list(filterByUID=False)), set([ (u'foo', u'bar', None), (u'bar', u'baz', None), (u'moo', u'mar', None), (u'foo', u'bar', "123"), (u'bar', u'baz', "123"), (u'moo', u'mar123', "123"), (u'foo', u'bar', "456"), (u'bar', u'baz', "456"), (u'moo', u'mar456', "456"), ])) calendarserver-5.2+dfsg/twext/web2/dav/test/test_pipeline.py0000644000175000017500000000575512263343324023306 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## import sys, os from twisted.internet import utils from twext.web2.test import test_server from twext.web2 import resource from twext.web2 import http from twext.web2.test import test_http from twisted.internet.defer import waitForDeferred, deferredGenerator from twisted.python import util class Pipeline(test_server.BaseCase): """ Pipelined request """ class TestResource(resource.LeafResource): def render(self, req): return http.Response(stream="Host:%s, Path:%s"%(req.host, req.path)) def setUp(self): self.root = self.TestResource() def chanrequest(self, root, uri, length, headers, method, version, prepath, content): self.cr = super(Pipeline, self).chanrequest(root, uri, length, headers, method, version, prepath, content) return self.cr def test_root(self): def _testStreamRead(x): self.assertTrue(self.cr.request.stream.length == 0) return self.assertResponse( (self.root, 'http://host/path', {"content-type":"text/plain",}, "PUT", None, '', "This is some text."), (405, {}, None)).addCallback(_testStreamRead) class SSLPipeline(test_http.SSLServerTest): @deferredGenerator def testAdvancedWorkingness(self): args = ('-u', util.sibpath(__file__, "tworequest_client.py"), "basic", str(self.port), self.type) d = waitForDeferred(utils.getProcessOutputAndValue(sys.executable, args=args, env=os.environ)) yield d; out,err,code = d.getResult() self.assertEquals(code, 0, "Error output:\n%s" % (err,)) self.assertEquals(out, "HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\nHTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n") calendarserver-5.2+dfsg/twext/web2/dav/test/test_options.py0000644000175000017500000000421412263343324023161 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from twext.web2.iweb import IResponse import twext.web2.dav.test.util from twext.web2.test.test_server import SimpleRequest class OPTIONS(twext.web2.dav.test.util.TestCase): """ OPTIONS request """ def test_DAV1(self): """ DAV level 1 """ return self._test_level("1") def test_DAV2(self): """ DAV level 2 """ return self._test_level("2") test_DAV2.todo = "DAV level 2 unimplemented" def test_ACL(self): """ DAV ACL """ return self._test_level("access-control") def _test_level(self, level): def doTest(response): response = IResponse(response) dav = response.headers.getHeader("dav") if not dav: self.fail("no DAV header: %s" % (response.headers,)) self.assertIn(level, dav, "no DAV level %s header" % (level,)) return response return self.send(SimpleRequest(self.site, "OPTIONS", "/"), doTest) calendarserver-5.2+dfsg/twext/web2/dav/test/__init__.py0000644000175000017500000000227612263343324022174 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ Tests for twext.web2.dav. """ calendarserver-5.2+dfsg/twext/web2/dav/test/util.py0000644000175000017500000002623012263343324021406 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## import os from urllib import quote as url_quote from filecmp import dircmp as DirCompare from tempfile import mkdtemp from shutil import copy from twisted.trial import unittest from twisted.internet import address from twisted.internet.defer import Deferred from twext.python.log import Logger from twext.web2.http import HTTPError, StatusResponse from twext.web2 import responsecode, server from twext.web2 import http_headers from twext.web2 import stream from twext.web2.dav.resource import TwistedACLInheritable from twext.web2.dav.static import DAVFile from twext.web2.dav.util import joinURL from txdav.xml import element from txdav.xml.base import encodeXMLName from twext.web2.http_headers import MimeType from twext.web2.dav.util import allDataFromStream log = Logger() class SimpleRequest(server.Request): """ A L{SimpleRequest} can be used in cases where a L{server.Request} object is necessary but it is beneficial to bypass the concrete transport (and associated logic with the C{chanRequest} attribute). """ clientproto = (1, 1) def __init__(self, site, method, uri, headers=None, content=None): if not headers: headers = http_headers.Headers(headers) super(SimpleRequest, self).__init__( site=site, chanRequest=None, command=method, path=uri, version=self.clientproto, contentLength=len(content or ''), headers=headers) self.stream = stream.MemoryStream(content or '') self.remoteAddr = address.IPv4Address('TCP', '127.0.0.1', 0) self._parseURL() self.host = 'localhost' self.port = 8080 def writeResponse(self, response): if self.chanRequest: self.chanRequest.writeHeaders(response.code, response.headers) return response class InMemoryPropertyStore (object): """ A dead property store for keeping properties in memory DO NOT USE OUTSIDE OF UNIT TESTS! """ def __init__(self, resource): self._dict = {} def get(self, qname): try: property = self._dict[qname] except KeyError: raise HTTPError(StatusResponse( responsecode.NOT_FOUND, "No such property: %s" % (encodeXMLName(*qname),) )) doc = element.WebDAVDocument.fromString(property) return doc.root_element def set(self, property): self._dict[property.qname()] = property.toxml() def delete(self, qname): try: del(self._dict[qname]) except KeyError: pass def contains(self, qname): return qname in self._dict def list(self): return self._dict.keys() class TestFile (DAVFile): _cachedPropertyStores = {} def deadProperties(self): if not hasattr(self, "_dead_properties"): dp = TestFile._cachedPropertyStores.get(self.fp.path) if dp is None: TestFile._cachedPropertyStores[self.fp.path] = InMemoryPropertyStore(self) dp = TestFile._cachedPropertyStores[self.fp.path] self._dead_properties = dp return self._dead_properties def parent(self): return TestFile(self.fp.parent()) class TestCase (unittest.TestCase): resource_class = TestFile def grant(*privileges): return element.ACL(*[ element.ACE( element.Grant(element.Privilege(privilege)), element.Principal(element.All()) ) for privilege in privileges ]) grant = staticmethod(grant) def grantInherit(*privileges): return element.ACL(*[ element.ACE( element.Grant(element.Privilege(privilege)), element.Principal(element.All()), TwistedACLInheritable() ) for privilege in privileges ]) grantInherit = staticmethod(grantInherit) def createDocumentRoot(self): docroot = self.mktemp() os.mkdir(docroot) rootresource = self.resource_class(docroot) rootresource.setAccessControlList(self.grantInherit(element.All())) dirnames = ( os.path.join(docroot, "dir1"), # 0 os.path.join(docroot, "dir2"), # 1 os.path.join(docroot, "dir2", "subdir1"), # 2 os.path.join(docroot, "dir3"), # 3 os.path.join(docroot, "dir4"), # 4 os.path.join(docroot, "dir4", "subdir1"), # 5 os.path.join(docroot, "dir4", "subdir1", "subsubdir1"), # 6 os.path.join(docroot, "dir4", "subdir2"), # 7 os.path.join(docroot, "dir4", "subdir2", "dir1"), # 8 os.path.join(docroot, "dir4", "subdir2", "dir2"), # 9 ) for dir in dirnames: os.mkdir(dir) src = os.path.dirname(__file__) filenames = [ os.path.join(src, f) for f in os.listdir(src) if os.path.isfile(os.path.join(src, f)) ] for dirname in (docroot,) + dirnames[3:8 + 1]: for filename in filenames[:5]: copy(filename, dirname) return docroot def _getDocumentRoot(self): if not hasattr(self, "_docroot"): log.info("Setting up docroot for %s" % (self.__class__,)) self._docroot = self.createDocumentRoot() return self._docroot def _setDocumentRoot(self, value): self._docroot = value docroot = property(_getDocumentRoot, _setDocumentRoot) def _getSite(self): if not hasattr(self, "_site"): rootresource = self.resource_class(self.docroot) rootresource.setAccessControlList(self.grantInherit(element.All())) self._site = Site(rootresource) return self._site def _setSite(self, site): self._site = site site = property(_getSite, _setSite) def setUp(self): unittest.TestCase.setUp(self) TestFile._cachedPropertyStores = {} def tearDown(self): unittest.TestCase.tearDown(self) def mkdtemp(self, prefix): """ Creates a new directory in the document root and returns its path and URI. """ path = mkdtemp(prefix=prefix + "_", dir=self.docroot) uri = joinURL("/", url_quote(os.path.basename(path))) + "/" return (os.path.abspath(path), uri) def send(self, request, callback=None): """ Invoke the logic involved in traversing a given L{server.Request} as if a client had sent it; call C{locateResource} to look up the resource to be rendered, and render it by calling its C{renderHTTP} method. @param request: A L{server.Request} (generally, to avoid real I/O, a L{SimpleRequest}) already associated with a site. @return: asynchronously return a response object or L{None} @rtype: L{Deferred} firing L{Response} or L{None} """ log.info("Sending %s request for URI %s" % (request.method, request.uri)) d = request.locateResource(request.uri) d.addCallback(lambda resource: resource.renderHTTP(request)) d.addCallback(request._cbFinishRender) if callback: if type(callback) is tuple: d.addCallbacks(*callback) else: d.addCallback(callback) return d def simpleSend(self, method, path="/", body="", mimetype="text", subtype="xml", resultcode=responsecode.OK, headers=()): """ Assemble and send a simple request using L{SimpleRequest}. This L{SimpleRequest} is associated with this L{TestCase}'s C{site} attribute. @param method: the HTTP method @type method: L{bytes} @param path: the absolute path portion of the HTTP URI @type path: L{bytes} @param body: the content body of the request @type body: L{bytes} @param mimetype: the main type of the mime type of the body of the request @type mimetype: L{bytes} @param subtype: the subtype of the mimetype of the body of the request @type subtype: L{bytes} @param resultcode: The expected result code for the response to the request. @type resultcode: L{int} @param headers: An iterable of 2-tuples of C{(header, value)}; headers to set on the outgoing request. @return: a L{Deferred} which fires with a L{bytes} if the request was successfully processed and fails with an L{HTTPError} if not; or, if the resultcode does not match the response's code, fails with L{FailTest}. """ request = SimpleRequest(self.site, method, path, content=body) if headers is not None: for k, v in headers: request.headers.setHeader(k, v) request.headers.setHeader("content-type", MimeType(mimetype, subtype)) def checkResult(response): self.assertEqual(response.code, resultcode) if response.stream is None: return None return allDataFromStream(response.stream) return self.send(request, None).addCallback(checkResult) class Site: # FIXME: There is no ISite interface; there should be. # implements(ISite) def __init__(self, resource): self.resource = resource def dircmp(dir1, dir2): dc = DirCompare(dir1, dir2) return bool( dc.left_only or dc.right_only or dc.diff_files or dc.common_funny or dc.funny_files ) def serialize(f, work): d = Deferred() def oops(error): d.errback(error) def do_serialize(_): try: args = work.next() except StopIteration: d.callback(None) else: r = f(*args) r.addCallback(do_serialize) r.addErrback(oops) do_serialize(None) return d calendarserver-5.2+dfsg/twext/web2/dav/test/test_delete.py0000644000175000017500000000525312263343324022734 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## import os import urllib import random from twext.web2 import responsecode from twext.web2.iweb import IResponse from twext.web2.test.test_server import SimpleRequest from twext.web2.dav.test.util import serialize import twext.web2.dav.test.util class DELETE(twext.web2.dav.test.util.TestCase): """ DELETE request """ # FIXME: # Try setting unwriteable perms on file, then delete # Try check response XML for error in some but not all files def test_DELETE(self): """ DELETE request """ def check_result(response, path): response = IResponse(response) if response.code != responsecode.NO_CONTENT: self.fail("DELETE response %s != %s" % (response.code, responsecode.NO_CONTENT)) if os.path.exists(path): self.fail("DELETE did not remove path %s" % (path,)) def work(): for filename in os.listdir(self.docroot): path = os.path.join(self.docroot, filename) uri = urllib.quote("/" + filename) if os.path.isdir(path): uri = uri + "/" def do_test(response, path=path): return check_result(response, path) request = SimpleRequest(self.site, "DELETE", uri) depth = random.choice(("infinity", None)) if depth is not None: request.headers.setHeader("depth", depth) yield (request, do_test) return serialize(self.send, work()) calendarserver-5.2+dfsg/twext/web2/dav/test/test_report.py0000644000175000017500000000523512263343324023005 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from twext.web2.iweb import IResponse from twext.web2.stream import MemoryStream from twext.web2 import responsecode import twext.web2.dav.test.util from twext.web2.test.test_server import SimpleRequest from txdav.xml import element as davxml class REPORT(twext.web2.dav.test.util.TestCase): """ REPORT request """ def test_REPORT_no_body(self): """ REPORT request with no body """ def do_test(response): response = IResponse(response) if response.code != responsecode.BAD_REQUEST: self.fail("Unexpected response code for REPORT with no body: %s" % (response.code,)) request = SimpleRequest(self.site, "REPORT", "/") request.stream = MemoryStream("") return self.send(request, do_test) def test_REPORT_unknown(self): """ Unknown/bogus report type """ def do_test(response): response = IResponse(response) if response.code != responsecode.FORBIDDEN: self.fail("Unexpected response code for unknown REPORT: %s" % (response.code,)) class GoofyReport (davxml.WebDAVUnknownElement): namespace = "GOOFY:" name = "goofy-report" def __init__(self): super(GoofyReport, self).__init__() request = SimpleRequest(self.site, "REPORT", "/") request.stream = MemoryStream(GoofyReport().toxml()) return self.send(request, do_test) calendarserver-5.2+dfsg/twext/web2/dav/test/test_auth.py0000644000175000017500000000543612263343324022436 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## import collections from twext.web2.dav.auth import AuthenticationWrapper import twext.web2.dav.test.util class AutoWrapperTestCase(twext.web2.dav.test.util.TestCase): def test_basicAuthPrevention(self): """ Ensure authentication factories which are not safe to use over an "unencrypted wire" are not advertised when an insecure (i.e. non-SSL connection is made. """ FakeFactory = collections.namedtuple("FakeFactory", ("scheme,")) wireEncryptedfactories = [FakeFactory("basic"), FakeFactory("digest"), FakeFactory("xyzzy")] wireUnencryptedfactories = [FakeFactory("digest"), FakeFactory("xyzzy")] class FakeChannel(object): def __init__(self, secure): self.secure = secure def getHostInfo(self): return "ignored", self.secure class FakeRequest(object): def __init__(self, secure): self.portal = None self.loginInterfaces = None self.credentialFactories = None self.chanRequest = FakeChannel(secure) wrapper = AuthenticationWrapper(None, None, wireEncryptedfactories, wireUnencryptedfactories, None) req = FakeRequest(True) # Connection is over SSL wrapper.hook(req) self.assertEquals( set(req.credentialFactories.keys()), set(["basic", "digest", "xyzzy"]) ) req = FakeRequest(False) # Connection is not over SSL wrapper.hook(req) self.assertEquals( set(req.credentialFactories.keys()), set(["digest", "xyzzy"]) ) calendarserver-5.2+dfsg/twext/web2/dav/idav.py0000644000175000017500000003022412263343324020373 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ web2.dav interfaces. """ __all__ = [ "IDAVResource", "IDAVPrincipalResource", "IDAVPrincipalCollectionResource", ] from twext.web2.iweb import IResource class IDAVResource(IResource): """ WebDAV resource. """ def isCollection(): """ Checks whether this resource is a collection resource. @return: C{True} if this resource is a collection resource, C{False} otherwise. """ def findChildren(depth, request, callback, privileges, inherited_aces): """ Returns an iterable of child resources for the given depth. Because resources do not know their request URIs, chidren are returned as tuples C{(resource, uri)}, where C{resource} is the child resource and C{uri} is a URL path relative to this resource. @param depth: the search depth (one of C{"0"}, C{"1"}, or C{"infinity"}) @param request: The current L{IRequest} responsible for this call. @param callback: C{callable} that will be called for each child found @param privileges: the list of L{Privilege}s to test for. This should default to None. @param inherited_aces: a list of L{Privilege}s for aces being inherited from the parent collection used to bypass inheritance lookup. @return: An L{Deferred} that fires when all the children have been found """ def hasProperty(property, request): """ Checks whether the given property is defined on this resource. @param property: an empty L{davxml.WebDAVElement} instance or a qname tuple. @param request: the request being processed. @return: a deferred value of C{True} if the given property is set on this resource, or C{False} otherwise. """ def readProperty(property, request): """ Reads the given property on this resource. @param property: an empty L{davxml.WebDAVElement} class or instance, or a qname tuple. @param request: the request being processed. @return: a deferred L{davxml.WebDAVElement} instance containing the value of the given property. @raise HTTPError: (containing a response with a status code of L{responsecode.CONFLICT}) if C{property} is not set on this resource. """ def writeProperty(property, request): """ Writes the given property on this resource. @param property: a L{davxml.WebDAVElement} instance. @param request: the request being processed. @return: an empty deferred which fires when the operation is completed. @raise HTTPError: (containing a response with a status code of L{responsecode.CONFLICT}) if C{property} is a read-only property. """ def removeProperty(property, request): """ Removes the given property from this resource. @param property: a L{davxml.WebDAVElement} instance or a qname tuple. @param request: the request being processed. @return: an empty deferred which fires when the operation is completed. @raise HTTPError: (containing a response with a status code of L{responsecode.CONFLICT}) if C{property} is a read-only property or if the property does not exist. """ def listProperties(request): """ @param request: the request being processed. @return: a deferred iterable of qnames for all properties defined for this resource. """ def supportedReports(): """ @return: an iterable of L{davxml.Report} elements for each report supported by this resource. """ def authorize(request, privileges, recurse=False): """ Verify that the given request is authorized to perform actions that require the given privileges. @param request: the request being processed. @param privileges: an iterable of L{davxml.WebDAVElement} elements denoting access control privileges. @param recurse: C{True} if a recursive check on all child resources of this resource should be performed as well, C{False} otherwise. @return: a Deferred which fires with C{None} when authorization is complete, or errbacks with L{HTTPError} (containing a response with a status code of L{responsecode.UNAUTHORIZED}) if not authorized. """ def principalCollections(): """ @return: an interable of L{IDAVPrincipalCollectionResource}s which contain principals used in ACLs for this resource. """ def setAccessControlList(acl): """ Sets the access control list containing the access control list for this resource. @param acl: an L{davxml.ACL} element. """ def supportedPrivileges(request): """ @param request: the request being processed. @return: a L{Deferred} with an L{davxml.SupportedPrivilegeSet} result describing the access control privileges which are supported by this resource. """ def currentPrivileges(request): """ @param request: the request being processed. @return: a sequence of the access control privileges which are set for the currently authenticated user. """ def accessControlList(request, inheritance=True, expanding=False): """ Obtains the access control list for this resource. @param request: the request being processed. @param inheritance: if True, replace inherited privileges with those from the import resource being inherited from, if False just return whatever is set in this ACL. @param expanding: if C{True}, method is called during parent inheritance expansion, if C{False} then not doing parent expansion. @return: a deferred L{davxml.ACL} element containing the access control list for this resource. """ def privilegesForPrincipal(principal, request): """ Evaluate the set of privileges that apply to the specified principal. This involves examing all ace's and granting/denying as appropriate for the specified principal's membership of the ace's prinicpal. @param request: the request being processed. @return: a list of L{Privilege}s that are allowed on this resource for the specified principal. """ ## # Quota ## def quota(request): """ Get current available & used quota values for this resource's quota root collection. @return: a C{tuple} containing two C{int}'s the first is quota-available-bytes, the second is quota-used-bytes, or C{None} if quota is not defined on the resource. """ def hasQuota(request): """ Check whether this resource is undre quota control by checking each parent to see if it has a quota root. @return: C{True} if under quota control, C{False} if not. """ def hasQuotaRoot(request): """ Determine whether the resource has a quota root. @return: a C{True} if this resource has quota root, C{False} otherwise. """ def quotaRoot(request): """ Get the quota root (max. allowed bytes) value for this collection. @return: a C{int} containing the maximum allowed bytes if this collection is quota-controlled, or C{None} if not quota controlled. """ def setQuotaRoot(request, maxsize): """ Set the quota root (max. allowed bytes) value for this collection. @param maxsize: a C{int} containing the maximum allowed bytes for the contents of this collection. """ def quotaSize(request): """ Get the size of this resource (if its a collection get total for all children as well). TODO: Take into account size of dead-properties. @return: a L{Deferred} with a C{int} result containing the size of the resource. """ def currentQuotaUse(request): """ Get the cached quota use value, or if not present (or invalid) determine quota use by brute force. @return: an L{Deferred} with a C{int} result containing the current used byte count if this collection is quota-controlled, or C{None} if not quota controlled. """ def updateQuotaUse(request, adjust): """ Adjust current quota use on this all all parent collections that also have quota roots. @param adjust: a C{int} containing the number of bytes added (positive) or removed (negative) that should be used to adjust the cached total. @return: an L{Deferred} with a C{int} result containing the current used byte if this collection is quota-controlled, or C{None} if not quota controlled. """ class IDAVPrincipalResource (IDAVResource): """ WebDAV principal resource. (RFC 3744, section 2) """ def alternateURIs(): """ Provides the URIs of network resources with additional descriptive information about the principal, for example, a URI to an LDAP record. (RFC 3744, section 4.1) @return: a iterable of URIs. """ def principalURL(): """ Provides the URL which must be used to identify this principal in ACL requests. (RFC 3744, section 4.2) @return: a URL. """ def groupMembers(): """ Provides the principal URLs of principals that are direct members of this (group) principal. (RFC 3744, section 4.3) @return: a deferred returning an iterable of principal URLs. """ def expandedGroupMembers(): """ Provides the principal URLs of principals that are members of this (group) principal, as well as members of any group principal which are members of this one. @return: a L{Deferred} that fires with an iterable of principal URLs. """ def groupMemberships(): """ Provides the URLs of the group principals in which the principal is directly a member. (RFC 3744, section 4.4) @return: a deferred containing an iterable of group principal URLs. """ class IDAVPrincipalCollectionResource(IDAVResource): """ WebDAV principal collection resource. (RFC 3744, section 5.8) """ def principalCollectionURL(): """ Provides a URL for this resource which may be used to identify this resource in ACL requests. (RFC 3744, section 5.8) @return: a URL. """ def principalForUser(user): """ Retrieve the principal for a given username. @param user: the (short) name of a user. @type user: C{str} @return: the resource representing the DAV principal resource for the given username. @rtype: L{IDAVPrincipalResource} """ calendarserver-5.2+dfsg/twext/web2/dav/http.py0000644000175000017500000003107612263343324020435 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. ## """ HTTP Utilities """ __all__ = [ "ErrorResponse", "NeedPrivilegesResponse", "MultiStatusResponse", "ResponseQueue", "PropertyStatusResponseQueue", "statusForFailure", "errorForFailure", "messageForFailure", ] import errno from twisted.python.failure import Failure from twisted.python.filepath import InsecurePath from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.iweb import IResponse from twext.web2.http import Response, HTTPError, StatusResponse from twext.web2.http_headers import MimeType from twext.web2.dav.util import joinURL from txdav.xml import element log = Logger() class ErrorResponse(Response): """ A L{Response} object which contains a status code and a L{element.Error} element. Renders itself as a DAV:error XML document. """ error = None unregistered = True # base class is already registered def __init__(self, code, error, description=None): """ @param code: a response code. @param error: an L{WebDAVElement} identifying the error, or a tuple C{(namespace, name)} with which to create an empty element denoting the error. (The latter is useful in the case of preconditions ans postconditions, not all of which have defined XML element classes.) @param description: an optional string that, if present, will get wrapped in a (twisted_dav_namespace, error-description) element. """ if type(error) is tuple: xml_namespace, xml_name = error error = element.WebDAVUnknownElement() error.namespace = xml_namespace error.name = xml_name self.description = description if self.description: output = element.Error(error, element.ErrorDescription(self.description)).toxml() else: output = element.Error(error).toxml() Response.__init__(self, code=code, stream=output) self.headers.setHeader("content-type", MimeType("text", "xml")) self.error = error def __repr__(self): return "<%s %s %s>" % (self.__class__.__name__, self.code, self.error.sname()) class NeedPrivilegesResponse (ErrorResponse): def __init__(self, base_uri, errors): """ An error response which is due to unsufficient privileges, as determined by L{DAVResource.checkPrivileges}. @param base_uri: the base URI for the resources with errors (the URI of the resource on which C{checkPrivileges} was called). @param errors: a sequence of tuples, as returned by C{checkPrivileges}. """ denials = [] for subpath, privileges in errors: if subpath is None: uri = base_uri else: uri = joinURL(base_uri, subpath) for p in privileges: denials.append(element.Resource(element.HRef(uri), element.Privilege(p))) super(NeedPrivilegesResponse, self).__init__(responsecode.FORBIDDEN, element.NeedPrivileges(*denials)) class MultiStatusResponse (Response): """ Multi-status L{Response} object. Renders itself as a DAV:multi-status XML document. """ def __init__(self, xml_responses): """ @param xml_responses: an interable of element.Response objects. """ Response.__init__(self, code=responsecode.MULTI_STATUS, stream=element.MultiStatus(*xml_responses).toxml()) self.headers.setHeader("content-type", MimeType("text", "xml")) class ResponseQueue (object): """ Stores a list of (typically error) responses for use in a L{MultiStatusResponse}. """ def __init__(self, path_basename, method, success_response): """ @param path_basename: the base path for all responses to be added to the queue. All paths for responses added to the queue must start with C{path_basename}, which will be stripped from the beginning of each path to determine the response's URI. @param method: the name of the method generating the queue. @param success_response: the response to return in lieu of a L{MultiStatusResponse} if no responses are added to this queue. """ self.responses = [] self.path_basename = path_basename self.path_basename_len = len(path_basename) self.method = method self.success_response = success_response def add(self, path, what): """ Add a response. @param path: a path, which must be a subpath of C{path_basename} as provided to L{__init__}. @param what: a status code or a L{Failure} for the given path. """ assert path.startswith(self.path_basename), "%s does not start with %s" % (path, self.path_basename) if type(what) is int: code = what error = None message = responsecode.RESPONSES[code] elif isinstance(what, Failure): code = statusForFailure(what) error = errorForFailure(what) message = messageForFailure(what) else: raise AssertionError("Unknown data type: %r" % (what,)) if code > 400: # Error codes only log.error("Error during %s for %s: %s" % (self.method, path, message)) uri = path[self.path_basename_len:] children = [] children.append(element.HRef(uri)) children.append(element.Status.fromResponseCode(code)) if error is not None: children.append(error) if message is not None: children.append(element.ResponseDescription(message)) self.responses.append(element.StatusResponse(*children)) def response(self): """ Generate a L{MultiStatusResponse} with the responses contained in the queue or, if no such responses, return the C{success_response} provided to L{__init__}. @return: the response. """ if self.responses: return MultiStatusResponse(self.responses) else: return self.success_response class PropertyStatusResponseQueue (object): """ Stores a list of propstat elements for use in a L{Response} in a L{MultiStatusResponse}. """ def __init__(self, method, uri, success_response): """ @param method: the name of the method generating the queue. @param uri: the URI for the response. @param success_response: the status to return if no L{PropertyStatus} are added to this queue. """ self.method = method self.uri = uri self.propstats = [] self.success_response = success_response def add(self, what, property): """ Add a response. @param what: a status code or a L{Failure} for the given path. @param property: the property whose status is being reported. """ if type(what) is int: code = what error = None message = responsecode.RESPONSES[code] elif isinstance(what, Failure): code = statusForFailure(what) error = errorForFailure(what) message = messageForFailure(what) else: raise AssertionError("Unknown data type: %r" % (what,)) if len(property.children) > 0: # Re-instantiate as empty element. property = element.WebDAVUnknownElement.withName(property.namespace, property.name) if code > 400: # Error codes only log.error("Error during %s for %s: %s" % (self.method, property, message)) children = [] children.append(element.PropertyContainer(property)) children.append(element.Status.fromResponseCode(code)) if error is not None: children.append(error) if message is not None: children.append(element.ResponseDescription(message)) self.propstats.append(element.PropertyStatus(*children)) def error(self): """ Convert any 2xx codes in the propstat responses to 424 Failed Dependency. """ for index, propstat in enumerate(self.propstats): # Check the status changed_status = False newchildren = [] for child in propstat.children: if isinstance(child, element.Status) and (child.code / 100 == 2): # Change the code newchildren.append(element.Status.fromResponseCode(responsecode.FAILED_DEPENDENCY)) changed_status = True elif changed_status and isinstance(child, element.ResponseDescription): newchildren.append(element.ResponseDescription(responsecode.RESPONSES[responsecode.FAILED_DEPENDENCY])) else: newchildren.append(child) self.propstats[index] = element.PropertyStatus(*newchildren) def response(self): """ Generate a response from the responses contained in the queue or, if there are no such responses, return the C{success_response} provided to L{__init__}. @return: a L{element.PropertyStatusResponse}. """ if self.propstats: return element.PropertyStatusResponse( element.HRef(self.uri), *self.propstats ) else: return element.StatusResponse( element.HRef(self.uri), element.Status.fromResponseCode(self.success_response) ) ## # Exceptions and response codes ## def statusForFailure(failure, what=None): """ @param failure: a L{Failure}. @param what: a decription of what was going on when the failure occurred. If what is not C{None}, emit a cooresponding message via L{log.err}. @return: a response code cooresponding to the given C{failure}. """ def msg(err): if what is not None: log.debug("%s while %s" % (err, what)) if failure.check(IOError, OSError): e = failure.value[0] if e == errno.EACCES or e == errno.EPERM: msg("Permission denied") return responsecode.FORBIDDEN elif e == errno.ENOSPC: msg("Out of storage space") return responsecode.INSUFFICIENT_STORAGE_SPACE elif e == errno.ENOENT: msg("Not found") return responsecode.NOT_FOUND else: failure.raiseException() elif failure.check(NotImplementedError): msg("Unimplemented error") return responsecode.NOT_IMPLEMENTED elif failure.check(InsecurePath): msg("Insecure path") return responsecode.FORBIDDEN elif failure.check(HTTPError): code = IResponse(failure.value.response).code msg("%d response" % (code,)) return code else: failure.raiseException() def errorForFailure(failure): if failure.check(HTTPError) and isinstance(failure.value.response, ErrorResponse): return element.Error(failure.value.response.error) else: return None def messageForFailure(failure): if failure.check(HTTPError): if isinstance(failure.value.response, ErrorResponse): return failure.value.response.description elif isinstance(failure.value.response, StatusResponse): return failure.value.response.description return str(failure) calendarserver-5.2+dfsg/twext/web2/dav/method/0000755000175000017500000000000012322625325020355 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/dav/method/report_acl_principal_prop_set.py0000644000175000017500000001371612263343324027045 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_report_expand -*- ## # Copyright (c) 2006-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV acl-prinicpal-prop-set report """ __all__ = ["report_DAV__acl_principal_prop_set"] from twisted.internet.defer import deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import HTTPError, StatusResponse from txdav.xml import element as davxml from twext.web2.dav.http import ErrorResponse from twext.web2.dav.http import MultiStatusResponse from twext.web2.dav.method import prop_common from twext.web2.dav.method.report import NumberOfMatchesWithinLimits from twext.web2.dav.method.report import max_number_of_matches log = Logger() def report_DAV__acl_principal_prop_set(self, request, acl_prinicpal_prop_set): """ Generate an acl-prinicpal-prop-set REPORT. (RFC 3744, section 9.2) """ # Verify root element if not isinstance(acl_prinicpal_prop_set, davxml.ACLPrincipalPropSet): raise ValueError("%s expected as root element, not %s." % (davxml.ACLPrincipalPropSet.sname(), acl_prinicpal_prop_set.sname())) # Depth must be "0" depth = request.headers.getHeader("depth", "0") if depth != "0": log.error("Error in prinicpal-prop-set REPORT, Depth set to %s" % (depth,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,))) # # Check authentication and access controls # x = waitForDeferred(self.authorize(request, (davxml.ReadACL(),))) yield x x.getResult() # Get a single DAV:prop element from the REPORT request body propertiesForResource = None propElement = None for child in acl_prinicpal_prop_set.children: if child.qname() == ("DAV:", "prop"): if propertiesForResource is not None: log.error("Only one DAV:prop element allowed") raise HTTPError(StatusResponse( responsecode.BAD_REQUEST, "Only one DAV:prop element allowed" )) propertiesForResource = prop_common.propertyListForResource propElement = child if propertiesForResource is None: log.error("Error in acl-principal-prop-set REPORT, no DAV:prop element") raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "No DAV:prop element")) # Enumerate principals on ACL in current resource principals = [] acl = waitForDeferred(self.accessControlList(request)) yield acl acl = acl.getResult() for ace in acl.children: resolved = waitForDeferred(self.resolvePrincipal(ace.principal.children[0], request)) yield resolved resolved = resolved.getResult() if resolved is not None and resolved not in principals: principals.append(resolved) # Run report for each referenced principal try: responses = [] matchcount = 0 for principal in principals: # Check size of results is within limit matchcount += 1 if matchcount > max_number_of_matches: raise NumberOfMatchesWithinLimits(max_number_of_matches) resource = waitForDeferred(request.locateResource(str(principal))) yield resource resource = resource.getResult() if resource is not None: # # Check authentication and access controls # x = waitForDeferred(resource.authorize(request, (davxml.Read(),))) yield x try: x.getResult() except HTTPError: responses.append(davxml.StatusResponse( principal, davxml.Status.fromResponseCode(responsecode.FORBIDDEN) )) else: d = waitForDeferred(prop_common.responseForHref( request, responses, principal, resource, propertiesForResource, propElement )) yield d d.getResult() else: log.error("Requested principal resource not found: %s" % (str(principal),)) responses.append(davxml.StatusResponse( principal, davxml.Status.fromResponseCode(responsecode.NOT_FOUND) )) except NumberOfMatchesWithinLimits: log.error("Too many matching components") raise HTTPError(ErrorResponse( responsecode.FORBIDDEN, davxml.NumberOfMatchesWithinLimits() )) yield MultiStatusResponse(responses) report_DAV__acl_principal_prop_set = deferredGenerator(report_DAV__acl_principal_prop_set) calendarserver-5.2+dfsg/twext/web2/dav/method/put.py0000644000175000017500000000742012263343324021542 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_put -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV PUT method """ __all__ = ["preconditions_PUT", "http_PUT"] from twisted.internet.defer import deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import HTTPError, StatusResponse from txdav.xml import element as davxml from twext.web2.dav.method import put_common from twext.web2.dav.util import parentForURL log = Logger() def preconditions_PUT(self, request): # # Check authentication and access controls # if self.exists(): x = waitForDeferred(self.authorize(request, (davxml.WriteContent(),))) yield x x.getResult() else: parent = waitForDeferred(request.locateResource(parentForURL(request.uri))) yield parent parent = parent.getResult() if not parent.exists(): raise HTTPError( StatusResponse( responsecode.CONFLICT, "cannot PUT to non-existent parent")) x = waitForDeferred(parent.authorize(request, (davxml.Bind(),))) yield x x.getResult() # # HTTP/1.1 (RFC 2068, section 9.6) requires that we respond with a Not # Implemented error if we get a Content-* header which we don't # recognize and handle properly. # for header, value in request.headers.getAllRawHeaders(): if header.startswith("Content-") and header not in ( #"Content-Base", # Doesn't make sense in PUT? #"Content-Encoding", # Requires that we decode it? "Content-Language", "Content-Length", #"Content-Location", # Doesn't make sense in PUT? "Content-MD5", #"Content-Range", # FIXME: Need to implement this "Content-Type", ): log.error("Client sent unrecognized content header in PUT request: %s" % (header,)) raise HTTPError(StatusResponse( responsecode.NOT_IMPLEMENTED, "Unrecognized content header %r in request." % (header,) )) preconditions_PUT = deferredGenerator(preconditions_PUT) def http_PUT(self, request): """ Respond to a PUT request. (RFC 2518, section 8.7) """ log.info("Writing request stream to %s" % (self,)) # # Don't pass in the request URI, since PUT isn't specified to be able # to return a MULTI_STATUS response, which is WebDAV-specific (and PUT is # not). # #return put(request.stream, self.fp) return put_common.storeResource(request, destination=self, destination_uri=request.uri) calendarserver-5.2+dfsg/twext/web2/dav/method/mkcol.py0000644000175000017500000000550112263343324022035 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_mkcol -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV MKCOL method """ __all__ = ["http_MKCOL"] from twisted.internet.defer import deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import HTTPError, StatusResponse from txdav.xml import element as davxml from twext.web2.dav.fileop import mkcollection from twext.web2.dav.util import noDataFromStream, parentForURL log = Logger() def http_MKCOL(self, request): """ Respond to a MKCOL request. (RFC 2518, section 8.3) """ parent = waitForDeferred(request.locateResource(parentForURL(request.uri))) yield parent parent = parent.getResult() x = waitForDeferred(parent.authorize(request, (davxml.Bind(),))) yield x x.getResult() if self.exists(): log.error("Attempt to create collection where file exists: %s" % (self,)) raise HTTPError(responsecode.NOT_ALLOWED) if not parent.isCollection(): log.error("Attempt to create collection with non-collection parent: %s" % (self,)) raise HTTPError(StatusResponse( responsecode.CONFLICT, "Parent resource is not a collection." )) # # Read request body # x = waitForDeferred(noDataFromStream(request.stream)) yield x try: x.getResult() except ValueError, e: log.error("Error while handling MKCOL body: %s" % (e,)) raise HTTPError(responsecode.UNSUPPORTED_MEDIA_TYPE) response = waitForDeferred(mkcollection(self.fp)) yield response yield response.getResult() http_MKCOL = deferredGenerator(http_MKCOL) calendarserver-5.2+dfsg/twext/web2/dav/method/copymove.py0000644000175000017500000002222412263343324022572 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_copy,twext.web2.dav.test.test_move -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV COPY and MOVE methods. """ __all__ = ["http_COPY", "http_MOVE"] from twisted.internet.defer import waitForDeferred, deferredGenerator from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.dav.fileop import move from twext.web2.http import HTTPError, StatusResponse from twext.web2.filter.location import addLocation from txdav.xml import element as davxml from twext.web2.dav.idav import IDAVResource from twext.web2.dav.method import put_common from twext.web2.dav.util import parentForURL # FIXME: This is circular import twext.web2.dav.static log = Logger() def http_COPY(self, request): """ Respond to a COPY request. (RFC 2518, section 8.8) """ r = waitForDeferred(prepareForCopy(self, request)) yield r r = r.getResult() destination, destination_uri, depth = r # # Check authentication and access controls # x = waitForDeferred(self.authorize(request, (davxml.Read(),), recurse=True)) yield x x.getResult() if destination.exists(): x = waitForDeferred(destination.authorize( request, (davxml.WriteContent(), davxml.WriteProperties()), recurse=True )) yield x x.getResult() else: destparent = waitForDeferred(request.locateResource(parentForURL(destination_uri))) yield destparent destparent = destparent.getResult() x = waitForDeferred(destparent.authorize(request, (davxml.Bind(),))) yield x x.getResult() # May need to add a location header addLocation(request, destination_uri) #x = waitForDeferred(copy(self.fp, destination.fp, destination_uri, depth)) x = waitForDeferred(put_common.storeResource(request, source=self, source_uri=request.uri, destination=destination, destination_uri=destination_uri, deletesource=False, depth=depth )) yield x yield x.getResult() http_COPY = deferredGenerator(http_COPY) def http_MOVE(self, request): """ Respond to a MOVE request. (RFC 2518, section 8.9) """ r = waitForDeferred(prepareForCopy(self, request)) yield r r = r.getResult() destination, destination_uri, depth = r # # Check authentication and access controls # parentURL = parentForURL(request.uri) parent = waitForDeferred(request.locateResource(parentURL)) yield parent parent = parent.getResult() x = waitForDeferred(parent.authorize(request, (davxml.Unbind(),))) yield x x.getResult() if destination.exists(): x = waitForDeferred(destination.authorize( request, (davxml.Bind(), davxml.Unbind()), recurse=True )) yield x x.getResult() else: destparentURL = parentForURL(destination_uri) destparent = waitForDeferred(request.locateResource(destparentURL)) yield destparent destparent = destparent.getResult() x = waitForDeferred(destparent.authorize(request, (davxml.Bind(),))) yield x x.getResult() # May need to add a location header addLocation(request, destination_uri) # # RFC 2518, section 8.9 says that we must act as if the Depth header is set # to infinity, and that the client must omit the Depth header or set it to # infinity. # # This seems somewhat at odds with the notion that a bad request should be # rejected outright; if the client sends a bad depth header, the client is # broken, and section 8 suggests that a bad request should be rejected... # # Let's play it safe for now and ignore broken clients. # if self.isCollection() and depth != "infinity": msg = "Client sent illegal depth header value for MOVE: %s" % (depth,) log.error(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) # Lets optimise a move within the same directory to a new resource as a simple move # rather than using the full transaction based storeResource api. This allows simple # "rename" operations to work quickly. if (not destination.exists()) and destparent == parent: x = waitForDeferred(move(self.fp, request.uri, destination.fp, destination_uri, depth)) else: x = waitForDeferred(put_common.storeResource(request, source=self, source_uri=request.uri, destination=destination, destination_uri=destination_uri, deletesource=True, depth=depth)) yield x yield x.getResult() http_MOVE = deferredGenerator(http_MOVE) def prepareForCopy(self, request): # # Get the depth # depth = request.headers.getHeader("depth", "infinity") if depth not in ("0", "infinity"): msg = ("Client sent illegal depth header value: %s" % (depth,)) log.error(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) # # Verify this resource exists # if not self.exists(): log.error("File not found: %s" % (self,)) raise HTTPError(StatusResponse( responsecode.NOT_FOUND, "Source resource %s not found." % (request.uri,) )) # # Get the destination # destination_uri = request.headers.getHeader("destination") if not destination_uri: msg = "No destination header in %s request." % (request.method,) log.error(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) d = request.locateResource(destination_uri) d.addCallback(_prepareForCopy, destination_uri, request, depth) return d def _prepareForCopy(destination, destination_uri, request, depth): # # Destination must be a DAV resource # try: destination = IDAVResource(destination) except TypeError: log.error("Attempt to %s to a non-DAV resource: (%s) %s" % (request.method, destination.__class__, destination_uri)) raise HTTPError(StatusResponse( responsecode.FORBIDDEN, "Destination %s is not a WebDAV resource." % (destination_uri,) )) # # FIXME: Right now we don't know how to copy to a non-DAVFile resource. # We may need some more API in IDAVResource. # So far, we need: .exists(), .fp.parent() # if not isinstance(destination, twext.web2.dav.static.DAVFile): log.error("DAV copy between non-DAVFile DAV resources isn't implemented") raise HTTPError(StatusResponse( responsecode.NOT_IMPLEMENTED, "Destination %s is not a DAVFile resource." % (destination_uri,) )) # # Check for existing destination resource # overwrite = request.headers.getHeader("overwrite", True) if destination.exists() and not overwrite: log.error("Attempt to %s onto existing file without overwrite flag enabled: %s" % (request.method, destination)) raise HTTPError(StatusResponse( responsecode.PRECONDITION_FAILED, "Destination %s already exists." % (destination_uri,) )) # # Make sure destination's parent exists # if not destination.parent().isCollection(): log.error("Attempt to %s to a resource with no parent: %s" % (request.method, destination.fp.path)) raise HTTPError(StatusResponse(responsecode.CONFLICT, "No parent collection.")) return destination, destination_uri, depth calendarserver-5.2+dfsg/twext/web2/dav/method/report_principal_property_search.py0000644000175000017500000001746412263343324027610 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_report_expand -*- ## # Copyright (c) 2006-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV prinicpal-property-search report """ __all__ = ["report_DAV__principal_property_search"] from twisted.internet.defer import deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import HTTPError, StatusResponse from txdav.xml.base import PCDATAElement from txdav.xml import element from txdav.xml.element import dav_namespace from twext.web2.dav.http import ErrorResponse, MultiStatusResponse from twext.web2.dav.method import prop_common from twext.web2.dav.method.report import NumberOfMatchesWithinLimits from twext.web2.dav.method.report import max_number_of_matches from twext.web2.dav.resource import isPrincipalResource log = Logger() def report_DAV__principal_property_search(self, request, principal_property_search): """ Generate a principal-property-search REPORT. (RFC 3744, section 9.4) """ # Verify root element if not isinstance(principal_property_search, element.PrincipalPropertySearch): raise ValueError("%s expected as root element, not %s." % (element.PrincipalPropertySearch.sname(), principal_property_search.sname())) # Only handle Depth: 0 depth = request.headers.getHeader("depth", "0") if depth != "0": log.error("Error in prinicpal-property-search REPORT, Depth set to %s" % (depth,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,))) # Get a single DAV:prop element from the REPORT request body propertiesForResource = None propElement = None propertySearches = [] applyTo = False for child in principal_property_search.children: if child.qname() == (dav_namespace, "prop"): propertiesForResource = prop_common.propertyListForResource propElement = child elif child.qname() == (dav_namespace, "apply-to-principal-collection-set"): applyTo = True elif child.qname() == (dav_namespace, "property-search"): props = child.childOfType(element.PropertyContainer) props.removeWhitespaceNodes() match = child.childOfType(element.Match) propertySearches.append((props.children, str(match).lower())) def nodeMatch(node, match): """ See if the content of the supplied node matches the supplied text. Try to follow the matching guidance in rfc3744 section 9.4.1. @param prop: the property element to match. @param match: the text to match against. @return: True if the property matches, False otherwise. """ node.removeWhitespaceNodes() for child in node.children: if isinstance(child, PCDATAElement): comp = str(child).lower() if comp.find(match) != -1: return True else: return nodeMatch(child, match) else: return False def propertySearch(resource, request): """ Test the resource to see if it contains properties matching the property-search specification in this report. @param resource: the L{DAVFile} for the resource to test. @param request: the current request. @return: True if the resource has matching properties, False otherwise. """ for props, match in propertySearches: # Test each property for prop in props: try: propvalue = waitForDeferred(resource.readProperty(prop.qname(), request)) yield propvalue propvalue = propvalue.getResult() if propvalue and not nodeMatch(propvalue, match): yield False return except HTTPError: # No property => no match yield False return yield True propertySearch = deferredGenerator(propertySearch) # Run report try: resources = [] responses = [] matchcount = 0 if applyTo: for principalCollection in self.principalCollections(): uri = principalCollection.principalCollectionURL() resource = waitForDeferred(request.locateResource(uri)) yield resource resource = resource.getResult() if resource: resources.append((resource, uri)) else: resources.append((self, request.uri)) # Loop over all collections and principal resources within for resource, ruri in resources: # Do some optimisation of access control calculation by determining any inherited ACLs outside of # the child resource loop and supply those to the checkPrivileges on each child. filteredaces = waitForDeferred(resource.inheritedACEsforChildren(request)) yield filteredaces filteredaces = filteredaces.getResult() children = [] d = waitForDeferred(resource.findChildren("infinity", request, lambda x, y: children.append((x,y)), privileges=(element.Read(),), inherited_aces=filteredaces)) yield d d.getResult() for child, uri in children: if isPrincipalResource(child): d = waitForDeferred(propertySearch(child, request)) yield d d = d.getResult() if d: # Check size of results is within limit matchcount += 1 if matchcount > max_number_of_matches: raise NumberOfMatchesWithinLimits(max_number_of_matches) d = waitForDeferred(prop_common.responseForHref( request, responses, element.HRef.fromString(uri), child, propertiesForResource, propElement )) yield d d.getResult() except NumberOfMatchesWithinLimits: log.error("Too many matching components in prinicpal-property-search report") raise HTTPError(ErrorResponse( responsecode.FORBIDDEN, element.NumberOfMatchesWithinLimits() )) yield MultiStatusResponse(responses) report_DAV__principal_property_search = deferredGenerator(report_DAV__principal_property_search) calendarserver-5.2+dfsg/twext/web2/dav/method/delete_common.py0000644000175000017500000000464312263343324023550 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_delete -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. ## """ WebDAV DELETE method """ __all__ = ["deleteResource"] from twisted.internet.defer import waitForDeferred, deferredGenerator from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import HTTPError from twext.web2.dav.fileop import delete log = Logger() def deleteResource(request, resource, resource_uri, depth="0"): """ Handle a resource delete with proper quota etc updates """ if not resource.exists(): log.error("File not found: %s" % (resource,)) raise HTTPError(responsecode.NOT_FOUND) # Do quota checks before we start deleting things myquota = waitForDeferred(resource.quota(request)) yield myquota myquota = myquota.getResult() if myquota is not None: old_size = waitForDeferred(resource.quotaSize(request)) yield old_size old_size = old_size.getResult() else: old_size = 0 # Do delete x = waitForDeferred(delete(resource_uri, resource.fp, depth)) yield x result = x.getResult() # Adjust quota if myquota is not None: d = waitForDeferred(resource.quotaSizeAdjust(request, -old_size)) yield d d.getResult() yield result deleteResource = deferredGenerator(deleteResource) calendarserver-5.2+dfsg/twext/web2/dav/method/put_common.py0000644000175000017500000002551112263343324023113 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # DRI: Cyrus Daboo, cdaboo@apple.com ## """ PUT/COPY/MOVE common behavior. """ __version__ = "0.0" __all__ = ["storeResource"] from twisted.python.failure import Failure from twext.python.filepath import CachingFilePath as FilePath from twisted.internet.defer import deferredGenerator, maybeDeferred, waitForDeferred from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.dav.fileop import copy, delete, put from twext.web2.dav.http import ErrorResponse from twext.web2.dav.resource import TwistedGETContentMD5 from twext.web2.stream import MD5Stream from twext.web2.http import HTTPError from twext.web2.http_headers import generateContentType from twext.web2.iweb import IResponse from twext.web2.stream import MemoryStream from txdav.xml import element as davxml from txdav.xml.base import dav_namespace log = Logger() def storeResource( request, source=None, source_uri=None, data=None, destination=None, destination_uri=None, deletesource=False, depth="0" ): """ Function that does common PUT/COPY/MOVE behaviour. @param request: the L{twext.web2.server.Request} for the current HTTP request. @param source: the L{DAVFile} for the source resource to copy from, or None if source data is to be read from the request. @param source_uri: the URI for the source resource. @param data: a C{str} to copy data from instead of the request stream. @param destination: the L{DAVFile} for the destination resource to copy into. @param destination_uri: the URI for the destination resource. @param deletesource: True if the source resource is to be deleted on successful completion, False otherwise. @param depth: a C{str} containing the COPY/MOVE Depth header value. @return: status response. """ try: assert request is not None and destination is not None and destination_uri is not None assert (source is None) or (source is not None and source_uri is not None) assert not deletesource or (deletesource and source is not None) except AssertionError: log.error("Invalid arguments to storeResource():") log.error("request=%s\n" % (request,)) log.error("source=%s\n" % (source,)) log.error("source_uri=%s\n" % (source_uri,)) log.error("data=%s\n" % (data,)) log.error("destination=%s\n" % (destination,)) log.error("destination_uri=%s\n" % (destination_uri,)) log.error("deletesource=%s\n" % (deletesource,)) log.error("depth=%s\n" % (depth,)) raise class RollbackState(object): """ This class encapsulates the state needed to rollback the entire PUT/COPY/MOVE transaction, leaving the server state the same as it was before the request was processed. The DoRollback method will actually execute the rollback operations. """ def __init__(self): self.active = True self.source_copy = None self.destination_copy = None self.destination_created = False self.source_deleted = False def Rollback(self): """ Rollback the server state. Do not allow this to raise another exception. If rollback fails then we are going to be left in an awkward state that will need to be cleaned up eventually. """ if self.active: self.active = False log.error("Rollback: rollback") try: if self.source_copy and self.source_deleted: self.source_copy.moveTo(source.fp) log.error("Rollback: source restored %s to %s" % (self.source_copy.path, source.fp.path)) self.source_copy = None self.source_deleted = False if self.destination_copy: destination.fp.remove() log.error("Rollback: destination restored %s to %s" % (self.destination_copy.path, destination.fp.path)) self.destination_copy.moveTo(destination.fp) self.destination_copy = None elif self.destination_created: destination.fp.remove() log.error("Rollback: destination removed %s" % (destination.fp.path,)) self.destination_created = False except: log.error("Rollback: exception caught and not handled: %s" % Failure()) def Commit(self): """ Commit the resource changes by wiping the rollback state. """ if self.active: log.error("Rollback: commit") self.active = False if self.source_copy: self.source_copy.remove() log.error("Rollback: removed source backup %s" % (self.source_copy.path,)) self.source_copy = None if self.destination_copy: self.destination_copy.remove() log.error("Rollback: removed destination backup %s" % (self.destination_copy.path,)) self.destination_copy = None self.destination_created = False self.source_deleted = False rollback = RollbackState() try: """ Handle validation operations here. """ """ Handle rollback setup here. """ # Do quota checks on destination and source before we start messing with adding other files destquota = waitForDeferred(destination.quota(request)) yield destquota destquota = destquota.getResult() if destquota is not None and destination.exists(): old_dest_size = waitForDeferred(destination.quotaSize(request)) yield old_dest_size old_dest_size = old_dest_size.getResult() else: old_dest_size = 0 if source is not None: sourcequota = waitForDeferred(source.quota(request)) yield sourcequota sourcequota = sourcequota.getResult() if sourcequota is not None and source.exists(): old_source_size = waitForDeferred(source.quotaSize(request)) yield old_source_size old_source_size = old_source_size.getResult() else: old_source_size = 0 else: sourcequota = None old_source_size = 0 # We may need to restore the original resource data if the PUT/COPY/MOVE fails, # so rename the original file in case we need to rollback. overwrite = destination.exists() if overwrite: rollback.destination_copy = FilePath(destination.fp.path) rollback.destination_copy.path += ".rollback" destination.fp.copyTo(rollback.destination_copy) else: rollback.destination_created = True if deletesource: rollback.source_copy = FilePath(source.fp.path) rollback.source_copy.path += ".rollback" source.fp.copyTo(rollback.source_copy) """ Handle actual store operations here. """ # Do put or copy based on whether source exists if source is not None: response = maybeDeferred(copy, source.fp, destination.fp, destination_uri, depth) else: datastream = request.stream if data is not None: datastream = MemoryStream(data) md5 = MD5Stream(datastream) response = maybeDeferred(put, md5, destination.fp) response = waitForDeferred(response) yield response response = response.getResult() # Update the MD5 value on the resource if source is not None: # Copy MD5 value from source to destination if source.hasDeadProperty(TwistedGETContentMD5): md5 = source.readDeadProperty(TwistedGETContentMD5) destination.writeDeadProperty(md5) else: # Finish MD5 calc and write dead property md5.close() md5 = md5.getMD5() destination.writeDeadProperty(TwistedGETContentMD5.fromString(md5)) # Update the content-type value on the resource if it is not been copied or moved if source is None: content_type = request.headers.getHeader("content-type") if content_type is not None: destination.writeDeadProperty(davxml.GETContentType.fromString(generateContentType(content_type))) response = IResponse(response) # Do quota check on destination if destquota is not None: # Get size of new/old resources new_dest_size = waitForDeferred(destination.quotaSize(request)) yield new_dest_size new_dest_size = new_dest_size.getResult() diff_size = new_dest_size - old_dest_size if diff_size >= destquota[0]: log.error("Over quota: available %d, need %d" % (destquota[0], diff_size)) raise HTTPError(ErrorResponse( responsecode.INSUFFICIENT_STORAGE_SPACE, (dav_namespace, "quota-not-exceeded") )) d = waitForDeferred(destination.quotaSizeAdjust(request, diff_size)) yield d d.getResult() if deletesource: # Delete the source resource if sourcequota is not None: delete_size = 0 - old_source_size d = waitForDeferred(source.quotaSizeAdjust(request, delete_size)) yield d d.getResult() delete(source_uri, source.fp, depth) rollback.source_deleted = True # Can now commit changes and forget the rollback details rollback.Commit() yield response return except: # Roll back changes to original server state. Note this may do nothing # if the rollback has already ocurred or changes already committed. rollback.Rollback() raise storeResource = deferredGenerator(storeResource) calendarserver-5.2+dfsg/twext/web2/dav/method/delete.py0000644000175000017500000000451412263343324022175 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_delete -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV DELETE method """ __all__ = ["http_DELETE"] from twisted.internet.defer import waitForDeferred, deferredGenerator from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import HTTPError from txdav.xml import element as davxml from twext.web2.dav.method.delete_common import deleteResource from twext.web2.dav.util import parentForURL log = Logger() def http_DELETE(self, request): """ Respond to a DELETE request. (RFC 2518, section 8.6) """ if not self.exists(): log.error("File not found: %s" % (self,)) raise HTTPError(responsecode.NOT_FOUND) depth = request.headers.getHeader("depth", "infinity") # # Check authentication and access controls # parent = waitForDeferred(request.locateResource(parentForURL(request.uri))) yield parent parent = parent.getResult() x = waitForDeferred(parent.authorize(request, (davxml.Unbind(),))) yield x x.getResult() x = waitForDeferred(deleteResource(request, self, request.uri, depth)) yield x yield x.getResult() http_DELETE = deferredGenerator(http_DELETE) calendarserver-5.2+dfsg/twext/web2/dav/method/prop_common.py0000644000175000017500000000741712147725751023301 0ustar rahulrahul## # Cyrus Daboo, cdaboo@apple.com # Copyright 2006-2012 Apple Computer, Inc. All Rights Reserved. ## __all__ = [ "responseForHref", "propertyListForResource", ] from twisted.internet.defer import deferredGenerator, waitForDeferred from twisted.python.failure import Failure from twext.python.log import Logger from twext.web2 import responsecode from txdav.xml import element from twext.web2.dav.http import statusForFailure from twext.web2.dav.method.propfind import propertyName log = Logger() def responseForHref(request, responses, href, resource, propertiesForResource, propertyreq): if propertiesForResource is not None: properties_by_status = waitForDeferred(propertiesForResource(request, propertyreq, resource)) yield properties_by_status properties_by_status = properties_by_status.getResult() propstats = [] for status in properties_by_status: properties = properties_by_status[status] if properties: xml_status = element.Status.fromResponseCode(status) xml_container = element.PropertyContainer(*properties) xml_propstat = element.PropertyStatus(xml_container, xml_status) propstats.append(xml_propstat) if propstats: responses.append(element.PropertyStatusResponse(href, *propstats)) else: responses.append( element.StatusResponse( href, element.Status.fromResponseCode(responsecode.OK), ) ) responseForHref = deferredGenerator(responseForHref) def propertyListForResource(request, prop, resource): """ Return the specified properties on the specified resource. @param request: the L{IRequest} for the current request. @param prop: the L{PropertyContainer} element for the properties of interest. @param resource: the L{DAVFile} for the targetted resource. @return: a map of OK and NOT FOUND property values. """ return _namedPropertiesForResource(request, prop.children, resource) def _namedPropertiesForResource(request, props, resource): """ Return the specified properties on the specified resource. @param request: the L{IRequest} for the current request. @param props: a list of property elements or qname tuples for the properties of interest. @param resource: the L{DAVFile} for the targetted resource. @return: a map of OK and NOT FOUND property values. """ properties_by_status = { responsecode.OK : [], responsecode.NOT_FOUND : [], } for property in props: if isinstance(property, element.WebDAVElement): qname = property.qname() else: qname = property props = waitForDeferred(resource.listProperties(request)) yield props props = props.getResult() if qname in props: try: prop = waitForDeferred(resource.readProperty(qname, request)) yield prop prop = prop.getResult() properties_by_status[responsecode.OK].append(prop) except: f = Failure() status = statusForFailure(f, "getting property: %s" % (qname,)) if status != responsecode.NOT_FOUND: log.error("Error reading property %r for resource %s: %s" % (qname, request.uri, f.value)) if status not in properties_by_status: properties_by_status[status] = [] properties_by_status[status].append(propertyName(qname)) else: properties_by_status[responsecode.NOT_FOUND].append(propertyName(qname)) yield properties_by_status _namedPropertiesForResource = deferredGenerator(_namedPropertiesForResource) calendarserver-5.2+dfsg/twext/web2/dav/method/get.py0000644000175000017500000000426112263343324021511 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_lock -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV GET and HEAD methods """ __all__ = ["http_OPTIONS", "http_HEAD", "http_GET"] import twext from txdav.xml import element as davxml from twext.web2.dav.util import parentForURL def http_OPTIONS(self, request): d = authorize(self, request) d.addCallback(lambda _: super(twext.web2.dav.resource.DAVResource, self).http_OPTIONS(request)) return d def http_HEAD(self, request): d = authorize(self, request) d.addCallback(lambda _: super(twext.web2.dav.resource.DAVResource, self).http_HEAD(request)) return d def http_GET(self, request): d = authorize(self, request) d.addCallback(lambda _: super(twext.web2.dav.resource.DAVResource, self).http_GET(request)) return d def authorize(self, request): if self.exists(): d = self.authorize(request, (davxml.Read(),)) else: d = request.locateResource(parentForURL(request.uri)) d.addCallback(lambda parent: parent.authorize(request, (davxml.Bind(),))) return d calendarserver-5.2+dfsg/twext/web2/dav/method/report_principal_match.py0000644000175000017500000002137412263343324025466 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_report_expand -*- ## # Copyright (c) 2006-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV principal-match report """ __all__ = ["report_DAV__principal_match"] from twisted.internet.defer import deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import StatusResponse, HTTPError from txdav.xml import element from txdav.xml.element import dav_namespace from twext.web2.dav.http import ErrorResponse, MultiStatusResponse from twext.web2.dav.method import prop_common from twext.web2.dav.method.report import NumberOfMatchesWithinLimits from twext.web2.dav.method.report import max_number_of_matches from twext.web2.dav.resource import isPrincipalResource log = Logger() def report_DAV__principal_match(self, request, principal_match): """ Generate a principal-match REPORT. (RFC 3744, section 9.3) """ # Verify root element if not isinstance(principal_match, element.PrincipalMatch): raise ValueError("%s expected as root element, not %s." % (element.PrincipalMatch.sname(), principal_match.sname())) # Only handle Depth: 0 depth = request.headers.getHeader("depth", "0") if depth != "0": log.error("Non-zero depth is not allowed: %s" % (depth,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,))) # Get a single DAV:prop element from the REPORT request body propertiesForResource = None propElement = None principalPropElement = None lookForPrincipals = True for child in principal_match.children: if child.qname() == (dav_namespace, "prop"): propertiesForResource = prop_common.propertyListForResource propElement = child elif child.qname() == (dav_namespace, "self"): lookForPrincipals = True elif child.qname() == (dav_namespace, "principal-property"): # Must have one and only one property in this element if len(child.children) != 1: log.error("Wrong number of properties in DAV:principal-property: %s" % (len(child.children),)) raise HTTPError(StatusResponse( responsecode.BAD_REQUEST, "DAV:principal-property must contain exactly one property" )) lookForPrincipals = False principalPropElement = child.children[0] # Run report for each referenced principal try: responses = [] matchcount = 0 myPrincipalURL = self.currentPrincipal(request).children[0] if lookForPrincipals: # Find the set of principals that represent "self". # First add "self" principal = waitForDeferred(request.locateResource(str(myPrincipalURL))) yield principal principal = principal.getResult() selfItems = [principal,] # Get group memberships for "self" and add each of those d = waitForDeferred(principal.groupMemberships()) yield d memberships = d.getResult() selfItems.extend(memberships) # Now add each principal found to the response provided the principal resource is a child of # the current resource. for principal in selfItems: # Get all the URIs that point to the principal resource # FIXME: making the assumption that the principalURL() is the URL of the resource we found principal_uris = [principal.principalURL()] principal_uris.extend(principal.alternateURIs()) # Compare each one to the request URI and return at most one that matches for uri in principal_uris: if uri.startswith(request.uri): # Check size of results is within limit matchcount += 1 if matchcount > max_number_of_matches: raise NumberOfMatchesWithinLimits(max_number_of_matches) d = waitForDeferred(prop_common.responseForHref( request, responses, element.HRef.fromString(uri), principal, propertiesForResource, propElement )) yield d d.getResult() break else: # Do some optimisation of access control calculation by determining any inherited ACLs outside of # the child resource loop and supply those to the checkPrivileges on each child. filteredaces = waitForDeferred(self.inheritedACEsforChildren(request)) yield filteredaces filteredaces = filteredaces.getResult() children = [] d = waitForDeferred(self.findChildren("infinity", request, lambda x, y: children.append((x,y)), privileges=(element.Read(),), inherited_aces=filteredaces)) yield d d.getResult() for child, uri in children: # Try to read the requested property from this resource try: prop = waitForDeferred(child.readProperty(principalPropElement.qname(), request)) yield prop prop = prop.getResult() if prop: prop.removeWhitespaceNodes() if prop and len(prop.children) == 1 and isinstance(prop.children[0], element.HRef): # Find principal associated with this property and test it principal = waitForDeferred(request.locateResource(str(prop.children[0]))) yield principal principal = principal.getResult() if principal and isPrincipalResource(principal): d = waitForDeferred(principal.principalMatch(myPrincipalURL)) yield d matched = d.getResult() if matched: # Check size of results is within limit matchcount += 1 if matchcount > max_number_of_matches: raise NumberOfMatchesWithinLimits(max_number_of_matches) d = waitForDeferred(prop_common.responseForHref( request, responses, element.HRef.fromString(uri), child, propertiesForResource, propElement )) yield d d.getResult() except HTTPError: # Just ignore a failure to access the property. We treat this like a property that does not exist # or does not match the principal. pass except NumberOfMatchesWithinLimits: log.error("Too many matching components in principal-match report") raise HTTPError(ErrorResponse( responsecode.FORBIDDEN, element.NumberOfMatchesWithinLimits() )) yield MultiStatusResponse(responses) report_DAV__principal_match = deferredGenerator(report_DAV__principal_match) calendarserver-5.2+dfsg/twext/web2/dav/method/acl.py0000644000175000017500000000625012263343324021471 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_lock -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV ACL method """ __all__ = ["http_ACL"] from twisted.internet.defer import deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import StatusResponse, HTTPError from txdav.xml import element as davxml from twext.web2.dav.http import ErrorResponse from twext.web2.dav.util import davXMLFromStream log = Logger() def http_ACL(self, request): """ Respond to a ACL request. (RFC 3744, section 8.1) """ if not self.exists(): log.error("File not found: %s" % (self,)) yield responsecode.NOT_FOUND return # # Check authentication and access controls # x = waitForDeferred(self.authorize(request, (davxml.WriteACL(),))) yield x x.getResult() # # Read request body # doc = waitForDeferred(davXMLFromStream(request.stream)) yield doc try: doc = doc.getResult() except ValueError, e: log.error("Error while handling ACL body: %s" % (e,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, str(e))) # # Set properties # if doc is None: error = "Request XML body is required." log.error("Error: {err}", err=error) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error)) # # Parse request # acl = doc.root_element if not isinstance(acl, davxml.ACL): error = ("Request XML body must be an acl element." % (davxml.PropertyUpdate.sname(),)) log.error("Error: {err}", err=error) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error)) # # Do ACL merger # result = waitForDeferred(self.mergeAccessControlList(acl, request)) yield result result = result.getResult() # # Return response # if result is None: yield responsecode.OK else: yield ErrorResponse(responsecode.FORBIDDEN, result) http_ACL = deferredGenerator(http_ACL) calendarserver-5.2+dfsg/twext/web2/dav/method/report_principal_search_property_set.py0000644000175000017500000000554712263343324030462 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_report_expand -*- ## # Copyright (c) 2006-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV principal-search-property-set report """ __all__ = ["report_DAV__principal_search_property_set"] from twisted.internet.defer import deferredGenerator from twext.python.log import Logger from twext.web2 import responsecode from txdav.xml import element as davxml from twext.web2.http import HTTPError, Response, StatusResponse from twext.web2.stream import MemoryStream log = Logger() def report_DAV__principal_search_property_set(self, request, principal_search_property_set): """ Generate a principal-search-property-set REPORT. (RFC 3744, section 9.5) """ # Verify root element if not isinstance(principal_search_property_set, davxml.PrincipalSearchPropertySet): raise ValueError("%s expected as root element, not %s." % (davxml.PrincipalSearchPropertySet.sname(), principal_search_property_set.sname())) # Only handle Depth: 0 depth = request.headers.getHeader("depth", "0") if depth != "0": log.error("Error in principal-search-property-set REPORT, Depth set to %s" % (depth,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,))) # Get details from the resource result = self.principalSearchPropertySet() if result is None: log.error("Error in principal-search-property-set REPORT not supported on: %s" % (self,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Not allowed on this resource")) yield Response(code=responsecode.OK, stream=MemoryStream(result.toxml())) report_DAV__principal_search_property_set = deferredGenerator(report_DAV__principal_search_property_set) calendarserver-5.2+dfsg/twext/web2/dav/method/report.py0000644000175000017500000001065712263343324022253 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_report -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV REPORT method """ __all__ = [ "max_number_of_matches", "NumberOfMatchesWithinLimits", "http_REPORT", ] import string from twisted.internet.defer import deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import HTTPError, StatusResponse from twext.web2.dav.http import ErrorResponse from twext.web2.dav.util import davXMLFromStream from txdav.xml import element as davxml from txdav.xml.element import lookupElement from txdav.xml.base import encodeXMLName log = Logger() max_number_of_matches = 500 class NumberOfMatchesWithinLimits(Exception): def __init__(self, limit): super(NumberOfMatchesWithinLimits, self).__init__() self.limit = limit def maxLimit(self): return self.limit def http_REPORT(self, request): """ Respond to a REPORT request. (RFC 3253, section 3.6) """ if not self.exists(): log.error("File not found: %s" % (self,)) raise HTTPError(responsecode.NOT_FOUND) # # Check authentication and access controls # x = waitForDeferred(self.authorize(request, (davxml.Read(),))) yield x x.getResult() # # Read request body # try: doc = waitForDeferred(davXMLFromStream(request.stream)) yield doc doc = doc.getResult() except ValueError, e: log.error("Error while handling REPORT body: %s" % (e,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, str(e))) if doc is None: raise HTTPError(StatusResponse( responsecode.BAD_REQUEST, "REPORT request body may not be empty" )) # # Parse request # namespace = doc.root_element.namespace name = doc.root_element.name ok = string.ascii_letters + string.digits + "_" def to_method(s): out = [] for c in s: if c in ok: out.append(c) else: out.append("_") return "report_" + "".join(out) if namespace: method_name = to_method("_".join((namespace, name))) if namespace == davxml.dav_namespace: request.submethod = "DAV:" + name else: request.submethod = encodeXMLName(namespace, name) else: method_name = to_method(name) request.submethod = name try: method = getattr(self, method_name) # Also double-check via supported-reports property reports = self.supportedReports() test = lookupElement((namespace, name)) if not test: raise AttributeError() test = davxml.Report(test()) if test not in reports: raise AttributeError() except AttributeError: # # Requested report is not supported. # log.error("Unsupported REPORT %s for resource %s (no method %s)" % (encodeXMLName(namespace, name), self, method_name)) raise HTTPError(ErrorResponse( responsecode.FORBIDDEN, davxml.SupportedReport() )) d = waitForDeferred(method(request, doc.root_element)) yield d yield d.getResult() http_REPORT = deferredGenerator(http_REPORT) calendarserver-5.2+dfsg/twext/web2/dav/method/proppatch.py0000644000175000017500000001637612263343324022744 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_prop.PROP.test_PROPPATCH -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV-aware static resources. """ __all__ = ["http_PROPPATCH"] from twisted.python.failure import Failure from twisted.internet.defer import deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.http import HTTPError, StatusResponse from txdav.xml import element as davxml from twext.web2.dav.http import MultiStatusResponse, PropertyStatusResponseQueue from twext.web2.dav.util import davXMLFromStream log = Logger() def http_PROPPATCH(self, request): """ Respond to a PROPPATCH request. (RFC 2518, section 8.2) """ if not self.exists(): log.error("File not found: %s" % (self,)) raise HTTPError(responsecode.NOT_FOUND) x = waitForDeferred(self.authorize(request, (davxml.WriteProperties(),))) yield x x.getResult() # # Read request body # try: doc = waitForDeferred(davXMLFromStream(request.stream)) yield doc doc = doc.getResult() except ValueError, e: log.error("Error while handling PROPPATCH body: %s" % (e,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, str(e))) if doc is None: error = "Request XML body is required." log.error("Error: {err}", error) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error)) # # Parse request # update = doc.root_element if not isinstance(update, davxml.PropertyUpdate): error = ("Request XML body must be a propertyupdate element." % (davxml.PropertyUpdate.sname(),)) log.error("Error: {err}", error) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error)) responses = PropertyStatusResponseQueue("PROPPATCH", request.uri, responsecode.NO_CONTENT) undoActions = [] gotError = False # Look for Prefer header prefer = request.headers.getHeader("prefer", {}) returnMinimal = any([key == "return" and value == "minimal" for key, value, _ignore_args in prefer]) try: # # Update properties # for setOrRemove in update.children: assert len(setOrRemove.children) == 1 container = setOrRemove.children[0] assert isinstance(container, davxml.PropertyContainer) properties = container.children def do(action, property, removing=False): """ Perform action(property, request) while maintaining an undo queue. """ has = waitForDeferred(self.hasProperty(property, request)) yield has has = has.getResult() if has: oldProperty = waitForDeferred(self.readProperty(property, request)) yield oldProperty oldProperty = oldProperty.getResult() def undo(): return self.writeProperty(oldProperty, request) else: def undo(): return self.removeProperty(property, request) try: x = waitForDeferred(action(property, request)) yield x x.getResult() except KeyError, e: # Removing a non-existent property is OK according to WebDAV if removing: responses.add(responsecode.OK, property) yield True return else: # Convert KeyError exception into HTTPError responses.add( Failure(exc_value=HTTPError(StatusResponse(responsecode.FORBIDDEN, str(e)))), property ) yield False return except: responses.add(Failure(), property) yield False return else: responses.add(responsecode.OK, property) # Only add undo action for those that succeed because those that fail will not have changed undoActions.append(undo) yield True return do = deferredGenerator(do) if isinstance(setOrRemove, davxml.Set): for property in properties: ok = waitForDeferred(do(self.writeProperty, property)) yield ok ok = ok.getResult() if not ok: gotError = True elif isinstance(setOrRemove, davxml.Remove): for property in properties: ok = waitForDeferred(do(self.removeProperty, property, True)) yield ok ok = ok.getResult() if not ok: gotError = True else: raise AssertionError("Unknown child of PropertyUpdate: %s" % (setOrRemove,)) except: # # If there is an error, we have to back out whatever we have # operations we have done because PROPPATCH is an # all-or-nothing request. # We handle the first one here, and then re-raise to handle the # rest in the containing scope. # for action in undoActions: x = waitForDeferred(action()) yield x x.getResult() raise # # If we had an error we need to undo any changes that did succeed and change status of # those to 424 Failed Dependency. # if gotError: for action in undoActions: x = waitForDeferred(action()) yield x x.getResult() responses.error() # # Return response - use 200 if Prefer:return=minimal set and no errors # if returnMinimal and not gotError: yield responsecode.OK else: yield MultiStatusResponse([responses.response()]) http_PROPPATCH = deferredGenerator(http_PROPPATCH) calendarserver-5.2+dfsg/twext/web2/dav/method/report_expand.py0000644000175000017500000001643512263343324023612 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_report_expand -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. ## """ WebDAV expand-property report """ __all__ = ["report_DAV__expand_property"] from twisted.internet.defer import inlineCallbacks, returnValue from twisted.python.failure import Failure from twext.python.log import Logger from twext.web2 import responsecode from txdav.xml import element from txdav.xml.element import dav_namespace from twext.web2.dav.http import statusForFailure, MultiStatusResponse from twext.web2.dav.method import prop_common from twext.web2.dav.method.propfind import propertyName from twext.web2.dav.resource import AccessDeniedError from twext.web2.dav.util import parentForURL from twext.web2.http import HTTPError, StatusResponse log = Logger() @inlineCallbacks def report_DAV__expand_property(self, request, expand_property): """ Generate an expand-property REPORT. (RFC 3253, section 3.8) TODO: for simplicity we will only support one level of expansion. """ # Verify root element if not isinstance(expand_property, element.ExpandProperty): raise ValueError("%s expected as root element, not %s." % (element.ExpandProperty.sname(), expand_property.sname())) # Only handle Depth: 0 depth = request.headers.getHeader("depth", "0") if depth != "0": log.error("Non-zero depth is not allowed: %s" % (depth,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,))) # # Get top level properties to expand and make sure we only have one level # properties = {} for property in expand_property.children: namespace = property.attributes.get("namespace", dav_namespace) name = property.attributes.get("name", "") # Make sure children have no children props_to_find = [] for child in property.children: if child.children: log.error("expand-property REPORT only supports single level expansion") raise HTTPError(StatusResponse( responsecode.NOT_IMPLEMENTED, "expand-property REPORT only supports single level expansion" )) child_namespace = child.attributes.get("namespace", dav_namespace) child_name = child.attributes.get("name", "") props_to_find.append((child_namespace, child_name)) properties[(namespace, name)] = props_to_find # # Generate the expanded responses status for each top-level property # properties_by_status = { responsecode.OK : [], responsecode.NOT_FOUND : [], } filteredaces = None lastParent = None for qname in properties.iterkeys(): try: prop = (yield self.readProperty(qname, request)) # Form the PROPFIND-style DAV:prop element we need later props_to_return = element.PropertyContainer(*properties[qname]) # Now dereference any HRefs responses = [] for href in prop.children: if isinstance(href, element.HRef): # Locate the Href resource and its parent resource_uri = str(href) child = (yield request.locateResource(resource_uri)) if not child or not child.exists(): responses.append(element.StatusResponse(href, element.Status.fromResponseCode(responsecode.NOT_FOUND))) continue parent = (yield request.locateResource(parentForURL(resource_uri))) # Check privileges on parent - must have at least DAV:read try: yield parent.checkPrivileges(request, (element.Read(),)) except AccessDeniedError: responses.append(element.StatusResponse(href, element.Status.fromResponseCode(responsecode.FORBIDDEN))) continue # Cache the last parent's inherited aces for checkPrivileges optimization if lastParent != parent: lastParent = parent # Do some optimisation of access control calculation by determining any inherited ACLs outside of # the child resource loop and supply those to the checkPrivileges on each child. filteredaces = (yield parent.inheritedACEsforChildren(request)) # Check privileges - must have at least DAV:read try: yield child.checkPrivileges(request, (element.Read(),), inherited_aces=filteredaces) except AccessDeniedError: responses.append(element.StatusResponse(href, element.Status.fromResponseCode(responsecode.FORBIDDEN))) continue # Now retrieve all the requested properties on the HRef resource yield prop_common.responseForHref( request, responses, href, child, prop_common.propertyListForResource, props_to_return, ) prop.children = responses properties_by_status[responsecode.OK].append(prop) except: f = Failure() log.error("Error reading property %r for resource %s: %s" % (qname, request.uri, f.value)) status = statusForFailure(f, "getting property: %s" % (qname,)) if status not in properties_by_status: properties_by_status[status] = [] properties_by_status[status].append(propertyName(qname)) # Build the overall response propstats = [ element.PropertyStatus( element.PropertyContainer(*properties_by_status[status]), element.Status.fromResponseCode(status) ) for status in properties_by_status if properties_by_status[status] ] returnValue(MultiStatusResponse((element.PropertyStatusResponse(element.HRef(request.uri), *propstats),))) calendarserver-5.2+dfsg/twext/web2/dav/method/__init__.py0000644000175000017500000000320612263343324022467 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV methods. Modules in this package provide the implementation of twext.web2.dav.static.DAVFile's dispatched methods. """ __all__ = [ "acl", "copymove", "delete", "get", "lock", "mkcol", "propfind", "proppatch", "prop_common", "put", "put_common", "report", "report_acl_principal_prop_set", "report_expand", "report_principal_match", "report_principal_property_search", "report_principal_search_property_set", ] calendarserver-5.2+dfsg/twext/web2/dav/method/lock.py0000644000175000017500000000315212263343324021660 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_lock -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV LOCK and UNLOCK methods """ __all__ = ["http_LOCK", "http_UNLOCK"] from twext.web2 import responsecode def http_LOCK(self, request): """ Respond to a LOCK request. (RFC 2518, section 8.10) """ return responsecode.NOT_IMPLEMENTED def http_UNLOCK(self, request): """ Respond to a UNLOCK request. (RFC 2518, section 8.11) """ return responsecode.NOT_IMPLEMENTED calendarserver-5.2+dfsg/twext/web2/dav/method/propfind.py0000644000175000017500000002133712263343324022556 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_prop.PROP.test_PROPFIND -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV PROPFIND method """ __all__ = [ "http_PROPFIND", "propertyName", ] from twisted.python.failure import Failure from twisted.internet.defer import deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2.http import HTTPError from twext.web2 import responsecode from twext.web2.http import StatusResponse from txdav.xml import element as davxml from twext.web2.dav.http import MultiStatusResponse, statusForFailure, \ ErrorResponse from twext.web2.dav.util import normalizeURL, davXMLFromStream log = Logger() def http_PROPFIND(self, request): """ Respond to a PROPFIND request. (RFC 2518, section 8.1) """ if not self.exists(): log.error("File not found: %s" % (self,)) raise HTTPError(responsecode.NOT_FOUND) # # Check authentication and access controls # x = waitForDeferred(self.authorize(request, (davxml.Read(),))) yield x x.getResult() # # Read request body # try: doc = waitForDeferred(davXMLFromStream(request.stream)) yield doc doc = doc.getResult() except ValueError, e: log.error("Error while handling PROPFIND body: %s" % (e,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, str(e))) if doc is None: # No request body means get all properties. search_properties = "all" else: # # Parse request # find = doc.root_element if not isinstance(find, davxml.PropertyFind): error = ("Non-%s element in PROPFIND request body: %s" % (davxml.PropertyFind.sname(), find)) log.error("Error: {err}", err=error) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error)) container = find.children[0] if isinstance(container, davxml.AllProperties): # Get all properties search_properties = "all" elif isinstance(container, davxml.PropertyName): # Get names only search_properties = "names" elif isinstance(container, davxml.PropertyContainer): properties = container.children search_properties = [(p.namespace, p.name) for p in properties] else: raise AssertionError("Unexpected element type in %s: %s" % (davxml.PropertyFind.sname(), container)) # # Generate XML output stream # request_uri = request.uri depth = request.headers.getHeader("depth", "infinity") # By policy we will never allow a depth:infinity propfind if depth == "infinity": raise HTTPError(ErrorResponse(responsecode.FORBIDDEN, davxml.PropfindFiniteDepth())) # Look for Prefer header first, then try Brief prefer = request.headers.getHeader("prefer", {}) returnMinimal = any([key == "return" and value == "minimal" for key, value, _ignore_args in prefer]) noRoot = any([key == "depth-noroot" and value is None for key, value, _ignore_args in prefer]) if not returnMinimal: returnMinimal = request.headers.getHeader("brief", False) xml_responses = [] # FIXME: take advantage of the new generative properties of findChildren my_url = normalizeURL(request_uri) if self.isCollection() and not my_url.endswith("/"): my_url += "/" # Do some optimisation of access control calculation by determining any inherited ACLs outside of # the child resource loop and supply those to the checkPrivileges on each child. filtered_aces = waitForDeferred(self.inheritedACEsforChildren(request)) yield filtered_aces filtered_aces = filtered_aces.getResult() if depth in ("1", "infinity") and noRoot: resources = [] else: resources = [(self, my_url)] d = self.findChildren(depth, request, lambda x, y: resources.append((x, y)), (davxml.Read(),), inherited_aces=filtered_aces) x = waitForDeferred(d) yield x x.getResult() for resource, uri in resources: if search_properties is "names": try: resource_properties = waitForDeferred(resource.listProperties(request)) yield resource_properties resource_properties = resource_properties.getResult() except: log.error("Unable to get properties for resource %r" % (resource,)) raise properties_by_status = { responsecode.OK: [propertyName(p) for p in resource_properties] } else: properties_by_status = { responsecode.OK : [], responsecode.NOT_FOUND : [], } if search_properties is "all": properties_to_enumerate = waitForDeferred(resource.listAllprop(request)) yield properties_to_enumerate properties_to_enumerate = properties_to_enumerate.getResult() else: properties_to_enumerate = search_properties for property in properties_to_enumerate: has = waitForDeferred(resource.hasProperty(property, request)) yield has has = has.getResult() if has: try: resource_property = waitForDeferred(resource.readProperty(property, request)) yield resource_property resource_property = resource_property.getResult() except: f = Failure() status = statusForFailure(f, "getting property: %s" % (property,)) if status not in properties_by_status: properties_by_status[status] = [] if not returnMinimal or status != responsecode.NOT_FOUND: properties_by_status[status].append(propertyName(property)) else: if resource_property is not None: properties_by_status[responsecode.OK].append(resource_property) elif not returnMinimal: properties_by_status[responsecode.NOT_FOUND].append(propertyName(property)) elif not returnMinimal: properties_by_status[responsecode.NOT_FOUND].append(propertyName(property)) propstats = [] for status in properties_by_status: properties = properties_by_status[status] if not properties: continue xml_status = davxml.Status.fromResponseCode(status) xml_container = davxml.PropertyContainer(*properties) xml_propstat = davxml.PropertyStatus(xml_container, xml_status) propstats.append(xml_propstat) # Always need to have at least one propstat present (required by Prefer header behavior) if len(propstats) == 0: propstats.append(davxml.PropertyStatus( davxml.PropertyContainer(), davxml.Status.fromResponseCode(responsecode.OK) )) xml_resource = davxml.HRef(uri) xml_response = davxml.PropertyStatusResponse(xml_resource, *propstats) xml_responses.append(xml_response) # # Return response # yield MultiStatusResponse(xml_responses) http_PROPFIND = deferredGenerator(http_PROPFIND) ## # Utilities ## def propertyName(name): property_namespace, property_name = name pname = davxml.WebDAVUnknownElement() pname.namespace = property_namespace pname.name = property_name return pname calendarserver-5.2+dfsg/twext/web2/dav/resource.py0000644000175000017500000030111012263343324021272 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_resource -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## from __future__ import print_function """ WebDAV resources. """ __all__ = [ "DAVPropertyMixIn", "DAVResource", "DAVLeafResource", "DAVPrincipalResource", "DAVPrincipalCollectionResource", "AccessDeniedError", "isPrincipalResource", "TwistedACLInheritable", "TwistedGETContentMD5", "TwistedQuotaRootProperty", "allACL", "readonlyACL", "davPrivilegeSet", "unauthenticatedPrincipal", ] import cPickle as pickle import urllib from zope.interface import implements from twisted.cred.error import LoginFailed, UnauthorizedLogin from twisted.python.failure import Failure from twisted.internet.defer import ( Deferred, maybeDeferred, succeed, inlineCallbacks, returnValue ) from twisted.internet import reactor from twext.python.log import Logger from txdav.xml import element from txdav.xml.base import encodeXMLName from txdav.xml.element import WebDAVElement, WebDAVEmptyElement, WebDAVTextElement from txdav.xml.element import dav_namespace from txdav.xml.element import twisted_dav_namespace, twisted_private_namespace from txdav.xml.element import registerElement, lookupElement from twext.web2 import responsecode from twext.web2.http import HTTPError, RedirectResponse, StatusResponse from twext.web2.http_headers import generateContentType from twext.web2.iweb import IResponse from twext.web2.resource import LeafResource from twext.web2.server import NoURLForResourceError from twext.web2.static import MetaDataMixin, StaticRenderMixin from twext.web2.auth.wrapper import UnauthorizedResponse from twext.web2.dav.idav import IDAVResource, IDAVPrincipalResource, IDAVPrincipalCollectionResource from twext.web2.dav.http import NeedPrivilegesResponse from twext.web2.dav.noneprops import NonePropertyStore from twext.web2.dav.util import unimplemented, parentForURL, joinURL from twext.web2.dav.auth import PrincipalCredentials from twistedcaldav import customxml log = Logger() class DAVPropertyMixIn (MetaDataMixin): """ Mix-in class which implements the DAV property access API in L{IDAVResource}. There are three categories of DAV properties, for the purposes of how this class manages them. A X{property} is either a X{live property} or a X{dead property}, and live properties are split into two categories: 1. Dead properties. There are properties that the server simply stores as opaque data. These are store in the X{dead property store}, which is provided by subclasses via the L{deadProperties} method. 2. Live properties which are always computed. These properties aren't stored anywhere (by this class) but instead are derived from the resource state or from data that is persisted elsewhere. These are listed in the L{liveProperties} attribute and are handled explicitly by the L{readProperty} method. 3. Live properties may be acted on specially and are stored in the X{dead property store}. These are not listed in the L{liveProperties} attribute, but may be handled specially by the property access methods. For example, L{writeProperty} might validate the data and refuse to write data it deems inappropriate for a given property. There are two sets of property access methods. The first group (L{hasProperty}, etc.) provides access to all properties. They automatically figure out which category a property falls into and act accordingly. The second group (L{hasDeadProperty}, etc.) accesses the dead property store directly and bypasses any live property logic that exists in the first group of methods. These methods are used by the first group of methods, and there are cases where they may be needed by other methods. I{Accessing dead properties directly should be done with caution.} Bypassing the live property logic means that values may not be the correct ones for use in DAV requests such as PROPFIND, and may be bypassing security checks. In general, one should never bypass the live property logic as part of a client request for property data. Properties in the L{twisted_private_namespace} namespace are internal to the server and should not be exposed to clients. They can only be accessed via the dead property store. """ # Note: The DAV:owner and DAV:group live properties are only # meaningful if you are using ACL semantics (ie. Unix-like) which # use them. This (generic) class does not. def liveProperties(self): return ( (dav_namespace, "resourcetype"), (dav_namespace, "getetag"), (dav_namespace, "getcontenttype"), (dav_namespace, "getcontentlength"), (dav_namespace, "getlastmodified"), (dav_namespace, "creationdate"), (dav_namespace, "displayname"), (dav_namespace, "supportedlock"), (dav_namespace, "supported-report-set"), # RFC 3253, section 3.1.5 #(dav_namespace, "owner" ), # RFC 3744, section 5.1 #(dav_namespace, "group" ), # RFC 3744, section 5.2 (dav_namespace, "supported-privilege-set"), # RFC 3744, section 5.3 (dav_namespace, "current-user-privilege-set"), # RFC 3744, section 5.4 (dav_namespace, "current-user-principal"), # RFC 5397, Section 3 (dav_namespace, "acl"), # RFC 3744, section 5.5 (dav_namespace, "acl-restrictions"), # RFC 3744, section 5.6 (dav_namespace, "inherited-acl-set"), # RFC 3744, section 5.7 (dav_namespace, "principal-collection-set"), # RFC 3744, section 5.8 (dav_namespace, "quota-available-bytes"), # RFC 4331, section 3 (dav_namespace, "quota-used-bytes"), # RFC 4331, section 4 (twisted_dav_namespace, "resource-class"), ) def deadProperties(self): """ Provides internal access to the WebDAV dead property store. You probably shouldn't be calling this directly if you can use the property accessors in the L{IDAVResource} API instead. However, a subclass must override this method to provide it's own dead property store. This implementation returns an instance of L{NonePropertyStore}, which cannot store dead properties. Subclasses must override this method if they wish to store dead properties. @return: a dict-like object from which one can read and to which one can write dead properties. Keys are qname tuples (i.e. C{(namespace, name)}) as returned by L{WebDAVElement.qname()} and values are L{WebDAVElement} instances. """ if not hasattr(self, "_dead_properties"): self._dead_properties = NonePropertyStore(self) return self._dead_properties def hasProperty(self, property, request): """ See L{IDAVResource.hasProperty}. """ if type(property) is tuple: qname = property else: qname = (property.namespace, property.name) if qname[0] == twisted_private_namespace: return succeed(False) # Need to special case the dynamic live properties namespace, name = qname if namespace == dav_namespace: if name in ("quota-available-bytes", "quota-used-bytes"): d = self.hasQuota(request) d.addCallback(lambda result: result) return d return succeed( qname in self.liveProperties() or self.deadProperties().contains(qname) ) def readProperty(self, property, request): """ See L{IDAVResource.readProperty}. """ @inlineCallbacks def defer(): if type(property) is tuple: qname = property sname = encodeXMLName(*property) else: qname = property.qname() sname = property.sname() namespace, name = qname if namespace == dav_namespace: if name == "resourcetype": # Allow live property to be overridden by dead property if self.deadProperties().contains(qname): returnValue(self.deadProperties().get(qname)) if self.isCollection(): returnValue(element.ResourceType.collection) #@UndefinedVariable returnValue(element.ResourceType.empty) #@UndefinedVariable if name == "getetag": etag = (yield self.etag()) if etag is None: returnValue(None) returnValue(element.GETETag(etag.generate())) if name == "getcontenttype": mimeType = self.contentType() if mimeType is None: returnValue(None) returnValue(element.GETContentType(generateContentType(mimeType))) if name == "getcontentlength": length = self.contentLength() if length is None: # TODO: really we should "render" the resource and # determine its size from that but for now we just # return an empty element. returnValue(element.GETContentLength("")) else: returnValue(element.GETContentLength(str(length))) if name == "getlastmodified": lastModified = self.lastModified() if lastModified is None: returnValue(None) returnValue(element.GETLastModified.fromDate(lastModified)) if name == "creationdate": creationDate = self.creationDate() if creationDate is None: returnValue(None) returnValue(element.CreationDate.fromDate(creationDate)) if name == "displayname": displayName = self.displayName() if displayName is None: returnValue(None) returnValue(element.DisplayName(displayName)) if name == "supportedlock": returnValue(element.SupportedLock( element.LockEntry( element.LockScope.exclusive, #@UndefinedVariable element.LockType.write #@UndefinedVariable ), element.LockEntry( element.LockScope.shared, #@UndefinedVariable element.LockType.write #@UndefinedVariable ), )) if name == "supported-report-set": returnValue(element.SupportedReportSet(*[ element.SupportedReport(report,) for report in self.supportedReports() ])) if name == "supported-privilege-set": returnValue((yield self.supportedPrivileges(request))) if name == "acl-restrictions": returnValue(element.ACLRestrictions()) if name == "inherited-acl-set": returnValue(element.InheritedACLSet(*self.inheritedACLSet())) if name == "principal-collection-set": returnValue(element.PrincipalCollectionSet(*[ element.HRef( principalCollection.principalCollectionURL() ) for principalCollection in self.principalCollections() ])) @inlineCallbacks def ifAllowed(privileges, callback): try: yield self.checkPrivileges(request, privileges) result = yield callback() except AccessDeniedError: raise HTTPError(StatusResponse( responsecode.UNAUTHORIZED, "Access denied while reading property %s." % (sname,) )) returnValue(result) if name == "current-user-privilege-set": @inlineCallbacks def callback(): privs = yield self.currentPrivileges(request) returnValue(element.CurrentUserPrivilegeSet(*privs)) returnValue((yield ifAllowed( (element.ReadCurrentUserPrivilegeSet(),), callback ))) if name == "acl": @inlineCallbacks def callback(): acl = yield self.accessControlList(request) if acl is None: acl = element.ACL() returnValue(acl) returnValue( (yield ifAllowed((element.ReadACL(),), callback)) ) if name == "current-user-principal": returnValue(element.CurrentUserPrincipal( self.currentPrincipal(request).children[0] )) if name == "quota-available-bytes": qvalue = yield self.quota(request) if qvalue is None: raise HTTPError(StatusResponse( responsecode.NOT_FOUND, "Property %s does not exist." % (sname,) )) else: returnValue(element.QuotaAvailableBytes(str(qvalue[0]))) if name == "quota-used-bytes": qvalue = yield self.quota(request) if qvalue is None: raise HTTPError(StatusResponse( responsecode.NOT_FOUND, "Property %s does not exist." % (sname,) )) else: returnValue(element.QuotaUsedBytes(str(qvalue[1]))) elif namespace == twisted_dav_namespace: if name == "resource-class": returnValue(ResourceClass(self.__class__.__name__)) elif namespace == twisted_private_namespace: raise HTTPError(StatusResponse( responsecode.FORBIDDEN, "Properties in the %s namespace are private to the server." % (sname,) )) returnValue(self.deadProperties().get(qname)) return defer() def writeProperty(self, property, request): """ See L{IDAVResource.writeProperty}. """ assert isinstance(property, WebDAVElement), ( "Not a property: %r" % (property,) ) def defer(): if property.protected: raise HTTPError(StatusResponse( responsecode.FORBIDDEN, "Protected property %s may not be set." % (property.sname(),) )) if property.namespace == twisted_private_namespace: raise HTTPError(StatusResponse( responsecode.FORBIDDEN, "Properties in the %s namespace are private to the server." % (property.sname(),) )) return self.deadProperties().set(property) return maybeDeferred(defer) def removeProperty(self, property, request): """ See L{IDAVResource.removeProperty}. """ def defer(): if type(property) is tuple: qname = property sname = encodeXMLName(*property) else: qname = property.qname() sname = property.sname() if qname in self.liveProperties(): raise HTTPError(StatusResponse( responsecode.FORBIDDEN, "Live property %s cannot be deleted." % (sname,) )) if qname[0] == twisted_private_namespace: raise HTTPError(StatusResponse( responsecode.FORBIDDEN, "Properties in the %s namespace are private to the server." % (qname[0],) )) return self.deadProperties().delete(qname) return maybeDeferred(defer) @inlineCallbacks def listProperties(self, request): """ See L{IDAVResource.listProperties}. """ qnames = set(self.liveProperties()) # Add dynamic live properties that exist dynamicLiveProperties = ( (dav_namespace, "quota-available-bytes"), (dav_namespace, "quota-used-bytes"), ) for dqname in dynamicLiveProperties: has = (yield self.hasProperty(dqname, request)) if not has: qnames.remove(dqname) for qname in self.deadProperties().list(): if ( qname not in qnames and qname[0] != twisted_private_namespace ): qnames.add(qname) returnValue(qnames) def listAllprop(self, request): """ Some DAV properties should not be returned to a C{DAV:allprop} query. RFC 3253 defines several such properties. This method computes a subset of the property qnames returned by L{listProperties} by filtering out elements whose class have the C{.hidden} attribute set to C{True}. @return: a list of qnames of properties which are defined and are appropriate for use in response to a C{DAV:allprop} query. """ def doList(qnames): result = [] for qname in qnames: try: if not lookupElement(qname).hidden: result.append(qname) except KeyError: # Unknown element result.append(qname) return result d = self.listProperties(request) d.addCallback(doList) return d def hasDeadProperty(self, property): """ Same as L{hasProperty}, but bypasses the live property store and checks directly from the dead property store. """ if type(property) is tuple: qname = property else: qname = property.qname() return self.deadProperties().contains(qname) def readDeadProperty(self, property): """ Same as L{readProperty}, but bypasses the live property store and reads directly from the dead property store. """ if type(property) is tuple: qname = property else: qname = property.qname() return self.deadProperties().get(qname) def writeDeadProperty(self, property): """ Same as L{writeProperty}, but bypasses the live property store and writes directly to the dead property store. Note that this should not be used unless you know that you are writing to an overrideable live property, as this bypasses the logic which protects protected properties. The result of writing to a non-overrideable live property with this method is undefined; the value in the dead property store may or may not be ignored when reading the property with L{readProperty}. """ self.deadProperties().set(property) def removeDeadProperty(self, property): """ Same as L{removeProperty}, but bypasses the live property store and acts directly on the dead property store. """ if self.hasDeadProperty(property): if type(property) is tuple: qname = property else: qname = property.qname() self.deadProperties().delete(qname) # # Overrides some methods in MetaDataMixin in order to allow DAV properties # to override the values of some HTTP metadata. # def contentType(self): if self.hasDeadProperty((element.dav_namespace, "getcontenttype")): return self.readDeadProperty( (element.dav_namespace, "getcontenttype") ).mimeType() else: return super(DAVPropertyMixIn, self).contentType() def displayName(self): if self.hasDeadProperty((element.dav_namespace, "displayname")): return str(self.readDeadProperty( (element.dav_namespace, "displayname") )) else: return super(DAVPropertyMixIn, self).displayName() class DAVResource (DAVPropertyMixIn, StaticRenderMixin): """ WebDAV resource. """ implements(IDAVResource) def __init__(self, principalCollections=None): """ @param principalCollections: an iterable of L{IDAVPrincipalCollectionResource}s which contain principals to be used in ACLs for this resource. """ if principalCollections is not None: self._principalCollections = frozenset([ IDAVPrincipalCollectionResource(principalCollection) for principalCollection in principalCollections ]) ## # DAV ## def davComplianceClasses(self): """ This implementation raises L{NotImplementedError}. @return: a sequence of strings denoting WebDAV compliance classes. For example, a DAV level 2 server might return ("1", "2"). """ unimplemented(self) def isCollection(self): """ See L{IDAVResource.isCollection}. This implementation raises L{NotImplementedError}; a subclass must override this method. """ unimplemented(self) def findChildren( self, depth, request, callback, privileges=None, inherited_aces=None ): """ See L{IDAVResource.findChildren}. This implementation works for C{depth} values of C{"0"}, C{"1"}, and C{"infinity"}. As long as C{self.listChildren} is implemented """ assert depth in ("0", "1", "infinity"), "Invalid depth: %s" % (depth,) if depth == "0" or not self.isCollection(): return succeed(None) completionDeferred = Deferred() basepath = request.urlForResource(self) children = [] def checkPrivilegesError(failure): failure.trap(AccessDeniedError) reactor.callLater(0, getChild) def checkPrivileges(child): if child is None: return None if privileges is None: return child d = child.checkPrivileges( request, privileges, inherited_aces=inherited_aces ) d.addCallback(lambda _: child) return d def gotChild(child, childpath): if child is None: callback(None, childpath + "/") else: if child.isCollection(): callback(child, childpath + "/") if depth == "infinity": d = child.findChildren( depth, request, callback, privileges ) d.addCallback(lambda x: reactor.callLater(0, getChild)) return d else: callback(child, childpath) reactor.callLater(0, getChild) def getChild(): try: childname = children.pop() except IndexError: completionDeferred.callback(None) else: childpath = joinURL(basepath, urllib.quote(childname)) d = request.locateChildResource(self, childname) d.addCallback(checkPrivileges) d.addCallbacks(gotChild, checkPrivilegesError, (childpath,)) d.addErrback(completionDeferred.errback) def gotChildren(listChildrenResult): children[:] = list(listChildrenResult) getChild() maybeDeferred(self.listChildren).addCallback(gotChildren) return completionDeferred @inlineCallbacks def findChildrenFaster( self, depth, request, okcallback, badcallback, missingcallback, unavailablecallback, names, privileges, inherited_aces ): """ See L{IDAVResource.findChildren}. This implementation works for C{depth} values of C{"0"}, C{"1"}, and C{"infinity"}. As long as C{self.listChildren} is implemented @param depth: a C{str} for the depth: "0", "1" and "infinity" only allowed. @param request: the L{Request} for the current request in progress @param okcallback: a callback function used on all resources that pass the privilege check, or C{None} @param badcallback: a callback function used on all resources that fail the privilege check, or C{None} @param missingcallback: a callback function used on all resources that are missing, or C{None} @param names: a C{list} of C{str}'s containing the names of the child resources to lookup. If empty or C{None} all children will be examined, otherwise only the ones in the list. @param privileges: a list of privileges to check. @param inherited_aces: the list of parent ACEs that are inherited by all children. """ assert depth in ("0", "1", "infinity"), "Invalid depth: %s" % (depth,) if depth == "0" or not self.isCollection(): returnValue(None) # First find all depth 1 children names1 = [] namesDeep = [] collections1 = [] if names: for name in names: (names1 if name.rstrip("/").find("/") == -1 else namesDeep).append(name.rstrip("/")) #children = [] #yield self.findChildren("1", request, lambda x, y: children.append((x, y)), privileges=None, inherited_aces=None) children = [] basepath = request.urlForResource(self) childnames = list((yield self.listChildren())) for childname in childnames: childpath = joinURL(basepath, urllib.quote(childname)) try: child = (yield request.locateChildResource(self, childname)) except HTTPError, e: log.error("Resource cannot be located: %s" % (str(e),)) if unavailablecallback: unavailablecallback(childpath) continue if child is not None: if child.isCollection(): collections1.append((child, childpath + "/")) if names and childname not in names1: continue if child.isCollection(): children.append((child, childpath + "/")) else: children.append((child, childpath)) if missingcallback: for name in set(names1) - set(childnames): missingcallback(joinURL(basepath, urllib.quote(name))) # Generate (acl,supported_privs) map aclmap = {} for resource, url in children: acl = (yield resource.accessControlList( request, inheritance=False, inherited_aces=inherited_aces )) supportedPrivs = (yield resource.supportedPrivileges(request)) aclmap.setdefault( (pickle.dumps(acl), supportedPrivs), (acl, supportedPrivs, []) )[2].append((resource, url)) # Now determine whether each ace satisfies privileges #print(aclmap) for items in aclmap.itervalues(): checked = (yield self.checkACLPrivilege( request, items[0], items[1], privileges, inherited_aces )) if checked: for resource, url in items[2]: if okcallback: okcallback(resource, url) else: if badcallback: for resource, url in items[2]: badcallback(resource, url) if depth == "infinity": # Split names into child collection groups child_collections = {} for name in namesDeep: collection, name = name.split("/", 1) child_collections.setdefault(collection, []).append(name) for collection, url in collections1: collection_name = url.split("/")[-2] if collection_name in child_collections: collection_inherited_aces = ( yield collection.inheritedACEsforChildren(request) ) yield collection.findChildrenFaster( depth, request, okcallback, badcallback, missingcallback, unavailablecallback, child_collections[collection_name] if names else None, privileges, inherited_aces=collection_inherited_aces ) returnValue(None) @inlineCallbacks def checkACLPrivilege( self, request, acl, privyset, privileges, inherited_aces ): if acl is None: returnValue(False) principal = self.currentPrincipal(request) # Other principal types don't make sense as actors. assert principal.children[0].name in ("unauthenticated", "href"), ( "Principal is not an actor: %r" % (principal,) ) acl = self.fullAccessControlList(acl, inherited_aces) pending = list(privileges) denied = [] for ace in acl.children: for privilege in tuple(pending): if not self.matchPrivilege( element.Privilege(privilege), ace.privileges, privyset ): continue match = (yield self.matchPrincipal( principal, ace.principal, request )) if match: if ace.invert: continue else: if not ace.invert: continue pending.remove(privilege) if not ace.allow: denied.append(privilege) returnValue(len(denied) + len(pending) == 0) def fullAccessControlList(self, acl, inherited_aces): """ See L{IDAVResource.accessControlList}. This implementation looks up the ACL in the private property C{(L{twisted_private_namespace}, "acl")}. If no ACL has been stored for this resource, it returns the value returned by C{defaultAccessControlList}. If access is disabled it will return C{None}. """ # # Inheritance is problematic. Here is what we do: # # 1. A private element is defined for use inside # of a . This private element is removed when the ACE is # exposed via WebDAV. # # 2. When checking ACLs with inheritance resolution, the server must # examine all parent resources of the current one looking for any # elements. # # If those are defined, the relevant ace is applied to the ACL on the # current resource. # # Dynamically update privileges for those ace's that are inherited. if acl: aces = list(acl.children) else: aces = [] aces.extend(inherited_aces) acl = element.ACL(*aces) return acl def supportedReports(self): """ See L{IDAVResource.supportedReports}. This implementation lists the three main ACL reports and expand-property. """ result = [] result.append(element.Report(element.ACLPrincipalPropSet(),)) result.append(element.Report(element.PrincipalMatch(),)) result.append(element.Report(element.PrincipalPropertySearch(),)) result.append(element.Report(element.ExpandProperty(),)) result.append(element.Report(customxml.CalendarServerPrincipalSearch(),)) return result ## # Authentication ## def authorize(self, request, privileges, recurse=False): """ See L{IDAVResource.authorize}. """ def whenAuthenticated(result): privilegeCheck = self.checkPrivileges(request, privileges, recurse) return privilegeCheck.addErrback(whenAccessDenied) def whenAccessDenied(f): f.trap(AccessDeniedError) # If we were unauthenticated to start with (no # Authorization header from client) then we should return # an unauthorized response instead to force the client to # login if it can. # We're not adding the headers here because this response # class is supposed to be a FORBIDDEN status code and # "Authorization will not help" according to RFC2616 def translateError(response): return Failure(HTTPError(response)) if request.authnUser == element.Principal(element.Unauthenticated()): return UnauthorizedResponse.makeResponse( request.credentialFactories, request.remoteAddr).addCallback(translateError) else: return translateError( NeedPrivilegesResponse(request.uri, f.value.errors)) d = self.authenticate(request) d.addCallback(whenAuthenticated) return d def authenticate(self, request): """ Authenticate the given request against the portal, setting both C{request.authzUser} (a C{str}, the username for the purposes of authorization) and C{request.authnUser} (a C{str}, the username for the purposes of authentication) when it has been authenticated. In order to authenticate, the request must have been previously prepared by L{twext.web2.dav.auth.AuthenticationWrapper.hook} to have the necessary authentication metadata. If the request was not thusly prepared, both C{authzUser} and C{authnUser} will be L{element.Unauthenticated}. @param request: the request which may contain authentication information and a reference to a portal to authenticate against. @type request: L{twext.web2.iweb.IRequest}. @return: a L{Deferred} which fires with a 2-tuple of C{(authnUser, authzUser)} if either the request is unauthenticated OR contains valid credentials to authenticate as a principal, or errbacks with L{HTTPError} if the authentication scheme is unsupported, or the credentials provided by the request are not valid. """ # Bypass normal authentication if its already been done (by SACL check) if ( hasattr(request, "authnUser") and hasattr(request, "authzUser") and request.authnUser is not None and request.authzUser is not None ): return succeed((request.authnUser, request.authzUser)) if not ( hasattr(request, "portal") and hasattr(request, "credentialFactories") and hasattr(request, "loginInterfaces") ): request.authnUser = element.Principal(element.Unauthenticated()) request.authzUser = element.Principal(element.Unauthenticated()) return succeed((request.authnUser, request.authzUser)) authHeader = request.headers.getHeader("authorization") if authHeader is not None: if authHeader[0] not in request.credentialFactories: log.debug( "Client authentication scheme %s is not provided by server %s" % (authHeader[0], request.credentialFactories.keys()) ) d = UnauthorizedResponse.makeResponse( request.credentialFactories, request.remoteAddr ) return d.addCallback(lambda response: Failure(HTTPError(response))) else: factory = request.credentialFactories[authHeader[0]] def gotCreds(creds): d = self.principalsForAuthID(request, creds.username) d.addCallback(gotDetails, creds) return d # Try to match principals in each principal collection # on the resource def gotDetails(details, creds): if details == (None, None): log.info( "Could not find the principal resource for user id: %s" % (creds.username,) ) return Failure(HTTPError(responsecode.UNAUTHORIZED)) authnPrincipal = IDAVPrincipalResource(details[0]) authzPrincipal = IDAVPrincipalResource(details[1]) return PrincipalCredentials(authnPrincipal, authzPrincipal, creds) def login(pcreds): return request.portal.login(pcreds, None, *request.loginInterfaces) def gotAuth(result): request.authnUser = result[1] request.authzUser = result[2] return (request.authnUser, request.authzUser) def translateUnauthenticated(f): f.trap(UnauthorizedLogin, LoginFailed) log.info("Authentication failed: %s" % (f.value,)) d = UnauthorizedResponse.makeResponse( request.credentialFactories, request.remoteAddr ) d.addCallback(lambda response: Failure(HTTPError(response))) return d d = factory.decode(authHeader[1], request) d.addCallback(gotCreds) d.addCallback(login) d.addCallbacks(gotAuth, translateUnauthenticated) return d else: if ( hasattr(request, "checkedWiki") and hasattr(request, "authnUser") and hasattr(request, "authzUser") ): # This request has already been authenticated via the wiki return succeed((request.authnUser, request.authzUser)) request.authnUser = element.Principal(element.Unauthenticated()) request.authzUser = element.Principal(element.Unauthenticated()) return succeed((request.authnUser, request.authzUser)) ## # ACL ## def currentPrincipal(self, request): """ @param request: the request being processed. @return: the current authorized principal, as derived from the given request. """ if hasattr(request, "authzUser"): return request.authzUser else: return unauthenticatedPrincipal def principalCollections(self): """ See L{IDAVResource.principalCollections}. """ if hasattr(self, "_principalCollections"): return self._principalCollections else: return () def defaultRootAccessControlList(self): """ @return: the L{element.ACL} element containing the default access control list for this resource. """ # # The default behaviour is to allow GET access to everything # and deny any type of write access (PUT, DELETE, etc.) to # everything. # return readonlyACL def defaultAccessControlList(self): """ @return: the L{element.ACL} element containing the default access control list for this resource. """ # # The default behaviour is no ACL; we should inherit from the parent # collection. # return element.ACL() def setAccessControlList(self, acl): """ See L{IDAVResource.setAccessControlList}. This implementation stores the ACL in the private property C{(L{twisted_private_namespace}, "acl")}. """ self.writeDeadProperty(acl) @inlineCallbacks def mergeAccessControlList(self, new_acl, request): """ Merges the supplied access control list with the one on this resource. Merging means change all the non-inherited and non-protected ace's in the original, and do not allow the new one to specify an inherited or protected access control entry. This is the behaviour required by the C{ACL} request. (RFC 3744, section 8.1). @param new_acl: an L{element.ACL} element @param request: the request being processed. @return: a tuple of the C{DAV:error} precondition element if an error occurred, C{None} otherwise. This implementation stores the ACL in the private property """ # C{(L{twisted_private_namespace}, "acl")}. # Steps for ACL evaluation: # 1. Check that ace's on incoming do not match a protected ace # 2. Check that ace's on incoming do not match an inherited ace # 3. Check that ace's on incoming all have deny before grant # 4. Check that ace's on incoming do not use abstract privilege # 5. Check that ace's on incoming are supported # (and are not inherited themselves) # 6. Check that ace's on incoming have valid principals # 7. Copy the original # 8. Remove all non-inherited and non-protected - and also inherited # 9. Add in ace's from incoming # 10. Verify that new acl is not in conflict with itself # 11. Update acl on the resource # Get the current access control list, preserving any private # properties on the ACEs as we will need to keep those when we # change the ACL. old_acl = (yield self.accessControlList(request, expanding=True)) # Check disabled if old_acl is None: returnValue(None) # Need to get list of supported privileges supported = [] def addSupportedPrivilege(sp): """ Add the element in any DAV:Privilege to our list and recurse into any DAV:SupportedPrivilege's """ for item in sp.children: if isinstance(item, element.Privilege): supported.append(item.children[0]) elif isinstance(item, element.SupportedPrivilege): addSupportedPrivilege(item) supportedPrivs = (yield self.supportedPrivileges(request)) for item in supportedPrivs.children: assert isinstance(item, element.SupportedPrivilege), ( "Not a SupportedPrivilege: %r" % (item,) ) addSupportedPrivilege(item) # Steps 1 - 6 got_deny = False for ace in new_acl.children: for old_ace in old_acl.children: if (ace.principal == old_ace.principal): # Step 1 if old_ace.protected: log.error("Attempt to overwrite protected ace %r " "on resource %r" % (old_ace, self)) returnValue(( element.dav_namespace, "no-protected-ace-conflict" )) # Step 2 # # RFC3744 says that we either enforce the # inherited ace conflict or we ignore it but use # access control evaluation to determine whether # there is any impact. Given that we have the # "inheritable" behavior it does not make sense to # disallow overrides of inherited ACEs since # "inheritable" cannot itself be controlled via # protocol. # # Otherwise, we'd use this logic: # #elif old_ace.inherited: # log.error("Attempt to overwrite inherited ace %r " # "on resource %r" % (old_ace, self)) # returnValue(( # element.dav_namespace, # "no-inherited-ace-conflict" # )) # Step 3 if ace.allow and got_deny: log.error("Attempt to set grant ace %r after deny ace " "on resource %r" % (ace, self)) returnValue((element.dav_namespace, "deny-before-grant")) got_deny = not ace.allow # Step 4: ignore as this server has no abstract privileges # (FIXME: none yet?) # Step 5 for privilege in ace.privileges: if privilege.children[0] not in supported: log.error("Attempt to use unsupported privilege %r " "in ace %r on resource %r" % (privilege.children[0], ace, self)) returnValue(( element.dav_namespace, "not-supported-privilege" )) if ace.protected: log.error("Attempt to create protected ace %r on resource %r" % (ace, self)) returnValue((element.dav_namespace, "no-ace-conflict")) if ace.inherited: log.error("Attempt to create inherited ace %r on resource %r" % (ace, self)) returnValue((element.dav_namespace, "no-ace-conflict")) # Step 6 valid = (yield self.validPrincipal(ace.principal, request)) if not valid: log.error("Attempt to use unrecognized principal %r " "in ace %r on resource %r" % (ace.principal, ace, self)) returnValue((element.dav_namespace, "recognized-principal")) # Step 8 & 9 # # Iterate through the old ones and replace any that are in the # new set, or remove the non-inherited/non-protected not in # the new set # new_aces = [ace for ace in new_acl.children] new_set = [] for old_ace in old_acl.children: for i, new_ace in enumerate(new_aces): if self.samePrincipal(new_ace.principal, old_ace.principal): new_set.append(new_ace) del new_aces[i] break else: if old_ace.protected and not old_ace.inherited: new_set.append(old_ace) new_set.extend(new_aces) # Step 10 # FIXME: verify acl is self-consistent # Step 11 yield self.writeNewACEs(new_set) returnValue(None) def writeNewACEs(self, new_aces): """ Write a new ACL to the resource's property store. This is a separate method so that it can be overridden by resources that need to do extra processing of ACLs being set via the ACL command. @param new_aces: C{list} of L{ACE} for ACL being set. """ return self.setAccessControlList(element.ACL(*new_aces)) def matchPrivilege(self, privilege, ace_privileges, supportedPrivileges): for ace_privilege in ace_privileges: if ( privilege == ace_privilege or ace_privilege.isAggregateOf(privilege, supportedPrivileges) ): return True return False @inlineCallbacks def checkPrivileges( self, request, privileges, recurse=False, principal=None, inherited_aces=None ): """ Check whether the given principal has the given privileges. (RFC 3744, section 5.5) @param request: the request being processed. @param privileges: an iterable of L{WebDAVElement} elements denoting access control privileges. @param recurse: C{True} if a recursive check on all child resources of this resource should be performed as well, C{False} otherwise. @param principal: the L{element.Principal} to check privileges for. If C{None}, it is deduced from C{request} by calling L{currentPrincipal}. @param inherited_aces: a list of L{element.ACE}s corresponding to the pre-computed inheritable aces from the parent resource hierarchy. @return: a L{Deferred} that callbacks with C{None} or errbacks with an L{AccessDeniedError} """ if principal is None: principal = self.currentPrincipal(request) supportedPrivs = (yield self.supportedPrivileges(request)) # Other principals types don't make sense as actors. assert principal.children[0].name in ("unauthenticated", "href"), ( "Principal is not an actor: %r" % (principal,) ) errors = [] resources = [(self, None)] if recurse: yield self.findChildren( "infinity", request, lambda x, y: resources.append((x, y)) ) for resource, uri in resources: acl = (yield resource.accessControlList( request, inherited_aces=inherited_aces ) ) # Check for disabled if acl is None: errors.append((uri, list(privileges))) continue pending = list(privileges) denied = [] for ace in acl.children: for privilege in tuple(pending): if not self.matchPrivilege( element.Privilege(privilege), ace.privileges, supportedPrivs ): continue match = (yield self.matchPrincipal(principal, ace.principal, request) ) if match: if ace.invert: continue else: if not ace.invert: continue pending.remove(privilege) if not ace.allow: denied.append(privilege) denied += pending # If no matching ACE, then denied if denied: errors.append((uri, denied)) if errors: raise AccessDeniedError(errors,) returnValue(None) def supportedPrivileges(self, request): """ See L{IDAVResource.supportedPrivileges}. This implementation returns a supported privilege set containing only the DAV:all privilege. """ return succeed(davPrivilegeSet) def currentPrivileges(self, request): """ See L{IDAVResource.currentPrivileges}. This implementation returns a current privilege set containing only the DAV:all privilege. """ current = self.currentPrincipal(request) return self.privilegesForPrincipal(current, request) @inlineCallbacks def accessControlList( self, request, inheritance=True, expanding=False, inherited_aces=None ): """ See L{IDAVResource.accessControlList}. This implementation looks up the ACL in the private property C{(L{twisted_private_namespace}, "acl")}. If no ACL has been stored for this resource, it returns the value returned by C{defaultAccessControlList}. If access is disabled it will return C{None}. """ # # Inheritance is problematic. Here is what we do: # # 1. A private element is defined for # use inside of a . This private element is # removed when the ACE is exposed via WebDAV. # # 2. When checking ACLs with inheritance resolution, the # server must examine all parent resources of the current # one looking for any elements. # # If those are defined, the relevant ace is applied to the ACL on the # current resource. # myURL = None def getMyURL(): url = request.urlForResource(self) assert url is not None, ( "urlForResource(self) returned None for resource %s" % (self,) ) return url try: acl = self.readDeadProperty(element.ACL) except HTTPError, e: assert e.response.code == responsecode.NOT_FOUND, ( "Expected %s response from readDeadProperty() exception, " "not %s" % (responsecode.NOT_FOUND, e.response.code) ) # Produce a sensible default for an empty ACL. if myURL is None: myURL = getMyURL() if myURL == "/": # If we get to the root without any ACLs, then use the default. acl = self.defaultRootAccessControlList() else: acl = self.defaultAccessControlList() # Dynamically update privileges for those ace's that are inherited. if inheritance: aces = list(acl.children) if myURL is None: myURL = getMyURL() if inherited_aces is None: if myURL != "/": parentURL = parentForURL(myURL) parent = (yield request.locateResource(parentURL)) if parent: parent_acl = (yield parent.accessControlList( request, inheritance=True, expanding=True ) ) # Check disabled if parent_acl is None: returnValue(None) for ace in parent_acl.children: if ace.inherited: aces.append(ace) elif TwistedACLInheritable() in ace.children: # Adjust ACE for inherit on this resource children = list(ace.children) children.remove(TwistedACLInheritable()) children.append( element.Inherited(element.HRef(parentURL)) ) aces.append(element.ACE(*children)) else: aces.extend(inherited_aces) # Always filter out any remaining private properties when we are # returning the ACL for the final resource after doing parent # inheritance. if not expanding: aces = [ element.ACE(*[ c for c in ace.children if c != TwistedACLInheritable() ]) for ace in aces ] acl = element.ACL(*aces) returnValue(acl) def inheritedACEsforChildren(self, request): """ Do some optimisation of access control calculation by determining any inherited ACLs outside of the child resource loop and supply those to the checkPrivileges on each child. @param request: the L{IRequest} for the request in progress. @return: a C{list} of L{Ace}s that child resources of this one will inherit. """ # Get the parent ACLs with inheritance and preserve the # element. def gotACL(parent_acl): # Check disabled if parent_acl is None: return None # Filter out those that are not inheritable (and remove # the inheritable element from those that are) aces = [] for ace in parent_acl.children: if ace.inherited: aces.append(ace) elif TwistedACLInheritable() in ace.children: # Adjust ACE for inherit on this resource children = list(ace.children) children.remove(TwistedACLInheritable()) children.append( element.Inherited( element.HRef(request.urlForResource(self)) ) ) aces.append(element.ACE(*children)) return aces d = self.accessControlList(request, inheritance=True, expanding=True) d.addCallback(gotACL) return d def inheritedACLSet(self): """ @return: a sequence of L{element.HRef}s from which ACLs are inherited. This implementation returns an empty set. """ return [] def principalsForAuthID(self, request, authid): """ Return authentication and authorization principal identifiers for the authentication identifier passed in. In this implementation authn and authz principals are the same. @param request: the L{IRequest} for the request in progress. @param authid: a string containing the authentication/authorization identifier for the principal to lookup. @return: a deferred tuple of two tuples. Each tuple is C{(principal, principalURI)} where: C{principal} is the L{Principal} that is found; {principalURI} is the C{str} URI of the principal. The first tuple corresponds to authentication identifiers, the second to authorization identifiers. It will errback with an HTTPError(responsecode.FORBIDDEN) if the principal isn't found. """ authnPrincipal = self.findPrincipalForAuthID(authid) if authnPrincipal is None: return succeed((None, None)) d = self.authorizationPrincipal(request, authid, authnPrincipal) d.addCallback(lambda authzPrincipal: (authnPrincipal, authzPrincipal)) return d def findPrincipalForAuthID(self, authid): """ Return authentication and authorization principal identifiers for the authentication identifier passed in. In this implementation authn and authz principals are the same. @param authid: a string containing the authentication/authorization identifier for the principal to lookup. @return: a tuple of C{(principal, principalURI)} where: C{principal} is the L{Principal} that is found; {principalURI} is the C{str} URI of the principal. If not found return None. """ for collection in self.principalCollections(): principal = collection.principalForUser(authid) if principal is not None: return principal return None def authorizationPrincipal(self, request, authid, authnPrincipal): """ Determine the authorization principal for the given request and authentication principal. This implementation simply uses that authentication principal as the authorization principal. @param request: the L{IRequest} for the request in progress. @param authid: a string containing the authentication/authorization identifier for the principal to lookup. @param authnPrincipal: the L{IDAVPrincipal} for the authenticated principal @return: a deferred result C{tuple} of (L{IDAVPrincipal}, C{str}) containing the authorization principal resource and URI respectively. """ return succeed(authnPrincipal) def samePrincipal(self, principal1, principal2): """ Check whether the two principals are exactly the same in terms of elements and data. @param principal1: a L{Principal} to test. @param principal2: a L{Principal} to test. @return: C{True} if they are the same, C{False} otherwise. """ # The interesting part of a principal is it's one child principal1 = principal1.children[0] principal2 = principal2.children[0] if type(principal1) == type(principal2): if isinstance(principal1, element.Property): return ( type(principal1.children[0]) == type(principal2.children[0]) ) elif isinstance(principal1, element.HRef): return ( str(principal1.children[0]) == str(principal2.children[0]) ) else: return True else: return False def matchPrincipal(self, principal1, principal2, request): """ Check whether the principal1 is a principal in the set defined by principal2. @param principal1: a L{Principal} to test. C{principal1} must contain a L{element.HRef} or L{element.Unauthenticated} element. @param principal2: a L{Principal} to test. @param request: the request being processed. @return: C{True} if they match, C{False} otherwise. """ # See RFC 3744, section 5.5.1 # The interesting part of a principal is it's one child principal1 = principal1.children[0] principal2 = principal2.children[0] if not hasattr(request, "matchPrincipals"): request.matchPrincipals = {} cache_key = (str(principal1), str(principal2)) match = request.matchPrincipals.get(cache_key, None) if match is not None: return succeed(match) def doMatch(): if isinstance(principal2, element.All): return succeed(True) elif isinstance(principal2, element.Authenticated): if isinstance(principal1, element.Unauthenticated): return succeed(False) elif isinstance(principal1, element.All): return succeed(False) else: return succeed(True) elif isinstance(principal2, element.Unauthenticated): if isinstance(principal1, element.Unauthenticated): return succeed(True) else: return succeed(False) elif isinstance(principal1, element.Unauthenticated): return succeed(False) assert isinstance(principal1, element.HRef), ( "Not an HRef: %r" % (principal1,) ) def resolved(principal2): assert principal2 is not None, "principal2 is None" # Compare two HRefs and do group membership test as well if principal1 == principal2: return True return self.principalIsGroupMember( str(principal1), str(principal2), request ) d = self.resolvePrincipal(principal2, request) d.addCallback(resolved) return d def cache(match): request.matchPrincipals[cache_key] = match return match d = doMatch() d.addCallback(cache) return d @inlineCallbacks def principalIsGroupMember(self, principal1, principal2, request): """ Check whether one principal is a group member of another. @param principal1: C{str} principalURL for principal to test. @param principal2: C{str} principalURL for possible group principal to test against. @param request: the request being processed. @return: L{Deferred} with result C{True} if principal1 is a member of principal2, C{False} otherwise """ resource1 = yield request.locateResource(principal1) resource2 = yield request.locateResource(principal2) if resource2 and isinstance(resource2, DAVPrincipalResource): isContained = yield resource2.containsPrincipal(resource1) returnValue(isContained) returnValue(False) def validPrincipal(self, ace_principal, request): """ Check whether the supplied principal is valid for this resource. @param ace_principal: the L{Principal} element to test @param request: the request being processed. @return C{True} if C{ace_principal} is valid, C{False} otherwise. This implementation tests for a valid element type and checks for an href principal that exists inside of a principal collection. """ def defer(): # # We know that the element contains a valid element type, so all # we need to do is check for a valid property and a valid href. # real_principal = ace_principal.children[0] if isinstance(real_principal, element.Property): # See comments in matchPrincipal(). We probably need # some common code. log.error("Encountered a property principal (%s), " "but handling is not implemented." % (real_principal,)) return False if isinstance(real_principal, element.HRef): return self.validHrefPrincipal(real_principal, request) return True return maybeDeferred(defer) def validHrefPrincipal(self, href_principal, request): """ Check whether the supplied principal (in the form of an Href) is valid for this resource. @param href_principal: the L{Href} element to test @param request: the request being processed. @return C{True} if C{href_principal} is valid, C{False} otherwise. This implementation tests for a href element that corresponds to a principal resource and matches the principal-URL. """ # Must have the principal resource type and must match the # principal-URL def _matchPrincipalURL(resource): return ( isPrincipalResource(resource) and resource.principalURL() == str(href_principal) ) d = request.locateResource(str(href_principal)) d.addCallback(_matchPrincipalURL) return d def resolvePrincipal(self, principal, request): """ Resolves a L{element.Principal} element into a L{element.HRef} element if possible. Specifically, the given C{principal}'s contained element is resolved. L{element.Property} is resolved to the URI in the contained property. L{element.Self} is resolved to the URI of this resource. L{element.HRef} elements are returned as-is. All other principals, including meta-principals (eg. L{element.All}), resolve to C{None}. @param principal: the L{element.Principal} child element to resolve. @param request: the request being processed. @return: a deferred L{element.HRef} element or C{None}. """ if isinstance(principal, element.Property): # NotImplementedError("Property principals are not implemented.") # # We can't raise here without potentially crippling the # server in a way that can't be fixed over the wire, so # let's refuse the match and log an error instead. # # Note: When fixing this, also fix validPrincipal() # log.error("Encountered a property principal (%s), " "but handling is not implemented; invalid for ACL use." % (principal,)) return succeed(None) # # FIXME: I think this is wrong - we need to get the # namespace and name from the first child of DAV:property # namespace = principal.attributes.get(["namespace"], dav_namespace) name = principal.attributes["name"] def gotPrincipal(principal): try: principal = principal.getResult() except HTTPError, e: assert e.response.code == responsecode.NOT_FOUND, ( "%s (!= %s) status from readProperty() exception" % (e.response.code, responsecode.NOT_FOUND) ) return None if not isinstance(principal, element.Principal): log.error("Non-principal value in property %s " "referenced by property principal." % (encodeXMLName(namespace, name),)) return None if len(principal.children) != 1: return None # The interesting part of a principal is it's one child principal = principal.children[0] # XXXXXX FIXME XXXXXX d = self.readProperty((namespace, name), request) d.addCallback(gotPrincipal) return d elif isinstance(principal, element.Self): try: self = IDAVPrincipalResource(self) except TypeError: log.error("DAV:self ACE is set on non-principal resource %r" % (self,)) return succeed(None) principal = element.HRef(self.principalURL()) if isinstance(principal, element.HRef): return succeed(principal) assert isinstance(principal, ( element.All, element.Authenticated, element.Unauthenticated )), "Not a meta-principal: %r" % (principal,) return succeed(None) @inlineCallbacks def privilegesForPrincipal(self, principal, request): """ See L{IDAVResource.privilegesForPrincipal}. """ # NB Return aggregate privileges expanded. acl = (yield self.accessControlList(request)) # Check disabled if acl is None: returnValue(()) granted = [] denied = [] for ace in acl.children: # First see if the ace's principal affects the principal # being tested. FIXME: support the DAV:invert operation match = (yield self.matchPrincipal(principal, ace.principal, request) ) if match: # Expand aggregate privileges ps = [] supportedPrivs = (yield self.supportedPrivileges(request) ) for p in ace.privileges: ps.extend(p.expandAggregate(supportedPrivs)) # Merge grant/deny privileges if ace.allow: granted.extend([p for p in ps if p not in granted]) else: denied.extend([p for p in ps if p not in denied]) # Subtract denied from granted allowed = tuple(p for p in granted if p not in denied) returnValue(allowed) def matchACEinACL(self, acl, ace): """ Find an ACE in the ACL that matches the supplied ACE's principal. @param acl: the L{ACL} to look at. @param ace: the L{ACE} to try and match @return: the L{ACE} in acl that matches, None otherwise. """ for a in acl.children: if self.samePrincipal(a.principal, ace.principal): return a return None def principalSearchPropertySet(self): """ @return: a L{element.PrincipalSearchPropertySet} element describing the principal properties that can be searched on this principal collection, or C{None} if this is not a principal collection. This implementation returns None. Principal collection resources must override and return their own suitable response. """ return None ## # Quota ## """ The basic policy here is to define a private 'quota-root' property on a collection. That property will contain the maximum allowed bytes for the collections and all its contents. In order to determine the quota property values on a resource, the server must look for the private property on that resource and any of its parents. If found on a parent, then that parent should be queried for quota information. If not found, no quota exists for the resource. To determine that actual quota in use we will cache the used byte count on the quota-root collection in another private property. It is the servers responsibility to keep that property up to date by adjusting it after every PUT, DELETE, COPY, MOVE, MKCOL, PROPPATCH, ACL, POST or any other method that may affect the size of stored data. If the private property is not present, the server will fall back to getting the size by iterating over all resources (this is done in static.py). """ def quota(self, request): """ Get current available & used quota values for this resource's quota root collection. @return: an L{Deferred} with result C{tuple} containing two C{int}'s the first is quota-available-bytes, the second is quota-used-bytes, or C{None} if quota is not defined on the resource. """ # See if already cached if hasattr(request, "quota"): if self in request.quota: return succeed(request.quota[self]) else: request.quota = {} # Find the quota root for this resource and return its data def gotQuotaRootResource(qroot_resource): if qroot_resource: qroot = qroot_resource.quotaRoot(request) if qroot is not None: def gotUsage(used): available = qroot - used if available < 0: available = 0 request.quota[self] = (available, used) return (available, used) d = qroot_resource.currentQuotaUse(request) d.addCallback(gotUsage) return d request.quota[self] = None return None d = self.quotaRootResource(request) d.addCallback(gotQuotaRootResource) return d def hasQuota(self, request): """ Check whether this resource is under quota control by checking each parent to see if it has a quota root. @return: C{True} if under quota control, C{False} if not. """ def gotQuotaRootResource(qroot_resource): return qroot_resource is not None d = self.quotaRootResource(request) d.addCallback(gotQuotaRootResource) return d def hasQuotaRoot(self, request): """ @return: a C{True} if this resource has quota root, C{False} otherwise. """ return self.hasDeadProperty(TwistedQuotaRootProperty) def quotaRoot(self, request): """ @return: a C{int} containing the maximum allowed bytes if this collection is quota-controlled, or C{None} if not quota controlled. """ if self.hasDeadProperty(TwistedQuotaRootProperty): return int(str(self.readDeadProperty(TwistedQuotaRootProperty))) else: return None @inlineCallbacks def quotaRootResource(self, request): """ Return the quota root for this resource. @return: L{DAVResource} or C{None} """ if self.hasQuotaRoot(request): returnValue(self) # Check the next parent try: url = request.urlForResource(self) except NoURLForResourceError: returnValue(None) while (url != "/"): url = parentForURL(url) if url is None: break parent = (yield request.locateResource(url)) if parent is None: break if parent.hasQuotaRoot(request): returnValue(parent) returnValue(None) def setQuotaRoot(self, request, maxsize): """ @param maxsize: a C{int} containing the maximum allowed bytes for the contents of this collection, or C{None} to remove quota restriction. """ assert self.isCollection(), "Only collections can have a quota root" assert maxsize is None or isinstance(maxsize, int), ( "maxsize must be an int or None" ) if maxsize is not None: self.writeDeadProperty(TwistedQuotaRootProperty(str(maxsize))) else: # Remove both the root and the cached used value self.removeDeadProperty(TwistedQuotaRootProperty) self.removeDeadProperty(TwistedQuotaUsedProperty) def quotaSize(self, request): """ Get the size of this resource (if its a collection get total for all children as well). TODO: Take into account size of dead-properties. @return: a C{int} containing the size of the resource. """ unimplemented(self) def checkQuota(self, request, available): """ Check to see whether all quota roots have sufficient available bytes. We currently do not use hierarchical quota checks - i.e. only the most immediate quota root parent is checked for quota. @param available: a C{int} containing the additional quota required. @return: C{True} if there is sufficient quota remaining on all quota roots, C{False} otherwise. """ def _defer(quotaroot): if quotaroot: # Check quota on this root (if it has one) quota = quotaroot.quotaRoot(request) if quota is not None: if available > quota[0]: return False return True d = self.quotaRootResource(request) d.addCallback(_defer) return d def quotaSizeAdjust(self, request, adjust): """ Update the quota used value on all quota root parents of this resource. @param adjust: a C{int} containing the number of bytes added (positive) or removed (negative) that should be used to adjust the cached total. """ def _defer(quotaroot): if quotaroot: # Check quota on this root (if it has one) return quotaroot.updateQuotaUse(request, adjust) d = self.quotaRootResource(request) d.addCallback(_defer) return d def currentQuotaUse(self, request): """ Get the cached quota use value, or if not present (or invalid) determine quota use by brute force. @return: an L{Deferred} with a C{int} result containing the current used byte if this collection is quota-controlled, or C{None} if not quota controlled. """ assert self.isCollection(), "Only collections can have a quota root" assert self.hasQuotaRoot(request), ( "Quota use only on quota root collection" ) # Try to get the cached value property if self.hasDeadProperty(TwistedQuotaUsedProperty): return succeed( int(str(self.readDeadProperty(TwistedQuotaUsedProperty))) ) else: # Do brute force size determination and cache the result # in the private property def _defer(result): self.writeDeadProperty(TwistedQuotaUsedProperty(str(result))) return result d = self.quotaSize(request) d.addCallback(_defer) return d def updateQuotaUse(self, request, adjust): """ Update the quota used value on this resource. @param adjust: a C{int} containing the number of bytes added (positive) or removed (negative) that should be used to adjust the cached total. @return: an L{Deferred} with a C{int} result containing the current used byte if this collection is quota-controlled, or C{None} if not quota controlled. """ assert self.isCollection(), "Only collections can have a quota root" # Get current value def _defer(size): size += adjust # Sanity check the resulting size if size >= 0: self.writeDeadProperty(TwistedQuotaUsedProperty(str(size))) else: # Remove the dead property and re-read to do brute # force quota calc log.info("Attempt to set quota used to a negative value: %s " "(adjustment: %s)" % (size, adjust,)) self.removeDeadProperty(TwistedQuotaUsedProperty) return self.currentQuotaUse(request) d = self.currentQuotaUse(request) d.addCallback(_defer) return d ## # HTTP ## def renderHTTP(self, request): # FIXME: This is for testing with litmus; comment out when not in use #litmus = request.headers.getRawHeaders("x-litmus") #if litmus: log.info("*** Litmus test: %s ***" % (litmus,)) # # If this is a collection and the URI doesn't end in "/", redirect. # if self.isCollection() and request.path[-1:] != "/": return RedirectResponse( request.unparseURL( path=urllib.quote( urllib.unquote(request.path), safe=':/') + '/' ) ) def setHeaders(response): response = IResponse(response) response.headers.setHeader("dav", self.davComplianceClasses()) # # If this is a collection and the URI doesn't end in "/", # add a Content-Location header. This is needed even if # we redirect such requests (as above) in the event that # this resource was created or modified by the request. # if self.isCollection() and request.path[-1:] != "/" and not response.headers.hasHeader("content-location"): response.headers.setHeader( "content-location", request.path + "/" ) return response def onError(f): # If we get an HTTPError, run its response through # setHeaders() as well. f.trap(HTTPError) return setHeaders(f.value.response) d = maybeDeferred(super(DAVResource, self).renderHTTP, request) return d.addCallbacks(setHeaders, onError) class DAVLeafResource (DAVResource, LeafResource): """ DAV resource with no children. """ def findChildren( self, depth, request, callback, privileges=None, inherited_aces=None ): return succeed(None) class DAVPrincipalResource (DAVResource): """ Resource representing a WebDAV principal. (RFC 3744, section 2) """ implements(IDAVPrincipalResource) ## # WebDAV ## def liveProperties(self): return super(DAVPrincipalResource, self).liveProperties() + ( (dav_namespace, "alternate-URI-set"), (dav_namespace, "principal-URL"), (dav_namespace, "group-member-set"), (dav_namespace, "group-membership"), ) def davComplianceClasses(self): return ("1", "access-control",) def isCollection(self): return False def readProperty(self, property, request): def defer(): if type(property) is tuple: qname = property else: qname = property.qname() namespace, name = qname if namespace == dav_namespace: if name == "alternate-URI-set": return element.AlternateURISet(*[ element.HRef(u) for u in self.alternateURIs() ]) if name == "principal-URL": return element.PrincipalURL( element.HRef(self.principalURL()) ) if name == "group-member-set": def callback(members): return element.GroupMemberSet(*[ element.HRef(p.principalURL()) for p in members ]) d = self.groupMembers() d.addCallback(callback) return d if name == "group-membership": def callback(memberships): return element.GroupMembership(*[ element.HRef(g.principalURL()) for g in memberships ]) d = self.groupMemberships() d.addCallback(callback) return d if name == "resourcetype": if self.isCollection(): return element.ResourceType( element.Collection(), element.Principal() ) else: return element.ResourceType(element.Principal()) return super(DAVPrincipalResource, self).readProperty( qname, request ) return maybeDeferred(defer) ## # ACL ## def alternateURIs(self): """ See L{IDAVPrincipalResource.alternateURIs}. This implementation returns C{()}. Subclasses should override this method to provide alternate URIs for this resource if appropriate. """ return () def principalURL(self): """ See L{IDAVPrincipalResource.principalURL}. This implementation raises L{NotImplementedError}. Subclasses must override this method to provide the principal URL for this resource. """ unimplemented(self) def groupMembers(self): """ This implementation returns a Deferred which fires with C{()}, which is appropriate for non-group principals. Subclasses should override this method to provide member URLs for this resource if appropriate. @see: L{IDAVPrincipalResource.groupMembers}. """ return succeed(()) def expandedGroupMembers(self): """ This implementation returns a Deferred which fires with C{()}, which is appropriate for non-group principals. Subclasses should override this method to provide expanded member URLs for this resource if appropriate. @see: L{IDAVPrincipalResource.expandedGroupMembers} """ return succeed(()) def groupMemberships(self): """ See L{IDAVPrincipalResource.groupMemberships}. This implementation raises L{NotImplementedError}. Subclasses must override this method to provide the group URLs for this resource. """ unimplemented(self) def principalMatch(self, href): """ Check whether the supplied principal matches this principal or is a member of this principal resource. @param href: the L{HRef} to test. @return: True if there is a match, False otherwise. """ uri = str(href) if self.principalURL() == uri: return succeed(True) else: d = self.expandedGroupMembers() d.addCallback( lambda members: uri in [member.principalURL() for member in members] ) return d @inlineCallbacks def containsPrincipal(self, principal): """ Is the given principal contained within our expanded group membership? @param principal: The principal to check @type principal: L{DirectoryCalendarPrincipalResource} @return: True if principal is a member, False otherwise @rtype: C{boolean} """ members = yield self.expandedGroupMembers() returnValue(principal in members) class DAVPrincipalCollectionResource (DAVResource): """ WebDAV principal collection resource. (RFC 3744, section 5.8) This is an abstract class; subclasses must implement C{principalForUser} in order to properly implement it. """ implements(IDAVPrincipalCollectionResource) def __init__(self, url, principalCollections=()): """ @param url: This resource's URL. """ DAVResource.__init__(self, principalCollections=principalCollections) assert url.endswith("/"), "Collection URL must end in '/'" self._url = url def principalCollectionURL(self): """ Return the URL for this principal collection. """ return self._url def principalForUser(self, user): """ Subclasses must implement this method. @see: L{IDAVPrincipalCollectionResource.principalForUser} @raise: L{NotImplementedError} """ raise NotImplementedError( "%s did not implement principalForUser" % (self.__class__) ) class AccessDeniedError(Exception): def __init__(self, errors): """ An error to be raised when some request fails to meet sufficient access privileges for a resource. @param errors: sequence of tuples, one for each resource for which one or more of the given privileges are not granted, in the form C{(uri, privileges)}, where uri is a URL path relative to resource or C{None} if the error was in this resource, privileges is a sequence of the privileges which are not granted a subset thereof. """ Exception.__init__(self, "Access denied for some resources: %r" % (errors,)) self.errors = errors ## # Utilities ## def isPrincipalResource(resource): try: resource = IDAVPrincipalResource(resource) except TypeError: return False else: return True class TwistedACLInheritable (WebDAVEmptyElement): """ When set on an ACE, this indicates that the ACE privileges should be inherited by all child resources within the resource with this ACE. """ namespace = twisted_dav_namespace name = "inheritable" registerElement(TwistedACLInheritable) element.ACE.allowed_children[(twisted_dav_namespace, "inheritable")] = (0, 1) class TwistedGETContentMD5 (WebDAVTextElement): """ MD5 hash of the resource content. """ namespace = twisted_dav_namespace name = "getcontentmd5" registerElement(TwistedGETContentMD5) class TwistedQuotaRootProperty (WebDAVTextElement): """ When set on a collection, this property indicates that the collection has a quota limit for the size of all resources stored in the collection (and any associate meta-data such as properties). The value is a number - the maximum size in bytes allowed. """ namespace = twisted_private_namespace name = "quota-root" registerElement(TwistedQuotaRootProperty) class TwistedQuotaUsedProperty (WebDAVTextElement): """ When set on a collection, this property contains the cached running total of the size of all resources stored in the collection (and any associate meta-data such as properties). The value is a number - the size in bytes used. """ namespace = twisted_private_namespace name = "quota-used" registerElement(TwistedQuotaUsedProperty) allACL = element.ACL( element.ACE( element.Principal(element.All()), element.Grant(element.Privilege(element.All())), element.Protected(), TwistedACLInheritable() ) ) readonlyACL = element.ACL( element.ACE( element.Principal(element.All()), element.Grant(element.Privilege(element.Read())), element.Protected(), TwistedACLInheritable() ) ) allPrivilegeSet = element.SupportedPrivilegeSet( element.SupportedPrivilege( element.Privilege(element.All()), element.Description("all privileges", **{"xml:lang": "en"}) ) ) # # This is one possible graph of the "standard" privileges documented # in 3744, section 3. # davPrivilegeSet = element.SupportedPrivilegeSet( element.SupportedPrivilege( element.Privilege(element.All()), element.Description( "all privileges", **{"xml:lang": "en"} ), element.SupportedPrivilege( element.Privilege(element.Read()), element.Description( "read resource", **{"xml:lang": "en"} ), ), element.SupportedPrivilege( element.Privilege(element.Write()), element.Description( "write resource", **{"xml:lang": "en"} ), element.SupportedPrivilege( element.Privilege(element.WriteProperties()), element.Description( "write resource properties", **{"xml:lang": "en"} ), ), element.SupportedPrivilege( element.Privilege(element.WriteContent()), element.Description( "write resource content", **{"xml:lang": "en"} ), ), element.SupportedPrivilege( element.Privilege(element.Bind()), element.Description( "add child resource", **{"xml:lang": "en"} ), ), element.SupportedPrivilege( element.Privilege(element.Unbind()), element.Description( "remove child resource", **{"xml:lang": "en"} ), ), ), element.SupportedPrivilege( element.Privilege(element.Unlock()), element.Description( "unlock resource without ownership of lock", **{"xml:lang": "en"} ), ), element.SupportedPrivilege( element.Privilege(element.ReadACL()), element.Description( "read resource access control list", **{"xml:lang": "en"} ), ), element.SupportedPrivilege( element.Privilege(element.WriteACL()), element.Description( "write resource access control list", **{"xml:lang": "en"} ), ), element.SupportedPrivilege( element.Privilege(element.ReadCurrentUserPrivilegeSet()), element.Description( "read privileges for current principal", **{"xml:lang": "en"} ), ), ), ) unauthenticatedPrincipal = element.Principal(element.Unauthenticated()) class ResourceClass (WebDAVTextElement): namespace = twisted_dav_namespace name = "resource-class" hidden = False calendarserver-5.2+dfsg/twext/web2/dav/xattrprops.py0000644000175000017500000002332112263343324021676 0ustar rahulrahul# Copyright (c) 2009 Twisted Matrix Laboratories. # See LICENSE for details. ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ DAV Property store using file system extended attributes. This API is considered private to static.py and is therefore subject to change. """ __all__ = ["xattrPropertyStore"] import urllib import sys import zlib import errno from operator import setitem from zlib import compress, decompress from cPickle import UnpicklingError, loads as unpickle import xattr if getattr(xattr, 'xattr', None) is None: raise ImportError("wrong xattr package imported") from twisted.python.util import untilConcludes from twisted.python.failure import Failure from twisted.python.log import err from txdav.xml.base import encodeXMLName from txdav.xml.parser import WebDAVDocument from twext.web2 import responsecode from twext.web2.http import HTTPError, StatusResponse from twext.web2.dav.http import statusForFailure # RFC 2518 Section 12.13.1 says that removal of non-existing property # is not an error. python-xattr on Linux fails with ENODATA in this # case. On Darwin and FreeBSD, the xattr library fails with ENOATTR, # which CPython does not expose. Its value is 93. _ATTR_MISSING = (93,) if hasattr(errno, "ENODATA"): _ATTR_MISSING += (errno.ENODATA,) class xattrPropertyStore (object): """ This implementation uses Bob Ippolito's xattr package, available from:: http://undefined.org/python/#xattr Note that the Bob's xattr package is specific to Linux and Darwin, at least presently. """ # # Dead properties are stored as extended attributes on disk. In order to # avoid conflicts with other attributes, prefix dead property names. # deadPropertyXattrPrefix = "WebDAV:" # Linux seems to require that attribute names use a "user." prefix. # FIXME: Is is a system-wide thing, or a per-filesystem thing? # If the latter, how to we detect the file system? if sys.platform == "linux2": deadPropertyXattrPrefix = "user." def _encode(clazz, name, uid=None): result = urllib.quote(encodeXMLName(*name), safe='{}:') if uid: result = uid + result r = clazz.deadPropertyXattrPrefix + result return r def _decode(clazz, name): name = urllib.unquote(name[len(clazz.deadPropertyXattrPrefix):]) index1 = name.find("{") index2 = name.find("}") if (index1 is -1 or index2 is -1 or not len(name) > index2): raise ValueError("Invalid encoded name: %r" % (name,)) if index1 == 0: uid = None else: uid = name[:index1] propnamespace = name[index1+1:index2] propname = name[index2+1:] return (propnamespace, propname, uid) _encode = classmethod(_encode) _decode = classmethod(_decode) def __init__(self, resource): self.resource = resource self.attrs = xattr.xattr(self.resource.fp.path) def get(self, qname, uid=None): """ Retrieve the value of a property stored as an extended attribute on the wrapped path. @param qname: The property to retrieve as a two-tuple of namespace URI and local name. @param uid: The per-user identifier for per user properties. @raise HTTPError: If there is no value associated with the given property. @return: A L{WebDAVDocument} representing the value associated with the given property. """ try: data = self.attrs.get(self._encode(qname, uid)) except KeyError: raise HTTPError(StatusResponse( responsecode.NOT_FOUND, "No such property: %s" % (encodeXMLName(*qname),) )) except IOError, e: if e.errno in _ATTR_MISSING or e.errno == errno.ENOENT: raise HTTPError(StatusResponse( responsecode.NOT_FOUND, "No such property: %s" % (encodeXMLName(*qname),) )) else: raise HTTPError(StatusResponse( statusForFailure(Failure()), "Unable to read property: %s" % (encodeXMLName(*qname),) )) # # Unserialize XML data from an xattr. The storage format has changed # over time: # # 1- Started with XML # 2- Started compressing the XML due to limits on xattr size # 3- Switched to pickle which is faster, still compressing # 4- Back to compressed XML for interoperability, size # # We only write the current format, but we also read the old # ones for compatibility. # legacy = False try: data = decompress(data) except zlib.error: legacy = True try: doc = WebDAVDocument.fromString(data) except ValueError: try: doc = unpickle(data) except UnpicklingError: format = "Invalid property value stored on server: %s %s" msg = format % (encodeXMLName(*qname), data) err(None, msg) raise HTTPError( StatusResponse(responsecode.INTERNAL_SERVER_ERROR, msg)) else: legacy = True if legacy: self.set(doc.root_element) return doc.root_element def set(self, property, uid=None): """ Store the given property as an extended attribute on the wrapped path. @param uid: The per-user identifier for per user properties. @param property: A L{WebDAVElement} to store. """ key = self._encode(property.qname(), uid) value = compress(property.toxml(pretty=False)) untilConcludes(setitem, self.attrs, key, value) # Update the resource because we've modified it self.resource.fp.restat() def delete(self, qname, uid=None): """ Remove the extended attribute from the wrapped path which stores the property given by C{qname}. @param uid: The per-user identifier for per user properties. @param qname: The property to delete as a two-tuple of namespace URI and local name. """ key = self._encode(qname, uid) try: try: self.attrs.remove(key) except KeyError: pass except IOError, e: if e.errno not in _ATTR_MISSING: raise except: raise HTTPError(StatusResponse( statusForFailure(Failure()), "Unable to delete property: %s", (key,) )) def contains(self, qname, uid=None): """ Determine whether the property given by C{qname} is stored in an extended attribute of the wrapped path. @param qname: The property to look up as a two-tuple of namespace URI and local name. @param uid: The per-user identifier for per user properties. @return: C{True} if the property exists, C{False} otherwise. """ key = self._encode(qname, uid) try: self.attrs.get(key) except KeyError: return False except IOError, e: if e.errno in _ATTR_MISSING or e.errno == errno.ENOENT: return False raise HTTPError(StatusResponse( statusForFailure(Failure()), "Unable to read property: %s" % (key,) )) else: return True def list(self, uid=None, filterByUID=True): """ Enumerate the property names stored in extended attributes of the wrapped path. @param uid: The per-user identifier for per user properties. @return: A C{list} of property names as two-tuples of namespace URI and local name. """ prefix = self.deadPropertyXattrPrefix try: attrs = iter(self.attrs) except IOError, e: if e.errno == errno.ENOENT: return [] raise HTTPError(StatusResponse( statusForFailure(Failure()), "Unable to list properties: %s", (self.resource.fp.path,) )) else: results = [ self._decode(name) for name in attrs if name.startswith(prefix) ] if filterByUID: return [ (namespace, name) for namespace, name, propuid in results if propuid == uid ] else: return results calendarserver-5.2+dfsg/twext/web2/dav/noneprops.py0000644000175000017500000000474212263343324021501 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ Empty DAV property store. This API is considered private to static.py and is therefore subject to change. """ __all__ = ["NonePropertyStore"] from twext.web2 import responsecode from twext.web2.http import HTTPError, StatusResponse from txdav.xml.base import encodeXMLName class NonePropertyStore (object): """ DAV property store which contains no properties and does not allow properties to be set. """ __singleton = None def __new__(clazz, resource): if NonePropertyStore.__singleton is None: NonePropertyStore.__singleton = object.__new__(clazz) return NonePropertyStore.__singleton def __init__(self, resource): pass def get(self, qname, uid=None): raise HTTPError(StatusResponse( responsecode.NOT_FOUND, "No such property: %s" % (encodeXMLName(*qname),) )) def set(self, property, uid=None): raise HTTPError(StatusResponse( responsecode.FORBIDDEN, "Permission denied for setting property: %s" % (property,) )) def delete(self, qname, uid=None): # RFC 2518 Section 12.13.1 says that removal of # non-existing property is not an error. pass def contains(self, qname, uid=None): return False def list(self, uid=None): return () calendarserver-5.2+dfsg/twext/web2/dav/fileop.py0000644000175000017500000004645112263343324020737 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV file operations This API is considered private to static.py and is therefore subject to change. """ __all__ = [ "delete", "copy", "move", "put", "mkcollection", "rmdir", ] import os import urllib from urlparse import urlsplit from twisted.python.failure import Failure from twisted.internet.defer import succeed, deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.python.filepath import CachingFilePath as FilePath from twext.web2 import responsecode from twext.web2.http import StatusResponse, HTTPError from twext.web2.stream import FileStream, readIntoFile from twext.web2.dav.http import ResponseQueue, statusForFailure log = Logger() def delete(uri, filepath, depth="infinity"): """ Perform a X{DELETE} operation on the given URI, which is backed by the given filepath. @param filepath: the L{FilePath} to delete. @param depth: the recursion X{Depth} for the X{DELETE} operation, which must be "infinity". @raise HTTPError: (containing a response with a status code of L{responsecode.BAD_REQUEST}) if C{depth} is not "infinity". @raise HTTPError: (containing an appropriate response) if the delete operation fails. If C{filepath} is a directory, the response will be a L{MultiStatusResponse}. @return: a deferred response with a status code of L{responsecode.NO_CONTENT} if the X{DELETE} operation succeeds. """ # # Remove the file(s) # # FIXME: defer if filepath.isdir(): # # RFC 2518, section 8.6 says that we must act as if the Depth header is # set to infinity, and that the client must omit the Depth header or set # it to infinity, meaning that for collections, we will delete all # members. # # This seems somewhat at odds with the notion that a bad request should # be rejected outright; if the client sends a bad depth header, the # client is broken, and RFC 2518, section 8 suggests that a bad request # should be rejected... # # Let's play it safe for now and ignore broken clients. # if depth != "infinity": msg = ("Client sent illegal depth header value for DELETE: %s" % (depth,)) log.error(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) # # Recursive delete # # RFC 2518, section 8.6 says that if we get an error deleting a resource # other than the collection in the request-URI, that we must respond # with a multi-status response containing error statuses for each # resource that we fail to delete. It also says we should not return # no-content (success) status, which means that we should continue after # errors, rather than aborting right away. This is interesting in that # it's different from how most operating system tools act (eg. rm) when # recursive filsystem deletes fail. # uri_path = urllib.unquote(urlsplit(uri)[2]) if uri_path[-1] == "/": uri_path = uri_path[:-1] log.info("Deleting directory %s" % (filepath.path,)) # NOTE: len(uri_path) is wrong if os.sep is not one byte long... meh. request_basename = filepath.path[:-len(uri_path)] errors = ResponseQueue(request_basename, "DELETE", responsecode.NO_CONTENT) # FIXME: defer this for dir, subdirs, files in os.walk(filepath.path, topdown=False): for filename in files: path = os.path.join(dir, filename) try: os.remove(path) except: errors.add(path, Failure()) for subdir in subdirs: path = os.path.join(dir, subdir) if os.path.islink(path): try: os.remove(path) except: errors.add(path, Failure()) else: try: os.rmdir(path) except: errors.add(path, Failure()) try: os.rmdir(filepath.path) except: raise HTTPError(statusForFailure( Failure(), "deleting directory: %s" % (filepath.path,) )) response = errors.response() else: # # Delete a file; much simpler, eh? # log.info("Deleting file %s" % (filepath.path,)) try: os.remove(filepath.path) except: raise HTTPError(statusForFailure( Failure(), "deleting file: %s" % (filepath.path,) )) response = responsecode.NO_CONTENT # Remove stat info for filepath since we deleted the backing file filepath.changed() return succeed(response) def copy(source_filepath, destination_filepath, destination_uri, depth): """ Perform a X{COPY} from the given source and destination filepaths. This will perform a X{DELETE} on the destination if necessary; the caller should check and handle the X{overwrite} header before calling L{copy} (as in L{COPYMOVE.prepareForCopy}). @param source_filepath: a L{FilePath} for the file to copy from. @param destination_filepath: a L{FilePath} for the file to copy to. @param destination_uri: the URI of the destination resource. @param depth: the recursion X{Depth} for the X{COPY} operation, which must be one of "0", "1", or "infinity". @raise HTTPError: (containing a response with a status code of L{responsecode.BAD_REQUEST}) if C{depth} is not "0", "1" or "infinity". @raise HTTPError: (containing an appropriate response) if the operation fails. If C{source_filepath} is a directory, the response will be a L{MultiStatusResponse}. @return: a deferred response with a status code of L{responsecode.CREATED} if the destination already exists, or L{responsecode.NO_CONTENT} if the destination was created by the X{COPY} operation. """ if source_filepath.isfile(): # # Copy the file # log.info("Copying file %s to %s" % (source_filepath.path, destination_filepath.path)) try: source_file = source_filepath.open() except: raise HTTPError(statusForFailure( Failure(), "opening file for reading: %s" % (source_filepath.path,) )) source_stream = FileStream(source_file) response = waitForDeferred(put(source_stream, destination_filepath, destination_uri)) yield response try: response = response.getResult() finally: source_stream.close() source_file.close() checkResponse(response, "put", responsecode.NO_CONTENT, responsecode.CREATED) yield response return elif source_filepath.isdir(): if destination_filepath.exists(): # # Delete the destination # response = waitForDeferred(delete(destination_uri, destination_filepath)) yield response response = response.getResult() checkResponse(response, "delete", responsecode.NO_CONTENT) success_code = responsecode.NO_CONTENT else: success_code = responsecode.CREATED # # Copy the directory # log.info("Copying directory %s to %s" % (source_filepath.path, destination_filepath.path)) source_basename = source_filepath.path destination_basename = destination_filepath.path errors = ResponseQueue(source_basename, "COPY", success_code) if destination_filepath.parent().isdir(): if os.path.islink(source_basename): link_destination = os.readlink(source_basename) if link_destination[0] != os.path.sep: link_destination = os.path.join(source_basename, link_destination) try: os.symlink(destination_basename, link_destination) except: errors.add(source_basename, Failure()) else: try: os.mkdir(destination_basename) except: raise HTTPError(statusForFailure( Failure(), "creating directory %s" % (destination_basename,) )) if depth == "0": yield success_code return else: raise HTTPError(StatusResponse( responsecode.CONFLICT, "Parent collection for destination %s does not exist" % (destination_uri,) )) # # Recursive copy # # FIXME: When we report errors, do we report them on the source URI # or on the destination URI? We're using the source URI here. # # FIXME: defer the walk? source_basename_len = len(source_basename) def paths(basepath, subpath): source_path = os.path.join(basepath, subpath) assert source_path.startswith(source_basename) destination_path = os.path.join(destination_basename, source_path[source_basename_len+1:]) return source_path, destination_path for dir, subdirs, files in os.walk(source_filepath.path, topdown=True): for filename in files: source_path, destination_path = paths(dir, filename) if not os.path.isdir(os.path.dirname(destination_path)): errors.add(source_path, responsecode.NOT_FOUND) else: response = waitForDeferred(copy(FilePath(source_path), FilePath(destination_path), destination_uri, depth)) yield response response = response.getResult() checkResponse(response, "copy", responsecode.CREATED, responsecode.NO_CONTENT) for subdir in subdirs: source_path, destination_path = paths(dir, subdir) log.info("Copying directory %s to %s" % (source_path, destination_path)) if not os.path.isdir(os.path.dirname(destination_path)): errors.add(source_path, responsecode.CONFLICT) else: if os.path.islink(source_path): link_destination = os.readlink(source_path) if link_destination[0] != os.path.sep: link_destination = os.path.join(source_path, link_destination) try: os.symlink(destination_path, link_destination) except: errors.add(source_path, Failure()) else: try: os.mkdir(destination_path) except: errors.add(source_path, Failure()) yield errors.response() return else: log.error("Unable to COPY to non-file: %s" % (source_filepath.path,)) raise HTTPError(StatusResponse( responsecode.FORBIDDEN, "The requested resource exists but is not backed by a regular file." )) copy = deferredGenerator(copy) def move(source_filepath, source_uri, destination_filepath, destination_uri, depth): """ Perform a X{MOVE} from the given source and destination filepaths. This will perform a X{DELETE} on the destination if necessary; the caller should check and handle the X{overwrite} header before calling L{copy} (as in L{COPYMOVE.prepareForCopy}). Following the X{DELETE}, this will attempt an atomic filesystem move. If that fails, a X{COPY} operation followed by a X{DELETE} on the source will be attempted instead. @param source_filepath: a L{FilePath} for the file to copy from. @param destination_filepath: a L{FilePath} for the file to copy to. @param destination_uri: the URI of the destination resource. @param depth: the recursion X{Depth} for the X{MOVE} operation, which must be "infinity". @raise HTTPError: (containing a response with a status code of L{responsecode.BAD_REQUEST}) if C{depth} is not "infinity". @raise HTTPError: (containing an appropriate response) if the operation fails. If C{source_filepath} is a directory, the response will be a L{MultiStatusResponse}. @return: a deferred response with a status code of L{responsecode.CREATED} if the destination already exists, or L{responsecode.NO_CONTENT} if the destination was created by the X{MOVE} operation. """ log.info("Moving %s to %s" % (source_filepath.path, destination_filepath.path)) # # Choose a success status # if destination_filepath.exists(): # # Delete the destination # response = waitForDeferred(delete(destination_uri, destination_filepath)) yield response response = response.getResult() checkResponse(response, "delete", responsecode.NO_CONTENT) success_code = responsecode.NO_CONTENT else: success_code = responsecode.CREATED # # See if rename (which is atomic, and fast) works # try: os.rename(source_filepath.path, destination_filepath.path) except OSError: pass else: # Remove stat info from source filepath since we moved it source_filepath.changed() yield success_code return # # Do a copy, then delete the source # response = waitForDeferred(copy(source_filepath, destination_filepath, destination_uri, depth)) yield response response = response.getResult() checkResponse(response, "copy", responsecode.CREATED, responsecode.NO_CONTENT) response = waitForDeferred(delete(source_uri, source_filepath)) yield response response = response.getResult() checkResponse(response, "delete", responsecode.NO_CONTENT) yield success_code move = deferredGenerator(move) def put(stream, filepath, uri=None): """ Perform a PUT of the given data stream into the given filepath. @param stream: the stream to write to the destination. @param filepath: the L{FilePath} of the destination file. @param uri: the URI of the destination resource. If the destination exists, if C{uri} is not C{None}, perform a X{DELETE} operation on the destination, but if C{uri} is C{None}, delete the destination directly. Note that whether a L{put} deletes the destination directly vs. performing a X{DELETE} on the destination affects the response returned in the event of an error during deletion. Specifically, X{DELETE} on collections must return a L{MultiStatusResponse} under certain circumstances, whereas X{PUT} isn't required to do so. Therefore, if the caller expects X{DELETE} semantics, it must provide a valid C{uri}. @raise HTTPError: (containing an appropriate response) if the operation fails. @return: a deferred response with a status code of L{responsecode.CREATED} if the destination already exists, or L{responsecode.NO_CONTENT} if the destination was created by the X{PUT} operation. """ log.info("Writing to file %s" % (filepath.path,)) if filepath.exists(): if uri is None: try: if filepath.isdir(): rmdir(filepath.path) else: os.remove(filepath.path) except: raise HTTPError(statusForFailure( Failure(), "writing to file: %s" % (filepath.path,) )) else: response = waitForDeferred(delete(uri, filepath)) yield response response = response.getResult() checkResponse(response, "delete", responsecode.NO_CONTENT) success_code = responsecode.NO_CONTENT else: success_code = responsecode.CREATED # # Write the contents of the request stream to resource's file # try: resource_file = filepath.open("w") except: raise HTTPError(statusForFailure( Failure(), "opening file for writing: %s" % (filepath.path,) )) try: x = waitForDeferred(readIntoFile(stream, resource_file)) yield x x.getResult() except: raise HTTPError(statusForFailure( Failure(), "writing to file: %s" % (filepath.path,) )) # Remove stat info from filepath since we modified the backing file filepath.changed() yield success_code put = deferredGenerator(put) def mkcollection(filepath): """ Perform a X{MKCOL} on the given filepath. @param filepath: the L{FilePath} of the collection resource to create. @raise HTTPError: (containing an appropriate response) if the operation fails. @return: a deferred response with a status code of L{responsecode.CREATED} if the destination already exists, or L{responsecode.NO_CONTENT} if the destination was created by the X{MKCOL} operation. """ try: os.mkdir(filepath.path) # Remove stat info from filepath because we modified it filepath.changed() except: raise HTTPError(statusForFailure( Failure(), "creating directory in MKCOL: %s" % (filepath.path,) )) return succeed(responsecode.CREATED) def rmdir(dirname): """ Removes the directory with the given name, as well as its contents. @param dirname: the path to the directory to remove. """ for dir, subdirs, files in os.walk(dirname, topdown=False): for filename in files: os.remove(os.path.join(dir, filename)) for subdir in subdirs: path = os.path.join(dir, subdir) if os.path.islink(path): os.remove(path) else: os.rmdir(path) os.rmdir(dirname) def checkResponse(response, method, *codes): assert response in codes, \ "%s() returned %r, but should have returned one of %r instead" % (method, response, codes) calendarserver-5.2+dfsg/twext/web2/dav/__init__.py0000644000175000017500000000346012263343324021211 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test -*- ## # Copyright (c) 2009 Twisted Matrix Laboratories. # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ WebDAV support for Twext.Web2. See RFC 2616: http://www.ietf.org/rfc/rfc2616.txt (HTTP) See RFC 2518: http://www.ietf.org/rfc/rfc2518.txt (WebDAV) See RFC 3253: http://www.ietf.org/rfc/rfc3253.txt (WebDAV Versioning Extentions) See RFC 3744: http://www.ietf.org/rfc/rfc3744.txt (WebDAV Access Control Protocol) See also: http://skrb.org/ietf/http_errata.html (Errata to RFC 2616) """ __version__ = 'SVN-Trunk' version = __version__ __all__ = [ "auth", "fileop", "davxml", "http", "idav", "noneprops", "resource", "static", "stream", "util", "xattrprops", ] calendarserver-5.2+dfsg/twext/web2/dav/static.py0000644000175000017500000001552312263343324020744 0ustar rahulrahul# -*- test-case-name: twext.web2.dav.test.test_static -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ WebDAV-aware static resources. """ __all__ = ["DAVFile"] from twisted.python.filepath import InsecurePath from twisted.internet.defer import succeed, deferredGenerator, waitForDeferred from twext.python.log import Logger from twext.web2 import http_headers from twext.web2 import responsecode from twext.web2.dav.resource import DAVResource, davPrivilegeSet from twext.web2.dav.resource import TwistedGETContentMD5 from twext.web2.dav.util import bindMethods from twext.web2.http import HTTPError, StatusResponse from twext.web2.static import File log = Logger() try: from twext.web2.dav.xattrprops import xattrPropertyStore as DeadPropertyStore except ImportError: log.info("No dead property store available; using nonePropertyStore.") log.info("Setting of dead properties will not be allowed.") from twext.web2.dav.noneprops import NonePropertyStore as DeadPropertyStore class DAVFile (DAVResource, File): """ WebDAV-accessible File resource. Extends twext.web2.static.File to handle WebDAV methods. """ def __init__( self, path, defaultType="text/plain", indexNames=None, principalCollections=() ): """ @param path: the path of the file backing this resource. @param defaultType: the default mime type (as a string) for this resource and (eg. child) resources derived from it. @param indexNames: a sequence of index file names. @param acl: an L{IDAVAccessControlList} with the . """ File.__init__( self, path, defaultType = defaultType, ignoredExts = (), processors = None, indexNames = indexNames, ) DAVResource.__init__(self, principalCollections=principalCollections) def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self.fp.path) ## # WebDAV ## def etag(self): if not self.fp.exists(): return succeed(None) if self.hasDeadProperty(TwistedGETContentMD5): return succeed(http_headers.ETag(str(self.readDeadProperty(TwistedGETContentMD5)))) else: return super(DAVFile, self).etag() def davComplianceClasses(self): return ("1", "access-control") # Add "2" when we have locking def deadProperties(self): if not hasattr(self, "_dead_properties"): self._dead_properties = DeadPropertyStore(self) return self._dead_properties def isCollection(self): """ See L{IDAVResource.isCollection}. """ return self.fp.isdir() ## # ACL ## def supportedPrivileges(self, request): return succeed(davPrivilegeSet) ## # Quota ## def quotaSize(self, request): """ Get the size of this resource. TODO: Take into account size of dead-properties. Does stat include xattrs size? @return: an L{Deferred} with a C{int} result containing the size of the resource. """ if self.isCollection(): def walktree(top): """ Recursively descend the directory tree rooted at top, calling the callback function for each regular file @param top: L{FilePath} for the directory to walk. """ total = 0 for f in top.listdir(): child = top.child(f) if child.isdir(): # It's a directory, recurse into it result = waitForDeferred(walktree(child)) yield result total += result.getResult() elif child.isfile(): # It's a file, call the callback function total += child.getsize() else: # Unknown file type, print a message pass yield total walktree = deferredGenerator(walktree) return walktree(self.fp) else: return succeed(self.fp.getsize()) ## # Workarounds for issues with File ## def ignoreExt(self, ext): """ Does nothing; doesn't apply to this subclass. """ pass def locateChild(self, req, segments): """ See L{IResource}C{.locateChild}. """ # If getChild() finds a child resource, return it try: child = self.getChild(segments[0]) if child is not None: return (child, segments[1:]) except InsecurePath: raise HTTPError(StatusResponse(responsecode.FORBIDDEN, "Invalid URL path")) # If we're not backed by a directory, we have no children. # But check for existance first; we might be a collection resource # that the request wants created. self.fp.restat(False) if self.fp.exists() and not self.fp.isdir(): return (None, ()) # OK, we need to return a child corresponding to the first segment path = segments[0] if path == "": # Request is for a directory (collection) resource return (self, ()) return (self.createSimilarFile(self.fp.child(path).path), segments[1:]) def createSimilarFile(self, path): return self.__class__( path, defaultType=self.defaultType, indexNames=self.indexNames[:], principalCollections=self.principalCollections()) # # Attach method handlers to DAVFile # import twext.web2.dav.method bindMethods(twext.web2.dav.method, DAVFile) calendarserver-5.2+dfsg/twext/web2/dav/util.py0000644000175000017500000001520512263343324020427 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_util -*- ## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # DRI: Wilfredo Sanchez, wsanchez@apple.com ## """ Utilities This API is considered private to static.py and is therefore subject to change. """ __all__ = [ "allDataFromStream", "davXMLFromStream", "noDataFromStream", "normalizeURL", "joinURL", "parentForURL", "unimplemented", "bindMethods", ] import urllib from urlparse import urlsplit, urlunsplit import posixpath # Careful; this module is not documented as public API from twisted.python.failure import Failure from twisted.internet.defer import succeed from twext.python.log import Logger from twext.web2.stream import readStream from txdav.xml.parser import WebDAVDocument log = Logger() ## # Reading request body ## def allDataFromStream(stream, filter=None): data = [] def gotAllData(_): if not data: return None result = "".join([str(x) for x in data]) if filter is None: return result else: return filter(result) return readStream(stream, data.append).addCallback(gotAllData) def davXMLFromStream(stream): # FIXME: # This reads the request body into a string and then parses it. # A better solution would parse directly and incrementally from the # request stream. if stream is None: return succeed(None) def parse(xml): try: doc = WebDAVDocument.fromString(xml) doc.root_element.validate() return doc except ValueError: log.error("Bad XML:\n%s" % (xml,)) raise return allDataFromStream(stream, parse) def noDataFromStream(stream): def gotData(data): if data: raise ValueError("Stream contains unexpected data.") return readStream(stream, gotData) ## # URLs ## def normalizeURL(url): """ Normalized a URL. @param url: a URL. @return: the normalized representation of C{url}. The returned URL will never contain a trailing C{"/"}; it is up to the caller to determine whether the resource referred to by the URL is a collection and add a trailing C{"/"} if so. """ def cleanup(path): # For some silly reason, posixpath.normpath doesn't clean up '//' at the # start of a filename, so let's clean it up here. if path[0] == "/": count = 0 for char in path: if char != "/": break count += 1 path = path[count - 1:] return path (scheme, host, path, query, fragment) = urlsplit(cleanup(url)) path = cleanup(posixpath.normpath(urllib.unquote(path))) return urlunsplit((scheme, host, urllib.quote(path), query, fragment)) def joinURL(*urls): """ Appends URLs in series. @param urls: URLs to join. @return: the normalized URL formed by combining each URL in C{urls}. The returned URL will contain a trailing C{"/"} if and only if the last given URL contains a trailing C{"/"}. """ if len(urls) > 0 and len(urls[-1]) > 0 and urls[-1][-1] == "/": trailing = "/" else: trailing = "" url = normalizeURL("/".join([url for url in urls])) if url == "/": return "/" else: return url + trailing def parentForURL(url): """ Extracts the URL of the containing collection resource for the resource corresponding to a given URL. This removes any query or fragment pieces. @param url: an absolute (server-relative is OK) URL. @return: the normalized URL of the collection resource containing the resource corresponding to C{url}. The returned URL will always contain a trailing C{"/"}. """ (scheme, host, path, _ignore_query, _ignore_fragment) = urlsplit(normalizeURL(url)) index = path.rfind("/") if index is 0: if path == "/": return None else: path = "/" else: if index is -1: raise ValueError("Invalid URL: %s" % (url,)) else: path = path[:index] + "/" return urlunsplit((scheme, host, path, None, None)) ## # Python magic ## def unimplemented(obj): """ Throw an exception signifying that the current method is unimplemented and should not have been invoked. """ import inspect caller = inspect.getouterframes(inspect.currentframe())[1][3] raise NotImplementedError("Method %s is unimplemented in subclass %s" % (caller, obj.__class__)) def bindMethods(module, clazz, prefixes=("preconditions_", "http_", "report_")): """ Binds all functions in the given module (as defined by that module's C{__all__} attribute) which start with any of the given prefixes as methods of the given class. @param module: the module in which to search for functions. @param clazz: the class to bind found functions to as methods. @param prefixes: a sequence of prefixes to match found functions against. """ for submodule_name in module.__all__: try: __import__(module.__name__ + "." + submodule_name) except ImportError: log.error("Unable to import module %s" % (module.__name__ + "." + submodule_name,)) Failure().raiseException() submodule = getattr(module, submodule_name) for method_name in submodule.__all__: for prefix in prefixes: if method_name.startswith(prefix): method = getattr(submodule, method_name) setattr(clazz, method_name, method) break calendarserver-5.2+dfsg/twext/web2/dav/auth.py0000644000175000017500000001441112263343324020411 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. ## __all__ = [ "IPrincipal", "DavRealm", "IPrincipalCredentials", "PrincipalCredentials", "AuthenticationWrapper", ] from zope.interface import implements, Interface from twisted.internet import defer from twisted.cred import checkers, error, portal from twext.web2.resource import WrapperResource from txdav.xml.element import twisted_private_namespace, registerElement from txdav.xml.element import WebDAVTextElement, Principal, HRef class AuthenticationWrapper(WrapperResource): def __init__(self, resource, portal, wireEncryptedCredentialFactories, wireUnencryptedCredentialFactories, loginInterfaces): """ Wrap the given resource and use the parameters to set up the request to allow anyone to challenge and handle authentication. @param resource: L{DAVResource} FIXME: This should get promoted to twext.web2.auth @param portal: The cred portal @param wireEncryptedCredentialFactories: Sequence of credentialFactories that can be used to authenticate by resources in this tree over a wire-encrypted channel (SSL). @param wireUnencryptedCredentialFactories: Sequence of credentialFactories that can be used to authenticate by resources in this tree over a wire-unencrypted channel (non-SSL). @param loginInterfaces: More cred stuff """ super(AuthenticationWrapper, self).__init__(resource) self.portal = portal self.wireEncryptedCredentialFactories = dict([(factory.scheme, factory) for factory in wireEncryptedCredentialFactories]) self.wireUnencryptedCredentialFactories = dict([(factory.scheme, factory) for factory in wireUnencryptedCredentialFactories]) self.loginInterfaces = loginInterfaces # FIXME: some unit tests access self.credentialFactories, so assigning here self.credentialFactories = self.wireEncryptedCredentialFactories def hook(self, req): req.portal = self.portal req.loginInterfaces = self.loginInterfaces # If not using SSL, use the factory list which excludes "Basic" if getattr(req, "chanRequest", None) is None: # This is only None in unit tests secureConnection = True else: ignored, secureConnection = req.chanRequest.getHostInfo() req.credentialFactories = ( self.wireEncryptedCredentialFactories if secureConnection else self.wireUnencryptedCredentialFactories ) class IPrincipal(Interface): pass class DavRealm(object): implements(portal.IRealm) def requestAvatar(self, avatarId, mind, *interfaces): if IPrincipal in interfaces: return IPrincipal, Principal(HRef(avatarId[0])), Principal(HRef(avatarId[1])) raise NotImplementedError("Only IPrincipal interface is supported") class IPrincipalCredentials(Interface): pass class PrincipalCredentials(object): implements(IPrincipalCredentials) def __init__(self, authnPrincipal, authzPrincipal, credentials): """ Initialize with both authentication and authorization values. Note that in most cases theses will be the same since HTTP auth makes no distinction between the two - but we may be layering some addition auth on top of this (.e.g.. proxy auth, cookies, forms etc) that make result in authentication and authorization being different. @param authnPrincipal: L{IDAVPrincipalResource} for the authenticated principal. @param authnURI: C{str} containing the URI of the authenticated principal. @param authzPrincipal: L{IDAVPrincipalResource} for the authorized principal. @param authzURI: C{str} containing the URI of the authorized principal. @param credentials: L{ICredentials} for the authentication credentials. """ self.authnPrincipal = authnPrincipal self.authzPrincipal = authzPrincipal self.credentials = credentials def checkPassword(self, password): return self.credentials.checkPassword(password) class TwistedPropertyChecker(object): implements(checkers.ICredentialsChecker) credentialInterfaces = (IPrincipalCredentials,) def _cbPasswordMatch(self, matched, principalURIs): if matched: # We return both URIs return principalURIs else: raise error.UnauthorizedLogin("Bad credentials for: %s" % (principalURIs[0],)) def requestAvatarId(self, credentials): pcreds = IPrincipalCredentials(credentials) pswd = str(pcreds.authnPrincipal.readDeadProperty(TwistedPasswordProperty)) d = defer.maybeDeferred(credentials.checkPassword, pswd) d.addCallback(self._cbPasswordMatch, ( pcreds.authnPrincipal.principalURL(), pcreds.authzPrincipal.principalURL(), pcreds.authnPrincipal, pcreds.authzPrincipal, )) return d ## # Utilities ## class TwistedPasswordProperty (WebDAVTextElement): namespace = twisted_private_namespace name = "password" registerElement(TwistedPasswordProperty) calendarserver-5.2+dfsg/twext/web2/http_headers.py0000644000175000017500000015620712263343324021362 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_http_headers -*- ## # Copyright (c) 2008 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## from __future__ import print_function """ HTTP header representation, parsing, and serialization. """ import time from calendar import timegm import base64 import re def dashCapitalize(s): ''' Capitalize a string, making sure to treat - as a word separator ''' return '-'.join([x.capitalize() for x in s.split('-')]) # datetime parsing and formatting weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] weekdayname_lower = [name.lower() for name in weekdayname] monthname = [None, 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] monthname_lower = [name and name.lower() for name in monthname] # HTTP Header parsing API header_case_mapping = {} def casemappingify(d): global header_case_mapping newd = dict([(key.lower(), key) for key in d.keys()]) header_case_mapping.update(newd) def lowerify(d): return dict([(key.lower(), value) for key, value in d.items()]) class HeaderHandler(object): """HeaderHandler manages header generating and parsing functions. """ HTTPParsers = {} HTTPGenerators = {} def __init__(self, parsers=None, generators=None): """ @param parsers: A map of header names to parsing functions. @type parsers: L{dict} @param generators: A map of header names to generating functions. @type generators: L{dict} """ if parsers: self.HTTPParsers.update(parsers) if generators: self.HTTPGenerators.update(generators) def parse(self, name, header): """ Parse the given header based on its given name. @param name: The header name to parse. @type name: C{str} @param header: A list of unparsed headers. @type header: C{list} of C{str} @return: The return value is the parsed header representation, it is dependent on the header. See the HTTP Headers document. """ parser = self.HTTPParsers.get(name, None) if parser is None: raise ValueError("No header parser for header '%s', either add one or use getHeaderRaw." % (name,)) try: for p in parser: # print("Parsing %s: %s(%s)" % (name, repr(p), repr(h))) header = p(header) # if isinstance(h, types.GeneratorType): # h=list(h) except ValueError: header = None return header def generate(self, name, header): """ Generate the given header based on its given name. @param name: The header name to generate. @type name: C{str} @param header: A parsed header, such as the output of L{HeaderHandler}.parse. @return: C{list} of C{str} each representing a generated HTTP header. """ generator = self.HTTPGenerators.get(name, None) if generator is None: # print(self.generators) raise ValueError("No header generator for header '%s', either add one or use setHeaderRaw." % (name,)) for g in generator: header = g(header) # self._raw_headers[name] = h return header def updateParsers(self, parsers): """Update en masse the parser maps. @param parsers: Map of header names to parser chains. @type parsers: C{dict} """ casemappingify(parsers) self.HTTPParsers.update(lowerify(parsers)) def addParser(self, name, value): """Add an individual parser chain for the given header. @param name: Name of the header to add @type name: C{str} @param value: The parser chain @type value: C{str} """ self.updateParsers({name: value}) def updateGenerators(self, generators): """Update en masse the generator maps. @param parsers: Map of header names to generator chains. @type parsers: C{dict} """ casemappingify(generators) self.HTTPGenerators.update(lowerify(generators)) def addGenerators(self, name, value): """Add an individual generator chain for the given header. @param name: Name of the header to add @type name: C{str} @param value: The generator chain @type value: C{str} """ self.updateGenerators({name: value}) def update(self, parsers, generators): """Conveniently update parsers and generators all at once. """ self.updateParsers(parsers) self.updateGenerators(generators) DefaultHTTPHandler = HeaderHandler() # # HTTP DateTime parser def parseDateTime(dateString): """Convert an HTTP date string (one of three formats) to seconds since epoch.""" parts = dateString.split() if not parts[0][0:3].lower() in weekdayname_lower: # Weekday is stupid. Might have been omitted. try: return parseDateTime("Sun, " + dateString) except ValueError: # Guess not. pass partlen = len(parts) if (partlen == 5 or partlen == 6) and parts[1].isdigit(): # 1st date format: Sun, 06 Nov 1994 08:49:37 GMT # (Note: "GMT" is literal, not a variable timezone) # (also handles without "GMT") # This is the normal format day = parts[1] month = parts[2] year = parts[3] time = parts[4] elif (partlen == 3 or partlen == 4) and parts[1].find('-') != -1: # 2nd date format: Sunday, 06-Nov-94 08:49:37 GMT # (Note: "GMT" is literal, not a variable timezone) # (also handles without without "GMT") # Two digit year, yucko. day, month, year = parts[1].split('-') time = parts[2] year = int(year) if year < 69: year = year + 2000 elif year < 100: year = year + 1900 elif len(parts) == 5: # 3rd date format: Sun Nov 6 08:49:37 1994 # ANSI C asctime() format. day = parts[2] month = parts[1] year = parts[4] time = parts[3] else: raise ValueError("Unknown datetime format %r" % dateString) day = int(day) month = int(monthname_lower.index(month.lower())) year = int(year) hour, min, sec = map(int, time.split(':')) return int(timegm((year, month, day, hour, min, sec))) ##### HTTP tokenizer class Token(str): __slots__ = [] tokens = {} def __new__(self, char): token = Token.tokens.get(char) if token is None: Token.tokens[char] = token = str.__new__(self, char) return token def __repr__(self): return "Token(%s)" % str.__repr__(self) # RFC 2616 section 2.2 http_tokens = " \t\"()<>@,;:\\/[]?={}" http_ctls = "\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x7f" def tokenize(header, foldCase=True): """Tokenize a string according to normal HTTP header parsing rules. In particular: - Whitespace is irrelevant and eaten next to special separator tokens. Its existance (but not amount) is important between character strings. - Quoted string support including embedded backslashes. - Case is insignificant (and thus lowercased), except in quoted strings. (unless foldCase=False) - Multiple headers are concatenated with ',' NOTE: not all headers can be parsed with this function. Takes a raw header value (list of strings), and Returns a generator of strings and Token class instances. """ tokens = http_tokens ctls = http_ctls string = ",".join(header) start = 0 cur = 0 quoted = False qpair = False inSpaces = -1 qstring = None for x in string: if quoted: if qpair: qpair = False qstring = qstring + string[start:cur - 1] + x start = cur + 1 elif x == '\\': qpair = True elif x == '"': quoted = False yield qstring + string[start:cur] qstring = None start = cur + 1 elif x in tokens: if start != cur: if foldCase: yield string[start:cur].lower() else: yield string[start:cur] start = cur + 1 if x == '"': quoted = True qstring = "" inSpaces = False elif x in " \t": if inSpaces is False: inSpaces = True else: inSpaces = -1 yield Token(x) elif x in ctls: raise ValueError("Invalid control character: %d in header" % ord(x)) else: if inSpaces is True: yield Token(' ') inSpaces = False inSpaces = False cur = cur + 1 if qpair: raise ValueError("Missing character after '\\'") if quoted: raise ValueError("Missing end quote") if start != cur: if foldCase: yield string[start:cur].lower() else: yield string[start:cur] def split(seq, delim): """The same as str.split but works on arbitrary sequences. Too bad it's not builtin to python!""" cur = [] for item in seq: if item == delim: yield cur cur = [] else: cur.append(item) yield cur # def find(seq, *args): # """The same as seq.index but returns -1 if not found, instead # Too bad it's not builtin to python!""" # try: # return seq.index(value, *args) # except ValueError: # return -1 def filterTokens(seq): """Filter out instances of Token, leaving only a list of strings. Used instead of a more specific parsing method (e.g. splitting on commas) when only strings are expected, so as to be a little lenient. Apache does it this way and has some comments about broken clients which forget commas (?), so I'm doing it the same way. It shouldn't hurt anything, in any case. """ l = [] for x in seq: if not isinstance(x, Token): l.append(x) return l ##### parser utilities: def checkSingleToken(tokens): if len(tokens) != 1: raise ValueError("Expected single token, not %s." % (tokens,)) return tokens[0] def parseKeyValue(val): if len(val) == 1: return val[0], None elif len(val) == 3 and val[1] == Token('='): return val[0], val[2] raise ValueError("Expected key or key=value, but got %s." % (val,)) def parseArgs(field): args = split(field, Token(';')) val = args.next() args = [parseKeyValue(arg) for arg in args] return val, args def listParser(fun): """Return a function which applies 'fun' to every element in the comma-separated list""" def listParserHelper(tokens): fields = split(tokens, Token(',')) for field in fields: if len(field) != 0: yield fun(field) return listParserHelper def last(seq): """Return seq[-1]""" return seq[-1] ##### Generation utilities def quoteString(s): """ Quote a string according to the rules for the I{quoted-string} production in RFC 2616 section 2.2. @type s: C{str} @rtype: C{str} """ return '"%s"' % s.replace('\\', '\\\\').replace('"', '\\"') def listGenerator(fun): """Return a function which applies 'fun' to every element in the given list, then joins the result with generateList""" def listGeneratorHelper(l): return generateList([fun(e) for e in l]) return listGeneratorHelper def generateList(seq): return ", ".join(seq) def singleHeader(item): return [item] _seperators = re.compile('[' + re.escape(http_tokens) + ']') def generateKeyValues(parameters): """ Format an iterable of key/value pairs. Although each header in HTTP 1.1 redefines the grammar for the formatting of its parameters, the grammar defined by almost all headers conforms to the specification given in RFC 2046. Note also that RFC 2616 section 19.2 note 2 points out that many implementations fail if the value is quoted, therefore this function only quotes the value when it is necessary. @param parameters: An iterable of C{tuple} of a C{str} parameter name and C{str} or C{None} parameter value which will be formated. @return: The formatted result. @rtype: C{str} """ l = [] for k, v in parameters: if v is None: l.append('%s' % k) else: if _seperators.search(v) is not None: v = quoteString(v) l.append('%s=%s' % (k, v)) return ";".join(l) class MimeType(object): def fromString(cls, mimeTypeString): """Generate a MimeType object from the given string. @param mimeTypeString: The mimetype to parse @return: L{MimeType} """ return DefaultHTTPHandler.parse('content-type', [mimeTypeString]) fromString = classmethod(fromString) def __init__(self, mediaType, mediaSubtype, params={}, **kwargs): """ @type mediaType: C{str} @type mediaSubtype: C{str} @type params: C{dict} """ self.mediaType = mediaType self.mediaSubtype = mediaSubtype self.params = dict(params) if kwargs: self.params.update(kwargs) def __eq__(self, other): if not isinstance(other, MimeType): return NotImplemented return (self.mediaType == other.mediaType and self.mediaSubtype == other.mediaSubtype and self.params == other.params) def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "MimeType(%r, %r, %r)" % (self.mediaType, self.mediaSubtype, self.params) def __hash__(self): return hash(self.mediaType) ^ hash(self.mediaSubtype) ^ hash(tuple(self.params.iteritems())) class MimeDisposition(object): def fromString(cls, dispositionString): """Generate a MimeDisposition object from the given string. @param dispositionString: The disposition to parse @return: L{MimeDisposition} """ return DefaultHTTPHandler.parse('content-disposition', [dispositionString]) fromString = classmethod(fromString) def __init__(self, dispositionType, params={}, **kwargs): """ @type mediaType: C{str} @type mediaSubtype: C{str} @type params: C{dict} """ self.dispositionType = dispositionType self.params = dict(params) if kwargs: self.params.update(kwargs) def __eq__(self, other): if not isinstance(other, MimeDisposition): return NotImplemented return (self.dispositionType == other.dispositionType and self.params == other.params) def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "MimeDisposition(%r, %r)" % (self.dispositionType, self.params) def __hash__(self): return hash(self.dispositionType) ^ hash(tuple(self.params.iteritems())) ##### Specific header parsers. def parseAccept(field): atype, args = parseArgs(field) if len(atype) != 3 or atype[1] != Token('/'): raise ValueError("MIME Type " + str(atype) + " invalid.") # okay, this spec is screwy. A 'q' parameter is used as the separator # between MIME parameters and (as yet undefined) additional HTTP # parameters. num = 0 for arg in args: if arg[0] == 'q': mimeparams = tuple(args[0:num]) params = args[num:] break num = num + 1 else: mimeparams = tuple(args) params = [] # Default values for parameters: qval = 1.0 # Parse accept parameters: for param in params: if param[0] == 'q': qval = float(param[1]) else: # Warn? ignored parameter. pass ret = MimeType(atype[0], atype[2], mimeparams), qval return ret def parseAcceptQvalue(field): atype, args = parseArgs(field) atype = checkSingleToken(atype) qvalue = 1.0 # Default qvalue is 1 for arg in args: if arg[0] == 'q': qvalue = float(arg[1]) return atype, qvalue def addDefaultCharset(charsets): if charsets.get('*') is None and charsets.get('iso-8859-1') is None: charsets['iso-8859-1'] = 1.0 return charsets def addDefaultEncoding(encodings): if encodings.get('*') is None and encodings.get('identity') is None: # RFC doesn't specify a default value for identity, only that it # "is acceptable" if not mentioned. Thus, give it a very low qvalue. encodings['identity'] = .0001 return encodings def parseContentType(header): # Case folding is disabled for this header, because of use of # Content-Type: multipart/form-data; boundary=CaSeFuLsTuFf # So, we need to explicitly .lower() the ctype and arg keys. ctype, args = parseArgs(header) if len(ctype) != 3 or ctype[1] != Token('/'): raise ValueError("MIME Type " + str(ctype) + " invalid.") args = [(kv[0].lower(), kv[1]) for kv in args] return MimeType(ctype[0].lower(), ctype[2].lower(), tuple(args)) def parseContentDisposition(header): # Case folding is disabled for this header, because of use of # So, we need to explicitly .lower() the dtype and arg keys. dtype, args = parseArgs(header) if len(dtype) != 1: raise ValueError("Content-Disposition " + str(dtype) + " invalid.") args = [(kv[0].lower(), kv[1]) for kv in args] return MimeDisposition(dtype[0].lower(), tuple(args)) def parseContentMD5(header): try: return base64.decodestring(header) except Exception, e: raise ValueError(e) def parseContentRange(header): """Parse a content-range header into (kind, start, end, realLength). realLength might be None if real length is not known ('*'). start and end might be None if start,end unspecified (for response code 416) """ kind, other = header.strip().split() if kind.lower() != "bytes": raise ValueError("a range of type %r is not supported") startend, realLength = other.split("/") if startend.strip() == '*': start, end = None, None else: start, end = map(int, startend.split("-")) if realLength == "*": realLength = None else: realLength = int(realLength) return (kind, start, end, realLength) def parseExpect(field): etype, args = parseArgs(field) etype = parseKeyValue(etype) return (etype[0], (lambda *args: args)(etype[1], *args)) def parseExpires(header): # """HTTP/1.1 clients and caches MUST treat other invalid date formats, # especially including the value 0, as in the past (i.e., "already expired").""" try: return parseDateTime(header) except ValueError: return 0 def parseIfModifiedSince(header): # Ancient versions of netscape and *current* versions of MSIE send # If-Modified-Since: Thu, 05 Aug 2004 12:57:27 GMT; length=123 # which is blantantly RFC-violating and not documented anywhere # except bug-trackers for web frameworks. # So, we'll just strip off everything after a ';'. return parseDateTime(header.split(';', 1)[0]) def parseIfRange(headers): try: return ETag.parse(tokenize(headers)) except ValueError: return parseDateTime(last(headers)) def parseRange(crange): crange = list(crange) if len(crange) < 3 or crange[1] != Token('='): raise ValueError("Invalid range header format: %s" % (crange,)) rtype = crange[0] if rtype != 'bytes': raise ValueError("Unknown range unit: %s." % (rtype,)) rangeset = split(crange[2:], Token(',')) ranges = [] for byterangespec in rangeset: if len(byterangespec) != 1: raise ValueError("Invalid range header format: %s" % (crange,)) start, end = byterangespec[0].split('-') if not start and not end: raise ValueError("Invalid range header format: %s" % (crange,)) if start: start = int(start) else: start = None if end: end = int(end) else: end = None if start and end and start > end: raise ValueError("Invalid range header, start > end: %s" % (crange,)) ranges.append((start, end)) return rtype, ranges def parseRetryAfter(header): try: # delta seconds return time.time() + int(header) except ValueError: # or datetime return parseDateTime(header) # WWW-Authenticate and Authorization def parseWWWAuthenticate(tokenized): headers = [] tokenList = list(tokenized) while tokenList: scheme = tokenList.pop(0) challenge = {} last = None kvChallenge = False while tokenList: token = tokenList.pop(0) if token == Token('='): kvChallenge = True challenge[last] = tokenList.pop(0) last = None elif token == Token(','): if kvChallenge: if len(tokenList) > 1 and tokenList[1] != Token('='): break else: break else: last = token if last and scheme and not challenge and not kvChallenge: challenge = last last = None headers.append((scheme, challenge)) if last and last not in (Token('='), Token(',')): if headers[-1] == (scheme, challenge): scheme = last challenge = {} headers.append((scheme, challenge)) return headers def parseAuthorization(header): scheme, rest = header.split(' ', 1) # this header isn't tokenized because it may eat characters # in the unquoted base64 encoded credentials return scheme.lower(), rest def parsePrefer(field): etype, args = parseArgs(field) etype = parseKeyValue(etype) return (etype[0], etype[1], args) #### Header generators def generateAccept(accept): mimeType, q = accept out = "%s/%s" % (mimeType.mediaType, mimeType.mediaSubtype) if mimeType.params: out += ';' + generateKeyValues(mimeType.params.iteritems()) if q != 1.0: out += (';q=%.3f' % (q,)).rstrip('0').rstrip('.') return out def removeDefaultEncoding(seq): for item in seq: if item[0] != 'identity' or item[1] != .0001: yield item def generateAcceptQvalue(keyvalue): if keyvalue[1] == 1.0: return "%s" % keyvalue[0:1] else: return ("%s;q=%.3f" % keyvalue).rstrip('0').rstrip('.') def parseCacheControl(kv): k, v = parseKeyValue(kv) if k == 'max-age' or k == 'min-fresh' or k == 's-maxage': # Required integer argument if v is None: v = 0 else: v = int(v) elif k == 'max-stale': # Optional integer argument if v is not None: v = int(v) elif k == 'private' or k == 'no-cache': # Optional list argument if v is not None: v = [field.strip().lower() for field in v.split(',')] return k, v def generateCacheControl((k, v)): if v is None: return str(k) else: if k == 'no-cache' or k == 'private': # quoted list of values v = quoteString(generateList( [header_case_mapping.get(name) or dashCapitalize(name) for name in v])) return '%s=%s' % (k, v) def generateContentRange(tup): """tup is (rtype, start, end, rlen) rlen can be None. """ rtype, start, end, rlen = tup if rlen == None: rlen = '*' else: rlen = int(rlen) if start == None and end == None: startend = '*' else: startend = '%d-%d' % (start, end) return '%s %s/%s' % (rtype, startend, rlen) def generateDateTime(secSinceEpoch): """Convert seconds since epoch to HTTP datetime string.""" year, month, day, hh, mm, ss, wd, _ignore_y, _ignore_z = time.gmtime(secSinceEpoch) s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( weekdayname[wd], day, monthname[month], year, hh, mm, ss) return s def generateExpect(item): if item[1][0] is None: out = '%s' % (item[0],) else: out = '%s=%s' % (item[0], item[1][0]) if len(item[1]) > 1: out += ';' + generateKeyValues(item[1][1:]) return out def generateRange(crange): def noneOr(s): if s is None: return '' return s rtype, ranges = crange if rtype != 'bytes': raise ValueError("Unknown range unit: " + rtype + ".") return (rtype + '=' + ','.join(['%s-%s' % (noneOr(startend[0]), noneOr(startend[1])) for startend in ranges])) def generateRetryAfter(when): # always generate delta seconds format return str(int(when - time.time())) def generateContentType(mimeType): out = "%s/%s" % (mimeType.mediaType, mimeType.mediaSubtype) if mimeType.params: out += ';' + generateKeyValues(mimeType.params.iteritems()) return out def generateContentDisposition(disposition): out = disposition.dispositionType if disposition.params: out += ';' + generateKeyValues(disposition.params.iteritems()) return out def generateIfRange(dateOrETag): if isinstance(dateOrETag, ETag): return dateOrETag.generate() else: return generateDateTime(dateOrETag) # WWW-Authenticate and Authorization def generateWWWAuthenticate(headers): _generated = [] for seq in headers: scheme, challenge = seq[0], seq[1] # If we're going to parse out to something other than a dict # we need to be able to generate from something other than a dict try: l = [] for k, v in dict(challenge).iteritems(): l.append("%s=%s" % (k, quoteString(v))) _generated.append("%s %s" % (scheme, ", ".join(l))) except ValueError: _generated.append("%s %s" % (scheme, challenge)) return _generated def generateAuthorization(seq): return [' '.join(seq)] def generatePrefer(items): key, value, args = items if value is None: out = '%s' % (key,) else: out = '%s=%s' % (key, value) if args: out += ';' + generateKeyValues(args) return out #### class ETag(object): def __init__(self, tag, weak=False): self.tag = str(tag) self.weak = weak def match(self, other, strongCompare): # Sec 13.3. # The strong comparison function: in order to be considered equal, both # validators MUST be identical in every way, and both MUST NOT be weak. # # The weak comparison function: in order to be considered equal, both # validators MUST be identical in every way, but either or both of # them MAY be tagged as "weak" without affecting the result. if not isinstance(other, ETag) or other.tag != self.tag: return False if strongCompare and (other.weak or self.weak): return False return True def __eq__(self, other): return isinstance(other, ETag) and other.tag == self.tag and other.weak == self.weak def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "Etag(%r, weak=%r)" % (self.tag, self.weak) def parse(tokens): tokens = tuple(tokens) if len(tokens) == 1 and not isinstance(tokens[0], Token): return ETag(tokens[0]) if(len(tokens) == 3 and tokens[0] == "w" and tokens[1] == Token('/')): return ETag(tokens[2], weak=True) raise ValueError("Invalid ETag.") parse = staticmethod(parse) def generate(self): if self.weak: return 'W/' + quoteString(self.tag) else: return quoteString(self.tag) def parseStarOrETag(tokens): tokens = tuple(tokens) if tokens == ('*',): return '*' else: return ETag.parse(tokens) def generateStarOrETag(etag): if etag == '*': return etag else: return etag.generate() #### Cookies. Blech! class Cookie(object): # __slots__ = ['name', 'value', 'path', 'domain', 'ports', 'expires', 'discard', 'secure', 'comment', 'commenturl', 'version'] def __init__(self, name, value, path=None, domain=None, ports=None, expires=None, discard=False, secure=False, comment=None, commenturl=None, version=0): self.name = name self.value = value self.path = path self.domain = domain self.ports = ports self.expires = expires self.discard = discard self.secure = secure self.comment = comment self.commenturl = commenturl self.version = version def __repr__(self): s = "Cookie(%r=%r" % (self.name, self.value) if self.path is not None: s += ", path=%r" % (self.path,) if self.domain is not None: s += ", domain=%r" % (self.domain,) if self.ports is not None: s += ", ports=%r" % (self.ports,) if self.expires is not None: s += ", expires=%r" % (self.expires,) if self.secure is not False: s += ", secure=%r" % (self.secure,) if self.comment is not None: s += ", comment=%r" % (self.comment,) if self.commenturl is not None: s += ", commenturl=%r" % (self.commenturl,) if self.version != 0: s += ", version=%r" % (self.version,) s += ")" return s def __eq__(self, other): return (isinstance(other, Cookie) and other.path == self.path and other.domain == self.domain and other.ports == self.ports and other.expires == self.expires and other.secure == self.secure and other.comment == self.comment and other.commenturl == self.commenturl and other.version == self.version) def __ne__(self, other): return not self.__eq__(other) def parseCookie(headers): """Bleargh, the cookie spec sucks. This surely needs interoperability testing. There are two specs that are supported: Version 0) http://wp.netscape.com/newsref/std/cookie_spec.html Version 1) http://www.faqs.org/rfcs/rfc2965.html """ cookies = [] # There can't really be multiple cookie headers according to RFC, because # if multiple headers are allowed, they must be joinable with ",". # Neither new RFC2965 cookies nor old netscape cookies are. header = ';'.join(headers) if header[0:8].lower() == "$version": # RFC2965 cookie h = tokenize([header], foldCase=False) r_cookies = split(h, Token(',')) for r_cookie in r_cookies: last_cookie = None rr_cookies = split(r_cookie, Token(';')) for cookie in rr_cookies: nameval = tuple(split(cookie, Token('='))) if len(nameval) == 2: (name,), (value,) = nameval else: (name,), = nameval value = None name = name.lower() if name == '$version': continue if name[0] == '$': if last_cookie is not None: if name == '$path': last_cookie.path = value elif name == '$domain': last_cookie.domain = value elif name == '$port': if value is None: last_cookie.ports = () else: last_cookie.ports = tuple([int(s) for s in value.split(',')]) else: last_cookie = Cookie(name, value, version=1) cookies.append(last_cookie) else: # Oldstyle cookies don't do quoted strings or anything sensible. # All characters are valid for names except ';' and '=', and all # characters are valid for values except ';'. Spaces are stripped, # however. r_cookies = header.split(';') for r_cookie in r_cookies: name, value = r_cookie.split('=', 1) name = name.strip(' \t') value = value.strip(' \t') cookies.append(Cookie(name, value)) return cookies cookie_validname = "[^" + re.escape(http_tokens + http_ctls) + "]*$" cookie_validname_re = re.compile(cookie_validname) cookie_validvalue = cookie_validname + '|"([^"]|\\\\")*"$' cookie_validvalue_re = re.compile(cookie_validvalue) def generateCookie(cookies): # There's a fundamental problem with the two cookie specifications. # They both use the "Cookie" header, and the RFC Cookie header only allows # one version to be specified. Thus, when you have a collection of V0 and # V1 cookies, you have to either send them all as V0 or send them all as # V1. # I choose to send them all as V1. # You might think converting a V0 cookie to a V1 cookie would be lossless, # but you'd be wrong. If you do the conversion, and a V0 parser tries to # read the cookie, it will see a modified form of the cookie, in cases # where quotes must be added to conform to proper V1 syntax. # (as a real example: "Cookie: cartcontents=oid:94680,qty:1,auto:0,esp:y") # However, that is what we will do, anyways. It has a high probability of # breaking applications that only handle oldstyle cookies, where some other # application set a newstyle cookie that is applicable over for site # (or host), AND where the oldstyle cookie uses a value which is invalid # syntax in a newstyle cookie. # Also, the cookie name *cannot* be quoted in V1, so some cookies just # cannot be converted at all. (e.g. "Cookie: phpAds_capAd[32]=2"). These # are just dicarded during conversion. # As this is an unsolvable problem, I will pretend I can just say # OH WELL, don't do that, or else upgrade your old applications to have # newstyle cookie parsers. # I will note offhandedly that there are *many* sites which send V0 cookies # that are not valid V1 cookie syntax. About 20% for my cookies file. # However, they do not generally mix them with V1 cookies, so this isn't # an issue, at least right now. I have not tested to see how many of those # webapps support RFC2965 V1 cookies. I suspect not many. max_version = max([cookie.version for cookie in cookies]) if max_version == 0: # no quoting or anything. return ';'.join(["%s=%s" % (cookie.name, cookie.value) for cookie in cookies]) else: str_cookies = ['$Version="1"'] for cookie in cookies: if cookie.version == 0: # Version 0 cookie: we make sure the name and value are valid # V1 syntax. # If they are, we use them as is. This means in *most* cases, # the cookie will look literally the same on output as it did # on input. # If it isn't a valid name, ignore the cookie. # If it isn't a valid value, quote it and hope for the best on # the other side. if cookie_validname_re.match(cookie.name) is None: continue value = cookie.value if cookie_validvalue_re.match(cookie.value) is None: value = quoteString(value) str_cookies.append("%s=%s" % (cookie.name, value)) else: # V1 cookie, nice and easy str_cookies.append("%s=%s" % (cookie.name, quoteString(cookie.value))) if cookie.path: str_cookies.append("$Path=%s" % quoteString(cookie.path)) if cookie.domain: str_cookies.append("$Domain=%s" % quoteString(cookie.domain)) if cookie.ports is not None: if len(cookie.ports) == 0: str_cookies.append("$Port") else: str_cookies.append("$Port=%s" % quoteString(",".join([str(x) for x in cookie.ports]))) return ';'.join(str_cookies) def parseSetCookie(headers): setCookies = [] for header in headers: try: parts = header.split(';') l = [] for part in parts: namevalue = part.split('=', 1) if len(namevalue) == 1: name = namevalue[0] value = None else: name, value = namevalue value = value.strip(' \t') name = name.strip(' \t') l.append((name, value)) setCookies.append(makeCookieFromList(l, True)) except ValueError: # If we can't parse one Set-Cookie, ignore it, # but not the rest of Set-Cookies. pass return setCookies def parseSetCookie2(toks): outCookies = [] for cookie in [[parseKeyValue(x) for x in split(y, Token(';'))] for y in split(toks, Token(','))]: try: outCookies.append(makeCookieFromList(cookie, False)) except ValueError: # Again, if we can't handle one cookie -- ignore it. pass return outCookies def makeCookieFromList(tup, netscapeFormat): name, value = tup[0] if name is None or value is None: raise ValueError("Cookie has missing name or value") if name.startswith("$"): raise ValueError("Invalid cookie name: %r, starts with '$'." % name) cookie = Cookie(name, value) hadMaxAge = False for name, value in tup[1:]: name = name.lower() if value is None: if name in ("discard", "secure"): # Boolean attrs value = True elif name != "port": # Can be either boolean or explicit continue if name in ("comment", "commenturl", "discard", "domain", "path", "secure"): # simple cases setattr(cookie, name, value) elif name == "expires" and not hadMaxAge: if netscapeFormat and value[0] == '"' and value[-1] == '"': value = value[1:-1] cookie.expires = parseDateTime(value) elif name == "max-age": hadMaxAge = True cookie.expires = int(value) + time.time() elif name == "port": if value is None: cookie.ports = () else: if netscapeFormat and value[0] == '"' and value[-1] == '"': value = value[1:-1] cookie.ports = tuple([int(s) for s in value.split(',')]) elif name == "version": cookie.version = int(value) return cookie def generateSetCookie(cookies): setCookies = [] for cookie in cookies: out = ["%s=%s" % (cookie.name, cookie.value)] if cookie.expires: out.append("expires=%s" % generateDateTime(cookie.expires)) if cookie.path: out.append("path=%s" % cookie.path) if cookie.domain: out.append("domain=%s" % cookie.domain) if cookie.secure: out.append("secure") setCookies.append('; '.join(out)) return setCookies def generateSetCookie2(cookies): setCookies = [] for cookie in cookies: out = ["%s=%s" % (cookie.name, quoteString(cookie.value))] if cookie.comment: out.append("Comment=%s" % quoteString(cookie.comment)) if cookie.commenturl: out.append("CommentURL=%s" % quoteString(cookie.commenturl)) if cookie.discard: out.append("Discard") if cookie.domain: out.append("Domain=%s" % quoteString(cookie.domain)) if cookie.expires: out.append("Max-Age=%s" % (cookie.expires - time.time())) if cookie.path: out.append("Path=%s" % quoteString(cookie.path)) if cookie.ports is not None: if len(cookie.ports) == 0: out.append("Port") else: out.append("Port=%s" % quoteString(",".join([str(x) for x in cookie.ports]))) if cookie.secure: out.append("Secure") out.append('Version="1"') setCookies.append('; '.join(out)) return setCookies def parseDepth(depth): if depth not in ("0", "1", "infinity"): raise ValueError("Invalid depth header value: %s" % (depth,)) return depth def parseOverWrite(overwrite): if overwrite == "F": return False elif overwrite == "T": return True raise ValueError("Invalid overwrite header value: %s" % (overwrite,)) def generateOverWrite(overwrite): if overwrite: return "T" else: return "F" def parseBrief(brief): # We accept upper or lower case if brief.upper() == "F": return False elif brief.upper() == "T": return True raise ValueError("Invalid brief header value: %s" % (brief,)) def generateBrief(brief): # MS definition uses lower case return "t" if brief else "f" ##### Random stuff that looks useful. # def sortMimeQuality(s): # def sorter(item1, item2): # if item1[0] == '*': # if item2[0] == '*': # return 0 # def sortQuality(s): # def sorter(item1, item2): # if item1[1] < item2[1]: # return -1 # if item1[1] < item2[1]: # return 1 # if item1[0] == item2[0]: # return 0 # def getMimeQuality(mimeType, accepts): # type,args = parseArgs(mimeType) # type=type.split(Token('/')) # if len(type) != 2: # raise ValueError, "MIME Type "+s+" invalid." # for accept in accepts: # accept,acceptQual=accept # acceptType=accept[0:1] # acceptArgs=accept[2] # if ((acceptType == type or acceptType == (type[0],'*') or acceptType==('*','*')) and # (args == acceptArgs or len(acceptArgs) == 0)): # return acceptQual # def getQuality(type, accepts): # qual = accepts.get(type) # if qual is not None: # return qual # return accepts.get('*') # Headers object class __RecalcNeeded(object): def __repr__(self): return "" _RecalcNeeded = __RecalcNeeded() class Headers(object): """ This class stores the HTTP headers as both a parsed representation and the raw string representation. It converts between the two on demand. """ def __init__(self, headers=None, rawHeaders=None, handler=DefaultHTTPHandler): self._raw_headers = {} self._headers = {} self.handler = handler if headers is not None: for key, value in headers.iteritems(): self.setHeader(key, value) if rawHeaders is not None: for key, value in rawHeaders.iteritems(): self.setRawHeaders(key, value) def _setRawHeaders(self, headers): self._raw_headers = headers self._headers = {} def _toParsed(self, name): r = self._raw_headers.get(name, None) h = self.handler.parse(name, r) if h is not None: self._headers[name] = h return h def _toRaw(self, name): h = self._headers.get(name, None) r = self.handler.generate(name, h) if r is not None: self._raw_headers[name] = r return r def hasHeader(self, name): """Does a header with the given name exist?""" name = name.lower() return name in self._raw_headers def getRawHeaders(self, name, default=None): """Returns a list of headers matching the given name as the raw string given.""" name = name.lower() raw_header = self._raw_headers.get(name, default) if raw_header is not _RecalcNeeded: return raw_header return self._toRaw(name) def getHeader(self, name, default=None): """Ret9urns the parsed representation of the given header. The exact form of the return value depends on the header in question. If no parser for the header exists, raise ValueError. If the header doesn't exist, return default (or None if not specified) """ name = name.lower() parsed = self._headers.get(name, default) if parsed is not _RecalcNeeded: return parsed return self._toParsed(name) def setRawHeaders(self, name, value): """Sets the raw representation of the given header. Value should be a list of strings, each being one header of the given name. """ name = name.lower() self._raw_headers[name] = value self._headers[name] = _RecalcNeeded def setHeader(self, name, value): """Sets the parsed representation of the given header. Value should be a list of objects whose exact form depends on the header in question. """ name = name.lower() self._raw_headers[name] = _RecalcNeeded self._headers[name] = value def addRawHeader(self, name, value): """ Add a raw value to a header that may or may not already exist. If it exists, add it as a separate header to output; do not replace anything. """ name = name.lower() raw_header = self._raw_headers.get(name) if raw_header is None: # No header yet raw_header = [] self._raw_headers[name] = raw_header elif raw_header is _RecalcNeeded: raw_header = self._toRaw(name) raw_header.append(value) self._headers[name] = _RecalcNeeded def removeHeader(self, name): """Removes the header named.""" name = name.lower() if name in self._raw_headers: del self._raw_headers[name] del self._headers[name] def __repr__(self): return '' % (self._raw_headers, self._headers) def canonicalNameCaps(self, name): """Return the name with the canonical capitalization, if known, otherwise, Caps-After-Dashes""" return header_case_mapping.get(name) or dashCapitalize(name) def getAllRawHeaders(self): """Return an iterator of key,value pairs of all headers contained in this object, as strings. The keys are capitalized in canonical capitalization.""" for k, v in self._raw_headers.iteritems(): if v is _RecalcNeeded: v = self._toRaw(k) yield self.canonicalNameCaps(k), v def makeImmutable(self): """Make this header set immutable. All mutating operations will raise an exception.""" self.setHeader = self.setRawHeaders = self.removeHeader = self._mutateRaise def _mutateRaise(self, *args): raise AttributeError("This header object is immutable as the headers have already been sent.") """The following dicts are all mappings of header to list of operations to perform. The first operation should generally be 'tokenize' if the header can be parsed according to the normal tokenization rules. If it cannot, generally the first thing you want to do is take only the last instance of the header (in case it was sent multiple times, which is strictly an error, but we're nice.). """ iteritems = lambda x: x.iteritems() parser_general_headers = { 'Cache-Control': (tokenize, listParser(parseCacheControl), dict), 'Connection': (tokenize, filterTokens), 'Date': (last, parseDateTime), # 'Pragma': tokenize # 'Trailer': tokenize 'Transfer-Encoding': (tokenize, filterTokens), # 'Upgrade': tokenize # 'Via': tokenize,stripComment # 'Warning': tokenize } generator_general_headers = { 'Cache-Control': (iteritems, listGenerator(generateCacheControl), singleHeader), 'Connection': (generateList, singleHeader), 'Date': (generateDateTime, singleHeader), # 'Pragma': # 'Trailer': 'Transfer-Encoding': (generateList, singleHeader), # 'Upgrade': # 'Via': # 'Warning': } parser_request_headers = { 'Accept': (tokenize, listParser(parseAccept), dict), 'Accept-Charset': (tokenize, listParser(parseAcceptQvalue), dict, addDefaultCharset), 'Accept-Encoding': (tokenize, listParser(parseAcceptQvalue), dict, addDefaultEncoding), 'Accept-Language': (tokenize, listParser(parseAcceptQvalue), dict), 'Authorization': (last, parseAuthorization), 'Cookie': (parseCookie,), 'Expect': (tokenize, listParser(parseExpect), dict), 'From': (last,), 'Host': (last,), 'If-Match': (tokenize, listParser(parseStarOrETag), list), 'If-Modified-Since': (last, parseIfModifiedSince), 'If-None-Match': (tokenize, listParser(parseStarOrETag), list), 'If-Range': (parseIfRange,), 'If-Unmodified-Since': (last, parseDateTime), 'Max-Forwards': (last, int), 'Prefer': (tokenize, listParser(parsePrefer), list), # 'Proxy-Authorization': str, # what is "credentials" 'Range': (tokenize, parseRange), 'Referer': (last, str), # TODO: URI object? 'TE': (tokenize, listParser(parseAcceptQvalue), dict), 'User-Agent': (last, str), } generator_request_headers = { 'Accept': (iteritems, listGenerator(generateAccept), singleHeader), 'Accept-Charset': (iteritems, listGenerator(generateAcceptQvalue), singleHeader), 'Accept-Encoding': (iteritems, removeDefaultEncoding, listGenerator(generateAcceptQvalue), singleHeader), 'Accept-Language': (iteritems, listGenerator(generateAcceptQvalue), singleHeader), 'Authorization': (generateAuthorization,), # what is "credentials" 'Cookie': (generateCookie, singleHeader), 'Expect': (iteritems, listGenerator(generateExpect), singleHeader), 'From': (str, singleHeader), 'Host': (str, singleHeader), 'If-Match': (listGenerator(generateStarOrETag), singleHeader), 'If-Modified-Since': (generateDateTime, singleHeader), 'If-None-Match': (listGenerator(generateStarOrETag), singleHeader), 'If-Range': (generateIfRange, singleHeader), 'If-Unmodified-Since': (generateDateTime, singleHeader), 'Max-Forwards': (str, singleHeader), 'Prefer': (listGenerator(generatePrefer), singleHeader), # 'Proxy-Authorization': str, # what is "credentials" 'Range': (generateRange, singleHeader), 'Referer': (str, singleHeader), 'TE': (iteritems, listGenerator(generateAcceptQvalue), singleHeader), 'User-Agent': (str, singleHeader), } parser_response_headers = { 'Accept-Ranges': (tokenize, filterTokens), 'Age': (last, int), 'ETag': (tokenize, ETag.parse), 'Location': (last,), # TODO: URI object? # 'Proxy-Authenticate' 'Retry-After': (last, parseRetryAfter), 'Server': (last,), 'Set-Cookie': (parseSetCookie,), 'Set-Cookie2': (tokenize, parseSetCookie2), 'Vary': (tokenize, filterTokens), 'WWW-Authenticate': (lambda h: tokenize(h, foldCase=False), parseWWWAuthenticate,) } generator_response_headers = { 'Accept-Ranges': (generateList, singleHeader), 'Age': (str, singleHeader), 'ETag': (ETag.generate, singleHeader), 'Location': (str, singleHeader), # 'Proxy-Authenticate' 'Retry-After': (generateRetryAfter, singleHeader), 'Server': (str, singleHeader), 'Set-Cookie': (generateSetCookie,), 'Set-Cookie2': (generateSetCookie2,), 'Vary': (generateList, singleHeader), 'WWW-Authenticate': (generateWWWAuthenticate,) } parser_entity_headers = { 'Allow': (lambda hdr: tokenize(hdr, foldCase=False), filterTokens), 'Content-Disposition': (lambda hdr: tokenize(hdr, foldCase=False), parseContentDisposition), 'Content-Encoding': (tokenize, filterTokens), 'Content-Language': (tokenize, filterTokens), 'Content-Length': (last, int), 'Content-Location': (last,), # TODO: URI object? 'Content-MD5': (last, parseContentMD5), 'Content-Range': (last, parseContentRange), 'Content-Type': (lambda hdr: tokenize(hdr, foldCase=False), parseContentType), 'Expires': (last, parseExpires), 'Last-Modified': (last, parseDateTime), } generator_entity_headers = { 'Allow': (generateList, singleHeader), 'Content-Disposition': (generateContentDisposition, singleHeader), 'Content-Encoding': (generateList, singleHeader), 'Content-Language': (generateList, singleHeader), 'Content-Length': (str, singleHeader), 'Content-Location': (str, singleHeader), 'Content-MD5': (base64.encodestring, lambda x: x.strip("\n"), singleHeader), 'Content-Range': (generateContentRange, singleHeader), 'Content-Type': (generateContentType, singleHeader), 'Expires': (generateDateTime, singleHeader), 'Last-Modified': (generateDateTime, singleHeader), } parser_dav_headers = { 'Brief' : (last, parseBrief), 'DAV' : (tokenize, list), 'Depth' : (last, parseDepth), 'Destination' : (last,), # TODO: URI object? # 'If' : (), # 'Lock-Token' : (), 'Overwrite' : (last, parseOverWrite), # 'Status-URI' : (), # 'Timeout' : (), } generator_dav_headers = { 'Brief' : (), 'DAV' : (generateList, singleHeader), 'Depth' : (singleHeader), 'Destination' : (singleHeader), # 'If' : (), # 'Lock-Token' : (), 'Overwrite' : (), # 'Status-URI' : (), # 'Timeout' : (), } DefaultHTTPHandler.updateParsers(parser_general_headers) DefaultHTTPHandler.updateParsers(parser_request_headers) DefaultHTTPHandler.updateParsers(parser_response_headers) DefaultHTTPHandler.updateParsers(parser_entity_headers) DefaultHTTPHandler.updateParsers(parser_dav_headers) DefaultHTTPHandler.updateGenerators(generator_general_headers) DefaultHTTPHandler.updateGenerators(generator_request_headers) DefaultHTTPHandler.updateGenerators(generator_response_headers) DefaultHTTPHandler.updateGenerators(generator_entity_headers) DefaultHTTPHandler.updateGenerators(generator_dav_headers) # casemappingify(DefaultHTTPParsers) # casemappingify(DefaultHTTPGenerators) # lowerify(DefaultHTTPParsers) # lowerify(DefaultHTTPGenerators) calendarserver-5.2+dfsg/twext/web2/error.py0000644000175000017500000002117712263343324020036 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_log -*- ## # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ Default error output filter for twext.web2. """ from twext.web2 import stream, http_headers from twext.web2.responsecode import ( MOVED_PERMANENTLY, FOUND, SEE_OTHER, USE_PROXY, TEMPORARY_REDIRECT, BAD_REQUEST, UNAUTHORIZED, PAYMENT_REQUIRED, FORBIDDEN, NOT_FOUND, NOT_ALLOWED, NOT_ACCEPTABLE, PROXY_AUTH_REQUIRED, REQUEST_TIMEOUT, CONFLICT, GONE, LENGTH_REQUIRED, PRECONDITION_FAILED, REQUEST_ENTITY_TOO_LARGE, REQUEST_URI_TOO_LONG, UNSUPPORTED_MEDIA_TYPE, REQUESTED_RANGE_NOT_SATISFIABLE, EXPECTATION_FAILED, INTERNAL_SERVER_ERROR, NOT_IMPLEMENTED, BAD_GATEWAY, SERVICE_UNAVAILABLE, GATEWAY_TIMEOUT, HTTP_VERSION_NOT_SUPPORTED, INSUFFICIENT_STORAGE_SPACE, NOT_EXTENDED, RESPONSES, ) from twisted.web.template import Element, flattenString, XMLString, renderer # 300 - Should include entity with choices # 301 - # 304 - Must include Date, ETag, Content-Location, Expires, Cache-Control, Vary. # 401 - Must include WWW-Authenticate. # 405 - Must include Allow. # 406 - Should include entity describing allowable characteristics # 407 - Must include Proxy-Authenticate # 413 - May include Retry-After # 416 - Should include Content-Range # 503 - Should include Retry-After ERROR_MESSAGES = { # 300 # no MULTIPLE_CHOICES MOVED_PERMANENTLY: 'The document has permanently moved here' '.', FOUND: 'The document has temporarily moved here' '.', SEE_OTHER: 'The results are available here' '.', # no NOT_MODIFIED USE_PROXY: 'Access to this resource must be through the proxy ' '.', # 306 unused TEMPORARY_REDIRECT: 'The document has temporarily moved ' 'here.', # 400 BAD_REQUEST: 'Your browser sent an invalid request.', UNAUTHORIZED: 'You are not authorized to view the resource at . ' "Perhaps you entered a wrong password, or perhaps your browser doesn't " 'support authentication.', PAYMENT_REQUIRED: 'Payment Required (useful result code, this...).', FORBIDDEN: 'You don\'t have permission to access .', NOT_FOUND: 'The resource cannot be found.', NOT_ALLOWED: 'The requested method is not supported by ' '.', NOT_ACCEPTABLE: 'No representation of that is acceptable to your ' 'client could be found.', PROXY_AUTH_REQUIRED: 'You are not authorized to view the resource at . ' 'Perhaps you entered a wrong password, or perhaps your browser doesn\'t ' 'support authentication.', REQUEST_TIMEOUT: 'Server timed out waiting for your client to finish sending the request.', CONFLICT: 'Conflict (?)', GONE: 'The resource has been permanently removed.', LENGTH_REQUIRED: 'The resource requires a Content-Length header.', PRECONDITION_FAILED: 'A precondition evaluated to false.', REQUEST_ENTITY_TOO_LARGE: 'The provided request entity data is too longer than the maximum for ' 'the method at .', REQUEST_URI_TOO_LONG: 'The request URL is longer than the maximum on this server.', UNSUPPORTED_MEDIA_TYPE: 'The provided request data has a format not understood by the resource ' 'at .', REQUESTED_RANGE_NOT_SATISFIABLE: 'None of the ranges given in the Range request header are satisfiable by ' 'the resource .', EXPECTATION_FAILED: 'The server does support one of the expectations given in the Expect ' 'header.', # 500 INTERNAL_SERVER_ERROR: 'An internal error occurred trying to process your request. Sorry.', NOT_IMPLEMENTED: 'Some functionality requested is not implemented on this server.', BAD_GATEWAY: 'An upstream server returned an invalid response.', SERVICE_UNAVAILABLE: 'This server cannot service your request becaues it is overloaded.', GATEWAY_TIMEOUT: 'An upstream server is not responding.', HTTP_VERSION_NOT_SUPPORTED: 'HTTP Version not supported.', INSUFFICIENT_STORAGE_SPACE: 'There is insufficient storage space available to perform that request.', NOT_EXTENDED: 'This server does not support the a mandatory extension requested.' } class DefaultErrorElement(Element): """ An L{ErrorElement} is an L{Element} that renders some HTML for the default rendering of an error page. """ loader = XMLString(""" <t:slot name="code"/> <t:slot name="title"/>

""") def __init__(self, request, response): super(DefaultErrorElement, self).__init__() self.request = request self.response = response @renderer def error(self, request, tag): """ Top-level renderer for page. """ return tag.fillSlots( code=str(self.response.code), title=RESPONSES.get(self.response.code), message=self.loadMessage(self.response.code).fillSlots( uri=self.request.uri, location=self.response.headers.getHeader('location'), method=self.request.method, ) ) def loadMessage(self, code): tag = XMLString(('') + ERROR_MESSAGES.get(code, "") + '').load()[0] return tag def defaultErrorHandler(request, response): """ Handle errors which do not have any stream (i.e. output) associated with them, so that users will see a nice message in their browser. This is used as a response filter in L{twext.web2.server.Request}. """ if response.stream is not None: # Already got an error message return response if response.code < 300: # We only do error messages return response message = ERROR_MESSAGES.get(response.code, None) if message is None: # No message specified for that code return response message = message % { 'uri': request.uri, 'location': response.headers.getHeader('location'), 'method': request.method, } data = [] error = [] (flattenString(request, DefaultErrorElement(request, response)) .addCallbacks(data.append, error.append)) # No deferreds from our renderers above, so this has always already fired. if data: subtype = 'html' body = data[0] else: subtype = 'error' body = 'Error in default error handler:\n' + error[0].getTraceback() ctype = http_headers.MimeType('text', subtype, {'charset':'utf-8'}) response.headers.setHeader("content-type", ctype) response.stream = stream.MemoryStream(body) return response defaultErrorHandler.handleErrors = True __all__ = ['defaultErrorHandler',] calendarserver-5.2+dfsg/twext/web2/test/0000755000175000017500000000000012322625325017302 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/test/test_fileupload.py0000644000175000017500000002100011337102650023024 0ustar rahulrahul# Copyright (c) 2001-2007 Twisted Matrix Laboratories. # See LICENSE for details. """ Tests for L{twext.web2.fileupload} and its different parsing functions. """ from twisted.internet import defer from twisted.trial import unittest from twisted.internet.defer import waitForDeferred, deferredGenerator from twext.web2 import stream, fileupload from twext.web2.http_headers import MimeType class TestStream(stream.SimpleStream): """ A stream that reads less data at a time than it could. """ def __init__(self, mem, maxReturn=1000, start=0, length=None): self.mem = mem self.start = start self.maxReturn = maxReturn if length is None: self.length = len(mem) - start else: if len(mem) < length: raise ValueError("len(mem) < start + length") self.length = length def read(self): if self.mem is None: return None if self.length == 0: result = None else: amtToRead = min(self.maxReturn, self.length) result = buffer(self.mem, self.start, amtToRead) self.length -= amtToRead self.start += amtToRead return result def close(self): self.mem = None stream.SimpleStream.close(self) class MultipartTests(unittest.TestCase): def doTestError(self, boundary, data, expected_error): # Test different amounts of data at a time. ds = [fileupload.parseMultipartFormData(TestStream(data, maxReturn=bytes), boundary) for bytes in range(1, 20)] d = defer.DeferredList(ds, consumeErrors=True) d.addCallback(self._assertFailures, expected_error) return d def _assertFailures(self, failures, *expectedFailures): for flag, failure in failures: self.failUnlessEqual(flag, defer.FAILURE) failure.trap(*expectedFailures) def doTest(self, boundary, data, expected_args, expected_files): #import time, gc, cgi, cStringIO for bytes in range(1, 20): #s = TestStream(data, maxReturn=bytes) s = stream.IStream(data) #t=time.time() d = waitForDeferred(fileupload.parseMultipartFormData(s, boundary)) yield d; args, files = d.getResult() #e=time.time() #print "%.2g"%(e-t) self.assertEquals(args, expected_args) # Read file data back into memory to compare. out = {} for name, l in files.items(): out[name] = [(filename, ctype, f.read()) for (filename, ctype, f) in l] self.assertEquals(out, expected_files) #data=cStringIO.StringIO(data) #t=time.time() #d=cgi.parse_multipart(data, {'boundary':boundary}) #e=time.time() #print "CGI: %.2g"%(e-t) doTest = deferredGenerator(doTest) def testNormalUpload(self): return self.doTest( '---------------------------155781040421463194511908194298', """-----------------------------155781040421463194511908194298\r Content-Disposition: form-data; name="foo"\r \r Foo Bar\r -----------------------------155781040421463194511908194298\r Content-Disposition: form-data; name="file"; filename="filename"\r Content-Type: text/html\r \r Contents of a file blah blah\r -----------------------------155781040421463194511908194298--\r """, {'foo':['Foo Bar']}, {'file':[('filename', MimeType('text', 'html'), "Contents of a file\nblah\nblah")]}) def testMultipleUpload(self): return self.doTest( 'xyz', """--xyz\r Content-Disposition: form-data; name="foo"\r \r Foo Bar\r --xyz\r Content-Disposition: form-data; name="foo"\r \r Baz\r --xyz\r Content-Disposition: form-data; name="file"; filename="filename"\r Content-Type: text/html\r \r blah\r --xyz\r Content-Disposition: form-data; name="file"; filename="filename"\r Content-Type: text/plain\r \r bleh\r --xyz--\r """, {'foo':['Foo Bar', 'Baz']}, {'file':[('filename', MimeType('text', 'html'), "blah"), ('filename', MimeType('text', 'plain'), "bleh")]}) def testStupidFilename(self): return self.doTest( '----------0xKhTmLbOuNdArY', """------------0xKhTmLbOuNdArY\r Content-Disposition: form-data; name="file"; filename="foo"; name="foobar.txt"\r Content-Type: text/plain\r \r Contents of a file blah blah\r ------------0xKhTmLbOuNdArY--\r """, {}, {'file':[('foo"; name="foobar.txt', MimeType('text', 'plain'), "Contents of a file\nblah\nblah")]}) def testEmptyFilename(self): return self.doTest( 'curlPYafCMnsamUw9kSkJJkSen41sAV', """--curlPYafCMnsamUw9kSkJJkSen41sAV\r cONTENT-tYPE: application/octet-stream\r cONTENT-dISPOSITION: FORM-DATA; NAME="foo"; FILENAME=""\r \r qwertyuiop\r --curlPYafCMnsamUw9kSkJJkSen41sAV--\r """, {}, {'foo':[('', MimeType('application', 'octet-stream'), "qwertyuiop")]}) # Failing parses def testMissingContentDisposition(self): return self.doTestError( '----------0xKhTmLbOuNdArY', """------------0xKhTmLbOuNdArY\r Content-Type: text/html\r \r Blah blah I am a stupid webbrowser\r ------------0xKhTmLbOuNdArY--\r """, fileupload.MimeFormatError) def testRandomData(self): return self.doTestError( 'boundary', """--sdkjsadjlfjlj skjsfdkljsd sfdkjsfdlkjhsfadklj sffkj""", fileupload.MimeFormatError) def test_tooBigUpload(self): """ Test that a too big form post fails. """ boundary = '---------------------------155781040421463194511908194298' data = """-----------------------------155781040421463194511908194298\r Content-Disposition: form-data; name="foo"\r \r Foo Bar\r -----------------------------155781040421463194511908194298\r Content-Disposition: form-data; name="file"; filename="filename"\r Content-Type: text/html\r \r Contents of a file blah blah\r -----------------------------155781040421463194511908194298--\r """ s = stream.IStream(data) return self.assertFailure( fileupload.parseMultipartFormData(s, boundary, maxSize=200), fileupload.MimeFormatError) def test_tooManyFields(self): """ Test when breaking the maximum number of fields. """ boundary = 'xyz' data = """--xyz\r Content-Disposition: form-data; name="foo"\r \r Foo Bar\r --xyz\r Content-Disposition: form-data; name="foo"\r \r Baz\r --xyz\r Content-Disposition: form-data; name="file"; filename="filename"\r Content-Type: text/html\r \r blah\r --xyz\r Content-Disposition: form-data; name="file"; filename="filename"\r Content-Type: text/plain\r \r bleh\r --xyz--\r """ s = stream.IStream(data) return self.assertFailure( fileupload.parseMultipartFormData(s, boundary, maxFields=3), fileupload.MimeFormatError) def test_maxMem(self): """ An attachment with no filename goes to memory: check that the C{maxMem} parameter limits the size of this kind of attachment. """ boundary = '---------------------------155781040421463194511908194298' data = """-----------------------------155781040421463194511908194298\r Content-Disposition: form-data; name="foo"\r \r Foo Bar and more content\r -----------------------------155781040421463194511908194298\r Content-Disposition: form-data; name="file"; filename="filename"\r Content-Type: text/html\r \r Contents of a file blah blah\r -----------------------------155781040421463194511908194298--\r """ s = stream.IStream(data) return self.assertFailure( fileupload.parseMultipartFormData(s, boundary, maxMem=10), fileupload.MimeFormatError) class TestURLEncoded(unittest.TestCase): def doTest(self, data, expected_args): for bytes in range(1, 20): s = TestStream(data, maxReturn=bytes) d = waitForDeferred(fileupload.parse_urlencoded(s)) yield d; args = d.getResult() self.assertEquals(args, expected_args) doTest = deferredGenerator(doTest) def test_parseValid(self): self.doTest("a=b&c=d&c=e", {'a':['b'], 'c':['d', 'e']}) self.doTest("a=b&c=d&c=e", {'a':['b'], 'c':['d', 'e']}) self.doTest("a=b+c%20d", {'a':['b c d']}) def test_parseInvalid(self): self.doTest("a&b=c", {'b':['c']}) calendarserver-5.2+dfsg/twext/web2/test/test_log.py0000644000175000017500000001076511340046753021506 0ustar rahulrahul# Copyright (c) 2001-2007 Twisted Matrix Laboratories. # See LICENSE for details. from twisted.python.log import addObserver, removeObserver from twext.web2.log import BaseCommonAccessLoggingObserver, LogWrapperResource from twext.web2.http import Response from twext.web2.resource import Resource, WrapperResource from twext.web2.test.test_server import BaseCase, BaseTestResource class BufferingLogObserver(BaseCommonAccessLoggingObserver): """ A web2 log observer that buffer messages. """ messages = [] def logMessage(self, message): self.messages.append(message) class SetDateWrapperResource(WrapperResource): """ A resource wrapper which sets the date header. """ def hook(self, req): def _filter(req, resp): resp.headers.setHeader('date', 0.0) return resp _filter.handleErrors = True req.addResponseFilter(_filter, atEnd=True) class NoneStreamResource(Resource): """ A basic empty resource. """ def render(self, req): return Response(200) class TestLogging(BaseCase): def setUp(self): self.blo = BufferingLogObserver() addObserver(self.blo.emit) # some default resource setup self.resrc = BaseTestResource() self.resrc.child_emptystream = NoneStreamResource() self.root = SetDateWrapperResource(LogWrapperResource(self.resrc)) def tearDown(self): removeObserver(self.blo.emit) def assertLogged(self, **expected): """ Check that logged messages matches expected format. """ if 'date' not in expected: epoch = BaseCommonAccessLoggingObserver().logDateString(0) expected['date'] = epoch if 'user' not in expected: expected['user'] = '-' if 'referer' not in expected: expected['referer'] = '-' if 'user-agent' not in expected: expected['user-agent'] = '-' if 'version' not in expected: expected['version'] = '1.1' if 'remotehost' not in expected: expected['remotehost'] = 'remotehost' messages = self.blo.messages[:] del self.blo.messages[:] expectedLog = ('%(remotehost)s - %(user)s [%(date)s] "%(method)s ' '%(uri)s HTTP/%(version)s" %(status)d %(length)d ' '"%(referer)s" "%(user-agent)s"') if expected.get('logged', True): # Ensure there weren't other messages hanging out self.assertEquals(len(messages), 1, "len(%r) != 1" % (messages, )) self.assertEquals(messages[0], expectedLog % expected) else: self.assertEquals(len(messages), 0, "len(%r) != 0" % (messages, )) def test_logSimpleRequest(self): """ Check the log for a simple request. """ uri = 'http://localhost/' method = 'GET' def _cbCheckLog(response): self.assertLogged(method=method, uri=uri, status=response[0], length=response[1].getHeader('content-length')) d = self.getResponseFor(self.root, uri, method=method) d.addCallback(_cbCheckLog) return d def test_logErrors(self): """ Test the error log. """ def test(_, uri, method, **expected): expected['uri'] = uri expected['method'] = method def _cbCheckLog(response): self.assertEquals(response[0], expected['status']) self.assertLogged( length=response[1].getHeader('content-length'), **expected) return self.getResponseFor(self.root, uri, method=method).addCallback(_cbCheckLog) uri = 'http://localhost/foo' # doesn't exist method = 'GET' d = test(None, uri, method, status=404, logged=True) # no host. this should result in a 400 which doesn't get logged uri = 'http:///' d.addCallback(test, uri, method, status=400, logged=False) return d def test_logNoneResponseStream(self): """ Test the log of an empty resource. """ uri = 'http://localhost/emptystream' method = 'GET' def _cbCheckLog(response): self.assertLogged(method=method, uri=uri, status=200, length=0) d = self.getResponseFor(self.root, uri, method=method) d.addCallback(_cbCheckLog) return d calendarserver-5.2+dfsg/twext/web2/test/test_static.py0000644000175000017500000001176111667476304022224 0ustar rahulrahul# Copyright (c) 2008-2011 Twisted Matrix Laboratories. # See LICENSE for details. """ Tests for L{twext.web2.static}. """ import os from twext.web2.test.test_server import BaseCase from twext.web2 import static from twext.web2 import http_headers from twext.web2 import stream from twext.web2 import iweb class TestData(BaseCase): def setUp(self): self.text = "Hello, World\n" self.data = static.Data(self.text, "text/plain") def test_dataState(self): """ Test the internal state of the Data object """ self.assert_(hasattr(self.data, "created_time")) self.assertEquals(self.data.data, self.text) self.assertEquals(self.data.type, http_headers.MimeType("text", "plain")) self.assertEquals(self.data.contentType(), http_headers.MimeType("text", "plain")) def test_etag(self): """ Test that we can get an ETag """ def _defer(result): self.failUnless(result) d = self.data.etag().addCallback(_defer) return d def test_render(self): """ Test that the result from Data.render is acceptable, including the response code, the content-type header, and the actual response body itself. """ response = iweb.IResponse(self.data.render(None)) self.assertEqual(response.code, 200) self.assert_(response.headers.hasHeader("content-type")) self.assertEqual(response.headers.getHeader("content-type"), http_headers.MimeType("text", "plain")) def checkStream(data): self.assertEquals(str(data), self.text) return stream.readStream(iweb.IResponse(self.data.render(None)).stream, checkStream) class TestFileSaver(BaseCase): def setUp(self): """ Create an empty directory and a resource which will save uploads to that directory. """ self.tempdir = self.mktemp() os.mkdir(self.tempdir) self.root = static.FileSaver(self.tempdir, expectedFields=['FileNameOne'], maxBytes=16) self.root.addSlash = True def uploadFile(self, fieldname, filename, mimetype, content, resrc=None, host='foo', path='/'): if not resrc: resrc = self.root ctype = http_headers.MimeType('multipart', 'form-data', (('boundary', '---weeboundary'),)) return self.getResponseFor(resrc, '/', headers={'host': 'foo', 'content-type': ctype }, length=len(content), method='POST', content="""-----weeboundary\r Content-Disposition: form-data; name="%s"; filename="%s"\r Content-Type: %s\r \r %s\r -----weeboundary--\r """ % (fieldname, filename, mimetype, content)) def _CbAssertInResponse(self, (code, headers, data, failed), expected_response, expectedFailure=False): expected_code, expected_headers, expected_data = expected_response self.assertEquals(code, expected_code) if expected_data is not None: self.failUnlessSubstring(expected_data, data) for key, value in expected_headers.iteritems(): self.assertEquals(headers.getHeader(key), value) self.assertEquals(failed, expectedFailure) def fileNameFromResponse(self, response): (code, headers, data, failure) = response return data[data.index('Saved file')+11:data.index('
')] def assertInResponse(self, response, expected_response, failure=False): d = response d.addCallback(self._CbAssertInResponse, expected_response, failure) return d def test_enforcesMaxBytes(self): return self.assertInResponse( self.uploadFile('FileNameOne', 'myfilename', 'text/html', 'X'*32), (200, {}, 'exceeds maximum length')) def test_enforcesMimeType(self): return self.assertInResponse( self.uploadFile('FileNameOne', 'myfilename', 'application/x-python', 'X'), (200, {}, 'type not allowed')) def test_invalidField(self): return self.assertInResponse( self.uploadFile('NotARealField', 'myfilename', 'text/html', 'X'), (200, {}, 'not a valid field')) def test_reportFileSave(self): return self.assertInResponse( self.uploadFile('FileNameOne', 'myfilename', 'text/plain', 'X'), (200, {}, 'Saved file')) def test_compareFileContents(self): def gotFname(fname): contents = file(fname, 'rb').read() self.assertEquals(contents, 'Test contents\n') d = self.uploadFile('FileNameOne', 'myfilename', 'text/plain', 'Test contents\n') d.addCallback(self.fileNameFromResponse) d.addCallback(gotFname) return d calendarserver-5.2+dfsg/twext/web2/test/test_http.py0000644000175000017500000013417612212514344021702 0ustar rahulrahul from __future__ import nested_scopes import time, sys, os from zope.interface import implements from twisted.trial import unittest from twext.web2 import http, http_headers, responsecode, iweb, stream from twext.web2 import channel from twisted.internet import reactor, protocol, address, interfaces, utils from twisted.internet import defer from twisted.internet.defer import waitForDeferred, deferredGenerator from twisted.protocols import loopback from twisted.python import util, runtime from twext.web2.channel.http import SSLRedirectRequest, HTTPFactory, HTTPChannel from twisted.internet.task import deferLater class RedirectResponseTestCase(unittest.TestCase): def testTemporary(self): """ Verify the "temporary" parameter sets the appropriate response code """ req = http.RedirectResponse("http://example.com/", temporary=False) self.assertEquals(req.code, responsecode.MOVED_PERMANENTLY) req = http.RedirectResponse("http://example.com/", temporary=True) self.assertEquals(req.code, responsecode.TEMPORARY_REDIRECT) class PreconditionTestCase(unittest.TestCase): def checkPreconditions(self, request, response, expectedResult, expectedCode, **kw): preconditionsPass = True try: http.checkPreconditions(request, response, **kw) except http.HTTPError, e: preconditionsPass = False self.assertEquals(e.response.code, expectedCode) self.assertEquals(preconditionsPass, expectedResult) def testWithoutHeaders(self): request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers()) out_headers = http_headers.Headers() response = http.Response(responsecode.OK, out_headers, None) self.checkPreconditions(request, response, True, responsecode.OK) out_headers.setHeader("ETag", http_headers.ETag('foo')) self.checkPreconditions(request, response, True, responsecode.OK) out_headers.removeHeader("ETag") out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT self.checkPreconditions(request, response, True, responsecode.OK) out_headers.setHeader("ETag", http_headers.ETag('foo')) self.checkPreconditions(request, response, True, responsecode.OK) def testIfMatch(self): request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers()) out_headers = http_headers.Headers() response = http.Response(responsecode.OK, out_headers, None) # Behavior with no ETag set, should be same as with an ETag request.headers.setRawHeaders("If-Match", ('*',)) self.checkPreconditions(request, response, True, responsecode.OK) self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED, entityExists=False) # Ask for tag, but no etag set. request.headers.setRawHeaders("If-Match", ('"frob"',)) self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED) ## Actually set the ETag header out_headers.setHeader("ETag", http_headers.ETag('foo')) out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT # behavior of entityExists request.headers.setRawHeaders("If-Match", ('*',)) self.checkPreconditions(request, response, True, responsecode.OK) self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED, entityExists=False) # tag matches request.headers.setRawHeaders("If-Match", ('"frob", "foo"',)) self.checkPreconditions(request, response, True, responsecode.OK) # none match request.headers.setRawHeaders("If-Match", ('"baz", "bob"',)) self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED) # But if we have an error code already, ignore this header response.code = responsecode.INTERNAL_SERVER_ERROR self.checkPreconditions(request, response, True, responsecode.INTERNAL_SERVER_ERROR) response.code = responsecode.OK # Must only compare strong tags out_headers.setHeader("ETag", http_headers.ETag('foo', weak=True)) request.headers.setRawHeaders("If-Match", ('W/"foo"',)) self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED) def testIfUnmodifiedSince(self): request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers()) out_headers = http_headers.Headers() response = http.Response(responsecode.OK, out_headers, None) # No Last-Modified => always fail. request.headers.setRawHeaders("If-Unmodified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',)) self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED) # Set output headers out_headers.setHeader("ETag", http_headers.ETag('foo')) out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT request.headers.setRawHeaders("If-Unmodified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',)) self.checkPreconditions(request, response, True, responsecode.OK) request.headers.setRawHeaders("If-Unmodified-Since", ('Sat, 01 Jan 2000 00:00:00 GMT',)) self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED) # But if we have an error code already, ignore this header response.code = responsecode.INTERNAL_SERVER_ERROR self.checkPreconditions(request, response, True, responsecode.INTERNAL_SERVER_ERROR) response.code = responsecode.OK # invalid date => header ignored request.headers.setRawHeaders("If-Unmodified-Since", ('alalalalalalalalalala',)) self.checkPreconditions(request, response, True, responsecode.OK) def testIfModifiedSince(self): if time.time() < 946771200: self.fail(RuntimeError("Your computer's clock is way wrong, " "this test will be invalid.")) request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers()) out_headers = http_headers.Headers() response = http.Response(responsecode.OK, out_headers, None) # No Last-Modified => always succeed request.headers.setRawHeaders("If-Modified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',)) self.checkPreconditions(request, response, True, responsecode.OK) # Set output headers out_headers.setHeader("ETag", http_headers.ETag('foo')) out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT request.headers.setRawHeaders("If-Modified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',)) self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED) # With a non-GET method request.method="PUT" self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED) request.method="GET" request.headers.setRawHeaders("If-Modified-Since", ('Sat, 01 Jan 2000 00:00:00 GMT',)) self.checkPreconditions(request, response, True, responsecode.OK) # But if we have an error code already, ignore this header response.code = responsecode.INTERNAL_SERVER_ERROR self.checkPreconditions(request, response, True, responsecode.INTERNAL_SERVER_ERROR) response.code = responsecode.OK # invalid date => header ignored request.headers.setRawHeaders("If-Modified-Since", ('alalalalalalalalalala',)) self.checkPreconditions(request, response, True, responsecode.OK) # date in the future => assume modified request.headers.setHeader("If-Modified-Since", time.time() + 500) self.checkPreconditions(request, response, True, responsecode.OK) def testIfNoneMatch(self): request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers()) out_headers = http_headers.Headers() response = http.Response(responsecode.OK, out_headers, None) request.headers.setRawHeaders("If-None-Match", ('"foo"',)) self.checkPreconditions(request, response, True, responsecode.OK) out_headers.setHeader("ETag", http_headers.ETag('foo')) out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT # behavior of entityExists request.headers.setRawHeaders("If-None-Match", ('*',)) request.method="PUT" self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED) request.method="GET" self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED) self.checkPreconditions(request, response, True, responsecode.OK, entityExists=False) # tag matches request.headers.setRawHeaders("If-None-Match", ('"frob", "foo"',)) request.method="PUT" self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED) request.method="GET" self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED) # now with IMS, also: request.headers.setRawHeaders("If-Modified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',)) request.method="PUT" self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED) request.method="GET" self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED) request.headers.setRawHeaders("If-Modified-Since", ('Sat, 01 Jan 2000 00:00:00 GMT',)) self.checkPreconditions(request, response, True, responsecode.OK) request.headers.removeHeader("If-Modified-Since") # none match request.headers.setRawHeaders("If-None-Match", ('"baz", "bob"',)) self.checkPreconditions(request, response, True, responsecode.OK) # now with IMS, also: request.headers.setRawHeaders("If-Modified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',)) self.checkPreconditions(request, response, True, responsecode.OK) request.headers.setRawHeaders("If-Modified-Since", ('Sat, 01 Jan 2000 00:00:00 GMT',)) self.checkPreconditions(request, response, True, responsecode.OK) request.headers.removeHeader("If-Modified-Since") # But if we have an error code already, ignore this header response.code = responsecode.INTERNAL_SERVER_ERROR self.checkPreconditions(request, response, True, responsecode.INTERNAL_SERVER_ERROR) response.code = responsecode.OK # Weak tags okay for GET out_headers.setHeader("ETag", http_headers.ETag('foo', weak=True)) request.headers.setRawHeaders("If-None-Match", ('W/"foo"',)) self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED) # Weak tags not okay for other methods request.method="PUT" out_headers.setHeader("ETag", http_headers.ETag('foo', weak=True)) request.headers.setRawHeaders("If-None-Match", ('W/"foo"',)) self.checkPreconditions(request, response, True, responsecode.OK) def testNoResponse(self): # Ensure that passing etag/lastModified arguments instead of response works. request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers()) request.method="PUT" request.headers.setRawHeaders("If-None-Match", ('"foo"',)) self.checkPreconditions(request, None, True, responsecode.OK) self.checkPreconditions(request, None, False, responsecode.PRECONDITION_FAILED, etag=http_headers.ETag('foo'), lastModified=946771200) # Make sure that, while you shoudn't do this, that it doesn't cause an error request.method="GET" self.checkPreconditions(request, None, False, responsecode.NOT_MODIFIED, etag=http_headers.ETag('foo')) class IfRangeTestCase(unittest.TestCase): def testIfRange(self): request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers()) response = TestResponse() self.assertEquals(http.checkIfRange(request, response), True) request.headers.setRawHeaders("If-Range", ('"foo"',)) self.assertEquals(http.checkIfRange(request, response), False) response.headers.setHeader("ETag", http_headers.ETag('foo')) self.assertEquals(http.checkIfRange(request, response), True) request.headers.setRawHeaders("If-Range", ('"bar"',)) response.headers.setHeader("ETag", http_headers.ETag('foo')) self.assertEquals(http.checkIfRange(request, response), False) request.headers.setRawHeaders("If-Range", ('W/"foo"',)) response.headers.setHeader("ETag", http_headers.ETag('foo', weak=True)) self.assertEquals(http.checkIfRange(request, response), False) request.headers.setRawHeaders("If-Range", ('"foo"',)) response.headers.removeHeader("ETag") self.assertEquals(http.checkIfRange(request, response), False) request.headers.setRawHeaders("If-Range", ('Sun, 02 Jan 2000 00:00:00 GMT',)) response.headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT self.assertEquals(http.checkIfRange(request, response), True) request.headers.setRawHeaders("If-Range", ('Sun, 02 Jan 2000 00:00:01 GMT',)) response.headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT self.assertEquals(http.checkIfRange(request, response), False) request.headers.setRawHeaders("If-Range", ('Sun, 01 Jan 2000 23:59:59 GMT',)) response.headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT self.assertEquals(http.checkIfRange(request, response), False) request.headers.setRawHeaders("If-Range", ('Sun, 01 Jan 2000 23:59:59 GMT',)) response.headers.removeHeader("Last-Modified") self.assertEquals(http.checkIfRange(request, response), False) request.headers.setRawHeaders("If-Range", ('jwerlqjL#$Y*KJAN',)) self.assertEquals(http.checkIfRange(request, response), False) class LoopbackRelay(loopback.LoopbackRelay): implements(interfaces.IProducer) def pauseProducing(self): self.paused = True def resumeProducing(self): self.paused = False def stopProducing(self): self.loseConnection() def loseWriteConnection(self): # HACK. self.loseConnection() def abortConnection(self): self.aborted = True def getHost(self): """ Synthesize a slightly more realistic 'host' thing. """ return address.IPv4Address('TCP', 'localhost', 4321) class TestRequestMixin(object): def __init__(self, *args, **kwargs): super(TestRequestMixin, self).__init__(*args, **kwargs) self.cmds = [] headers = list(self.headers.getAllRawHeaders()) headers.sort() self.cmds.append(('init', self.method, self.uri, self.clientproto, self.stream.length, tuple(headers))) def process(self): pass def handleContentChunk(self, data): self.cmds.append(('contentChunk', data)) def handleContentComplete(self): self.cmds.append(('contentComplete',)) def connectionLost(self, reason): self.cmds.append(('connectionLost', reason)) def _finished(self, x): self._reallyFinished(x) class TestRequest(TestRequestMixin, http.Request): """ Stub request for testing. """ class TestSSLRedirectRequest(TestRequestMixin, SSLRedirectRequest): """ Stub request for HSTS testing. """ class TestResponse(object): implements(iweb.IResponse) code = responsecode.OK headers = None def __init__(self): self.headers = http_headers.Headers() self.stream = stream.ProducerStream() def write(self, data): self.stream.write(data) def finish(self): self.stream.finish() class TestClient(protocol.Protocol): data = "" done = False def dataReceived(self, data): self.data+=data def write(self, data): self.transport.write(data) def connectionLost(self, reason): self.done = True self.transport.loseConnection() def loseConnection(self): self.done = True self.transport.loseConnection() class TestConnection: def __init__(self): self.requests = [] self.client = None self.callLaters = [] def fakeCallLater(self, secs, f): assert secs == 0 self.callLaters.append(f) class HTTPTests(unittest.TestCase): requestClass = TestRequest def setUp(self): super(HTTPTests, self).setUp() # We always need this set to True - previous tests may have changed it HTTPChannel.allowPersistentConnections = True def connect(self, logFile=None, **protocol_kwargs): cxn = TestConnection() def makeTestRequest(*args): cxn.requests.append(self.requestClass(*args)) return cxn.requests[-1] factory = channel.HTTPFactory(requestFactory=makeTestRequest, _callLater=cxn.fakeCallLater, **protocol_kwargs) cxn.client = TestClient() cxn.server = factory.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 2345)) cxn.serverToClient = LoopbackRelay(cxn.client, logFile) cxn.clientToServer = LoopbackRelay(cxn.server, logFile) cxn.server.makeConnection(cxn.serverToClient) cxn.client.makeConnection(cxn.clientToServer) return cxn def iterate(self, cxn): callLaters = cxn.callLaters cxn.callLaters = [] for f in callLaters: f() cxn.serverToClient.clearBuffer() cxn.clientToServer.clearBuffer() if cxn.serverToClient.shouldLose: cxn.serverToClient.clearBuffer() if cxn.clientToServer.shouldLose: cxn.clientToServer.clearBuffer() def compareResult(self, cxn, cmds, data): self.iterate(cxn) for receivedRequest, expectedCommands in map(None, cxn.requests, cmds): sortedHeaderCommands = [] for cmd in expectedCommands: if len(cmd) == 6: sortedHeaders = list(cmd[5]) sortedHeaders.sort() sortedHeaderCommands.append(cmd[:5] + (tuple(sortedHeaders),)) else: sortedHeaderCommands.append(cmd) self.assertEquals(receivedRequest.cmds, sortedHeaderCommands) self.assertEquals(cxn.client.data, data) def assertDone(self, cxn, done=True): self.iterate(cxn) self.assertEquals(cxn.client.done, done) class GracefulShutdownTestCase(HTTPTests): def _callback(self, result): self.callbackFired = True def testAllConnectionsClosedWithoutConnectedChannels(self): """ allConnectionsClosed( ) should fire right away if no connected channels """ self.callbackFired = False factory = HTTPFactory(None) factory.allConnectionsClosed().addCallback(self._callback) self.assertTrue(self.callbackFired) # now! def testallConnectionsClosedWithConnectedChannels(self): """ allConnectionsClosed( ) should only fire after all connected channels have been removed """ self.callbackFired = False factory = HTTPFactory(None) factory.addConnectedChannel("A") factory.addConnectedChannel("B") factory.addConnectedChannel("C") factory.allConnectionsClosed().addCallback(self._callback) factory.removeConnectedChannel("A") self.assertFalse(self.callbackFired) # wait for it... factory.removeConnectedChannel("B") self.assertFalse(self.callbackFired) # wait for it... factory.removeConnectedChannel("C") self.assertTrue(self.callbackFired) # now! class CoreHTTPTestCase(HTTPTests): # Note: these tests compare the client output using string # matching. It is acceptable for this to change and break # the test if you know what you are doing. def testHTTP0_9(self, nouri=False): cxn = self.connect() cmds = [[]] data = "" if nouri: cxn.client.write("GET\r\n") else: cxn.client.write("GET /\r\n") # Second request which should not be handled cxn.client.write("GET /two\r\n") cmds[0] += [('init', 'GET', '/', (0,9), 0, ()), ('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() response.headers.setRawHeaders("Yo", ("One", "Two")) cxn.requests[0].writeResponse(response) response.write("") self.compareResult(cxn, cmds, data) response.write("Output") data += "Output" self.compareResult(cxn, cmds, data) response.finish() self.compareResult(cxn, cmds, data) self.assertDone(cxn) def testHTTP0_9_nouri(self): self.testHTTP0_9(True) def testHTTP1_0(self): cxn = self.connect() cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.0\r\nContent-Length: 5\r\nHost: localhost\r\n\r\nInput") # Second request which should not be handled cxn.client.write("GET /two HTTP/1.0\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1,0), 5, (('Host', ['localhost']),)), ('contentChunk', 'Input'), ('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() response.headers.setRawHeaders("Yo", ("One", "Two")) cxn.requests[0].writeResponse(response) response.write("") data += "HTTP/1.1 200 OK\r\nYo: One\r\nYo: Two\r\nConnection: close\r\n\r\n" self.compareResult(cxn, cmds, data) response.write("Output") data += "Output" self.compareResult(cxn, cmds, data) response.finish() self.compareResult(cxn, cmds, data) self.assertDone(cxn) def testHTTP1_0_keepalive(self): cxn = self.connect() cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.0\r\nConnection: keep-alive\r\nContent-Length: 5\r\nHost: localhost\r\n\r\nInput") cxn.client.write("GET /two HTTP/1.0\r\n\r\n") # Third request shouldn't be handled cxn.client.write("GET /three HTTP/1.0\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1,0), 5, (('Host', ['localhost']),)), ('contentChunk', 'Input'), ('contentComplete',)] self.compareResult(cxn, cmds, data) response0 = TestResponse() response0.headers.setRawHeaders("Content-Length", ("6", )) response0.headers.setRawHeaders("Yo", ("One", "Two")) cxn.requests[0].writeResponse(response0) response0.write("") data += "HTTP/1.1 200 OK\r\nContent-Length: 6\r\nYo: One\r\nYo: Two\r\nConnection: Keep-Alive\r\n\r\n" self.compareResult(cxn, cmds, data) response0.write("Output") data += "Output" self.compareResult(cxn, cmds, data) response0.finish() # Now for second request: cmds.append([]) cmds[1] += [('init', 'GET', '/two', (1,0), 0, ()), ('contentComplete',)] self.compareResult(cxn, cmds, data) response1 = TestResponse() response1.headers.setRawHeaders("Content-Length", ("0", )) cxn.requests[1].writeResponse(response1) response1.write("") data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n" self.compareResult(cxn, cmds, data) response1.finish() self.assertDone(cxn) def testHTTP1_1_pipelining(self): cxn = self.connect(maxPipeline=2) cmds = [] data = "" # Both these show up immediately. cxn.client.write("GET / HTTP/1.1\r\nContent-Length: 5\r\nHost: localhost\r\n\r\nInput") cxn.client.write("GET /two HTTP/1.1\r\nHost: localhost\r\n\r\n") # Doesn't show up until the first is done. cxn.client.write("GET /three HTTP/1.1\r\nHost: localhost\r\n\r\n") # Doesn't show up until the second is done. cxn.client.write("GET /four HTTP/1.1\r\nHost: localhost\r\n\r\n") cmds.append([]) cmds[0] += [('init', 'GET', '/', (1,1), 5, (('Host', ['localhost']),)), ('contentChunk', 'Input'), ('contentComplete',)] cmds.append([]) cmds[1] += [('init', 'GET', '/two', (1,1), 0, (('Host', ['localhost']),)), ('contentComplete',)] self.compareResult(cxn, cmds, data) response0 = TestResponse() response0.headers.setRawHeaders("Content-Length", ("6", )) cxn.requests[0].writeResponse(response0) response0.write("") data += "HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\n" self.compareResult(cxn, cmds, data) response0.write("Output") data += "Output" self.compareResult(cxn, cmds, data) response0.finish() # Now the third request gets read: cmds.append([]) cmds[2] += [('init', 'GET', '/three', (1,1), 0, (('Host', ['localhost']),)), ('contentComplete',)] self.compareResult(cxn, cmds, data) # Let's write out the third request before the second. # This should not cause anything to be written to the client. response2 = TestResponse() response2.headers.setRawHeaders("Content-Length", ("5", )) cxn.requests[2].writeResponse(response2) response2.write("Three") response2.finish() self.compareResult(cxn, cmds, data) response1 = TestResponse() response1.headers.setRawHeaders("Content-Length", ("3", )) cxn.requests[1].writeResponse(response1) response1.write("Two") data += "HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nTwo" self.compareResult(cxn, cmds, data) response1.finish() # Fourth request shows up cmds.append([]) cmds[3] += [('init', 'GET', '/four', (1,1), 0, (('Host', ['localhost']),)), ('contentComplete',)] data += "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nThree" self.compareResult(cxn, cmds, data) response3 = TestResponse() response3.headers.setRawHeaders("Content-Length", ("0",)) cxn.requests[3].writeResponse(response3) response3.finish() data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" self.compareResult(cxn, cmds, data) self.assertDone(cxn, done=False) cxn.client.loseConnection() self.assertDone(cxn) def testHTTP1_1_chunking(self, extraHeaders=""): cxn = self.connect() cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\nHost: localhost\r\n\r\n5\r\nInput\r\n") cmds[0] += [('init', 'GET', '/', (1,1), None, (('Host', ['localhost']),)), ('contentChunk', 'Input')] self.compareResult(cxn, cmds, data) cxn.client.write("1; blahblahblah\r\na\r\n10\r\nabcdefghijklmnop\r\n") cmds[0] += [('contentChunk', 'a'),('contentChunk', 'abcdefghijklmnop')] self.compareResult(cxn, cmds, data) cxn.client.write("0\r\nRandom-Ignored-Trailer: foo\r\n\r\n") cmds[0] += [('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() cxn.requests[0].writeResponse(response) response.write("Output") expected = ["HTTP/1.1 200 OK"] if extraHeaders: expected.append(extraHeaders) expected.extend([ "Transfer-Encoding: chunked", "", "6", "Output", "", ]) data += "\r\n".join(expected) self.compareResult(cxn, cmds, data) response.write("blahblahblah") data += "C\r\nblahblahblah\r\n" self.compareResult(cxn, cmds, data) response.finish() data += "0\r\n\r\n" self.compareResult(cxn, cmds, data) cxn.client.loseConnection() self.assertDone(cxn) def testHTTP1_1_expect_continue(self): cxn = self.connect() cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\nContent-Length: 5\r\nHost: localhost\r\nExpect: 100-continue\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1,1), 5, (('Expect', ['100-continue']), ('Host', ['localhost'])))] self.compareResult(cxn, cmds, data) cxn.requests[0].stream.read() data += "HTTP/1.1 100 Continue\r\n\r\n" self.compareResult(cxn, cmds, data) cxn.client.write("Input") cmds[0] += [('contentChunk', 'Input'), ('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() response.headers.setRawHeaders("Content-Length", ("6",)) cxn.requests[0].writeResponse(response) response.write("Output") response.finish() data += "HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nOutput" self.compareResult(cxn, cmds, data) cxn.client.loseConnection() self.assertDone(cxn) def testHTTP1_1_expect_continue_early_reply(self): cxn = self.connect() cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\nContent-Length: 5\r\nHost: localhost\r\nExpect: 100-continue\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1,1), 5, (('Host', ['localhost']), ('Expect', ['100-continue'])))] self.compareResult(cxn, cmds, data) response = TestResponse() response.headers.setRawHeaders("Content-Length", ("6",)) cxn.requests[0].writeResponse(response) response.write("Output") response.finish() cmds[0] += [('contentComplete',)] data += "HTTP/1.1 200 OK\r\nContent-Length: 6\r\nConnection: close\r\n\r\nOutput" self.compareResult(cxn, cmds, data) cxn.client.loseConnection() self.assertDone(cxn) def testHeaderContinuation(self): cxn = self.connect() cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\nHost: localhost\r\nFoo: yada\r\n yada\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1,1), 0, (('Host', ['localhost']), ('Foo', ['yada yada']),)), ('contentComplete',)] self.compareResult(cxn, cmds, data) cxn.client.loseConnection() self.assertDone(cxn) def testTimeout_immediate(self): # timeout 0 => timeout on first iterate call cxn = self.connect(inputTimeOut = 0) return deferLater(reactor, 0, self.assertDone, cxn) def testTimeout_inRequest(self): cxn = self.connect(inputTimeOut = 0.3) cxn.client.write("GET / HTTP/1.1\r\n") return deferLater(reactor, 0.5, self.assertDone, cxn) def testTimeout_betweenRequests(self): cxn = self.connect(betweenRequestsTimeOut = 0.3) cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1,1), 0, ()), ('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() response.headers.setRawHeaders("Content-Length", ("0",)) cxn.requests[0].writeResponse(response) response.finish() data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" self.compareResult(cxn, cmds, data) return deferLater(reactor, 0.5, self.assertDone, cxn) # Wait for timeout def testTimeout_idleRequest(self): cxn = self.connect(idleTimeOut=0.3) cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1, 1), 0, ()), ('contentComplete',)] self.compareResult(cxn, cmds, data) return deferLater(reactor, 0.5, self.assertDone, cxn) # Wait for timeout def testTimeout_abortRequest(self): cxn = self.connect(allowPersistentConnections=False, closeTimeOut=0.3) cxn.client.transport.loseConnection = lambda : None cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1, 1), 0, ()), ('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() response.headers.setRawHeaders("Content-Length", ("0",)) cxn.requests[0].writeResponse(response) response.finish() data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n" self.compareResult(cxn, cmds, data) def _check(cxn): self.assertDone(cxn) self.assertTrue(cxn.serverToClient.aborted) return deferLater(reactor, 0.5, self.assertDone, cxn) # Wait for timeout def testConnectionCloseRequested(self): cxn = self.connect() cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1,1), 0, ()), ('contentComplete',)] self.compareResult(cxn, cmds, data) cxn.client.write("GET / HTTP/1.1\r\nConnection: close\r\n\r\n") cmds.append([]) cmds[1] += [('init', 'GET', '/', (1,1), 0, ()), ('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() response.headers.setRawHeaders("Content-Length", ("0",)) cxn.requests[0].writeResponse(response) response.finish() data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" response = TestResponse() response.headers.setRawHeaders("Content-Length", ("0",)) cxn.requests[1].writeResponse(response) response.finish() data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n" self.compareResult(cxn, cmds, data) self.assertDone(cxn) def testConnectionKeepAliveOff(self): cxn = self.connect(allowPersistentConnections=False) cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1, 1), 0, ()), ('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() response.headers.setRawHeaders("Content-Length", ("0",)) cxn.requests[0].writeResponse(response) response.finish() data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n" self.compareResult(cxn, cmds, data) self.assertDone(cxn) def testExtraCRLFs(self): cxn = self.connect() cmds = [[]] data = "" # Some broken clients (old IEs) send an extra CRLF after post cxn.client.write("POST / HTTP/1.1\r\nContent-Length: 5\r\nHost: localhost\r\n\r\nInput\r\n") cmds[0] += [('init', 'POST', '/', (1,1), 5, (('Host', ['localhost']),)), ('contentChunk', 'Input'), ('contentComplete',)] self.compareResult(cxn, cmds, data) cxn.client.write("GET /two HTTP/1.1\r\n\r\n") cmds.append([]) cmds[1] += [('init', 'GET', '/two', (1,1), 0, ()), ('contentComplete',)] self.compareResult(cxn, cmds, data) cxn.client.loseConnection() self.assertDone(cxn) def testDisallowPersistentConnections(self): cxn = self.connect(allowPersistentConnections=False) cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nGET / HTTP/1.1\r\nHost: localhost\r\n\r\n") cmds[0] += [('init', 'GET', '/', (1,1), 0, (('Host', ['localhost']),)), ('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() response.finish() cxn.requests[0].writeResponse(response) data += 'HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n' self.compareResult(cxn, cmds, data) self.assertDone(cxn) def testIgnoreBogusContentLength(self): # Ensure that content-length is ignored when transfer-encoding # is also specified. cxn = self.connect() cmds = [[]] data = "" cxn.client.write("GET / HTTP/1.1\r\nContent-Length: 100\r\nTransfer-Encoding: chunked\r\nHost: localhost\r\n\r\n5\r\nInput\r\n") cmds[0] += [('init', 'GET', '/', (1,1), None, (('Host', ['localhost']),)), ('contentChunk', 'Input')] self.compareResult(cxn, cmds, data) cxn.client.write("0\r\n\r\n") cmds[0] += [('contentComplete',)] self.compareResult(cxn, cmds, data) response = TestResponse() response.finish() cxn.requests[0].writeResponse(response) data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" self.compareResult(cxn, cmds, data) cxn.client.loseConnection() self.assertDone(cxn) class ErrorTestCase(HTTPTests): def assertStartsWith(self, first, second, msg=None): self.assert_(first.startswith(second), '%r.startswith(%r)' % (first, second)) def checkError(self, cxn, code): self.iterate(cxn) self.assertStartsWith(cxn.client.data, "HTTP/1.1 %d "%code) self.assertIn("\r\nConnection: close\r\n", cxn.client.data) # Ensure error messages have a defined content-length. self.assertIn("\r\nContent-Length:", cxn.client.data) self.assertDone(cxn) def testChunkingError1(self): cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\nasdf\r\n") self.checkError(cxn, 400) def testChunkingError2(self): cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n1\r\nblahblah\r\n") self.checkError(cxn, 400) def testChunkingError3(self): cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n-1\r\nasdf\r\n") self.checkError(cxn, 400) def testTooManyHeaders(self): cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\n") cxn.client.write("Foo: Bar\r\n"*5000) self.checkError(cxn, 400) def testLineTooLong(self): cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\n") cxn.client.write("Foo: "+("Bar"*10000)) self.checkError(cxn, 400) def testLineTooLong2(self): cxn = self.connect() cxn.client.write("GET "+("/Bar")*10000 +" HTTP/1.1\r\n") self.checkError(cxn, 414) def testNoColon(self): cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\n") cxn.client.write("Blahblah\r\n\r\n") self.checkError(cxn, 400) def test_nonAsciiHeader(self): """ As per U{RFC 822 section 3, }, headers are ASCII only. """ cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\nX-Extra-Header: \xff\r\n\r\n") self.checkError(cxn, responsecode.BAD_REQUEST) cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\nX-E\xfftra-Header: foo\r\n\r\n") self.checkError(cxn, responsecode.BAD_REQUEST) def testBadRequest(self): cxn = self.connect() cxn.client.write("GET / more HTTP/1.1\r\n") self.checkError(cxn, 400) def testWrongProtocol(self): cxn = self.connect() cxn.client.write("GET / Foobar/1.0\r\n") self.checkError(cxn, 400) def testBadProtocolVersion(self): cxn = self.connect() cxn.client.write("GET / HTTP/1\r\n") self.checkError(cxn, 400) def testBadProtocolVersion2(self): cxn = self.connect() cxn.client.write("GET / HTTP/-1.0\r\n") self.checkError(cxn, 400) def testWrongProtocolVersion(self): cxn = self.connect() cxn.client.write("GET / HTTP/2.0\r\n") self.checkError(cxn, 505) def testUnsupportedTE(self): cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\n") cxn.client.write("Transfer-Encoding: blahblahblah, chunked\r\n\r\n") self.checkError(cxn, 501) def testTEWithoutChunked(self): cxn = self.connect() cxn.client.write("GET / HTTP/1.1\r\n") cxn.client.write("Transfer-Encoding: gzip\r\n\r\n") self.checkError(cxn, 400) class PipelinedErrorTestCase(ErrorTestCase): # Make sure that even low level reading errors don't corrupt the data stream, # but always wait until their turn to respond. def connect(self): cxn = ErrorTestCase.connect(self) cxn.client.write("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n") cmds = [[('init', 'GET', '/', (1,1), 0, (('Host', ['localhost']),)), ('contentComplete', )]] data = "" self.compareResult(cxn, cmds, data) return cxn def checkError(self, cxn, code): self.iterate(cxn) self.assertEquals(cxn.client.data, '') response = TestResponse() response.headers.setRawHeaders("Content-Length", ("0",)) cxn.requests[0].writeResponse(response) response.write('') data = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" self.iterate(cxn) self.assertEquals(cxn.client.data, data) # Reset the data so the checkError's startswith test can work right. cxn.client.data = "" response.finish() ErrorTestCase.checkError(self, cxn, code) class SimpleFactory(channel.HTTPFactory): def buildProtocol(self, addr): # Do a bunch of crazy crap just so that the test case can know when the # connection is done. p = channel.HTTPFactory.buildProtocol(self, addr) cl = p.connectionLost def newCl(reason): reactor.callLater(0, lambda: self.testcase.connlost.callback(None)) return cl(reason) p.connectionLost = newCl self.conn = p return p class SimpleRequest(http.Request): def process(self): response = TestResponse() if self.uri == "/error": response.code=402 elif self.uri == "/forbidden": response.code=403 else: response.code=404 response.write("URI %s unrecognized." % self.uri) response.finish() self.writeResponse(response) class AbstractServerTestMixin: type = None def testBasicWorkingness(self): args = ('-u', util.sibpath(__file__, "simple_client.py"), "basic", str(self.port), self.type) d = waitForDeferred( utils.getProcessOutputAndValue(sys.executable, args=args, env=os.environ) ) yield d; out,err,code = d.getResult() self.assertEquals(code, 0, "Error output:\n%s" % (err,)) self.assertEquals(out, "HTTP/1.1 402 Payment Required\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") testBasicWorkingness = deferredGenerator(testBasicWorkingness) def testLingeringClose(self): args = ('-u', util.sibpath(__file__, "simple_client.py"), "lingeringClose", str(self.port), self.type) d = waitForDeferred( utils.getProcessOutputAndValue(sys.executable, args=args, env=os.environ) ) yield d; out,err,code = d.getResult() self.assertEquals(code, 0, "Error output:\n%s" % (err,)) self.assertEquals(out, "HTTP/1.1 402 Payment Required\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") testLingeringClose = deferredGenerator(testLingeringClose) class TCPServerTest(unittest.TestCase, AbstractServerTestMixin): type = 'tcp' def setUp(self): factory=SimpleFactory(requestFactory=SimpleRequest) factory.testcase = self self.factory = factory self.connlost = defer.Deferred() self.socket = reactor.listenTCP(0, factory) self.port = self.socket.getHost().port def tearDown(self): # Make sure the listening port is closed d = defer.maybeDeferred(self.socket.stopListening) def finish(v): # And make sure the established connection is, too self.factory.conn.transport.loseConnection() return self.connlost return d.addCallback(finish) try: from twisted.internet import ssl ssl # pyflakes except ImportError: # happens the first time the interpreter tries to import it ssl = None if ssl and not ssl.supported: # happens second and later times ssl = None certPath = util.sibpath(__file__, "server.pem") class SSLServerTest(unittest.TestCase, AbstractServerTestMixin): type = 'ssl' def setUp(self): sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath) factory=SimpleFactory(requestFactory=SimpleRequest) factory.testcase = self self.factory = factory self.connlost = defer.Deferred() self.socket = reactor.listenSSL(0, factory, sCTX) self.port = self.socket.getHost().port def tearDown(self): # Make sure the listening port is closed d = defer.maybeDeferred(self.socket.stopListening) def finish(v): # And make sure the established connection is, too self.factory.conn.transport.loseConnection() return self.connlost return d.addCallback(finish) def testLingeringClose(self): return super(SSLServerTest, self).testLingeringClose() if runtime.platform.isWindows(): # This may not just be Windows, but all platforms with more recent # versions of OpenSSL. Do some more experimentation... testLingeringClose.todo = "buffering kills the connection too early; test this some other way" if interfaces.IReactorProcess(reactor, None) is None: TCPServerTest.skip = SSLServerTest.skip = "Required process support missing from reactor" elif interfaces.IReactorSSL(reactor, None) is None: SSLServerTest.skip = "Required SSL support missing from reactor" elif ssl is None: SSLServerTest.skip = "SSL not available, cannot test SSL." calendarserver-5.2+dfsg/twext/web2/test/test_resource.py0000644000175000017500000001605511457335713022560 0ustar rahulrahul# Copyright (c) 2001-2007 Twisted Matrix Laboratories. # See LICENSE for details. """ A test harness for twext.web2.resource. """ from sets import Set as set from zope.interface import implements from twisted.internet.defer import succeed, fail, inlineCallbacks from twisted.trial import unittest from twext.web2 import responsecode from twext.web2.iweb import IResource from twext.web2.http import Response from twext.web2.stream import MemoryStream from twext.web2.resource import RenderMixin, LeafResource from twext.web2.server import Site, StopTraversal from twext.web2.test.test_server import SimpleRequest class PreconditionError (Exception): "Precondition Failure" class TestResource (RenderMixin): implements(IResource) def _handler(self, request): if request is None: return responsecode.INTERNAL_SERVER_ERROR return responsecode.NO_CONTENT http_BLEARGH = _handler http_HUCKHUCKBLORP = _handler http_SWEETHOOKUPS = _handler http_HOOKUPS = _handler def preconditions_BLEARGH(self, request): raise PreconditionError() def precondition_HUCKHUCKBLORP(self, request): return fail(None) def preconditions_SWEETHOOKUPS(self, request): return None def preconditions_HOOKUPS(self, request): return succeed(None) renderOutput = "Snootch to the hootch" def render(self, request): response = Response() response.stream = MemoryStream(self.renderOutput) return response def generateResponse(method): resource = TestResource() method = getattr(resource, "http_" + method) return method(SimpleRequest(Site(resource), method, "/")) class RenderMixInTestCase (unittest.TestCase): """ Test RenderMixin. """ _my_allowed_methods = set(( "HEAD", "OPTIONS", "GET", "BLEARGH", "HUCKHUCKBLORP", "SWEETHOOKUPS", "HOOKUPS", )) def test_allowedMethods(self): """ RenderMixin.allowedMethods() """ self.assertEquals( set(TestResource().allowedMethods()), self._my_allowed_methods ) @inlineCallbacks def test_checkPreconditions_raises(self): """ RenderMixin.checkPreconditions() Exception raised in checkPreconditions() """ resource = TestResource() request = SimpleRequest(Site(resource), "BLEARGH", "/") # Check that checkPreconditions raises as expected self.assertRaises( PreconditionError, resource.checkPreconditions, request ) # Check that renderHTTP calls checkPreconditions yield self.failUnlessFailure( resource.renderHTTP(request), PreconditionError ) @inlineCallbacks def test_checkPreconditions_none(self): """ RenderMixin.checkPreconditions() checkPreconditions() returns None """ resource = TestResource() request = SimpleRequest(Site(resource), "SWEETHOOKUPS", "/") # Check that checkPreconditions without a raise doesn't barf self.assertEquals( (yield resource.renderHTTP(request)), responsecode.NO_CONTENT ) def test_checkPreconditions_deferred(self): """ RenderMixin.checkPreconditions() checkPreconditions() returns a deferred """ resource = TestResource() request = SimpleRequest(Site(resource), "HOOKUPS", "/") # Check that checkPreconditions without a raise doesn't barf def checkResponse(response): self.assertEquals(response, responsecode.NO_CONTENT) d = resource.renderHTTP(request) d.addCallback(checkResponse) def test_OPTIONS_status(self): """ RenderMixin.http_OPTIONS() Response code is OK """ response = generateResponse("OPTIONS") self.assertEquals(response.code, responsecode.OK) def test_OPTIONS_allow(self): """ RenderMixin.http_OPTIONS() Allow header indicates allowed methods """ response = generateResponse("OPTIONS") self.assertEquals( set(response.headers.getHeader("allow")), self._my_allowed_methods ) def test_TRACE_status(self): """ RenderMixin.http_TRACE() Response code is OK """ response = generateResponse("TRACE") self.assertEquals(response.code, responsecode.OK) test_TRACE_status.skip = "TRACE is disabled now." def test_TRACE_body(self): """ RenderMixin.http_TRACE() Check body for traciness """ raise NotImplementedError() test_TRACE_body.todo = "Someone should write this test" def test_HEAD_status(self): """ RenderMixin.http_HEAD() Response code is OK """ response = generateResponse("HEAD") self.assertEquals(response.code, responsecode.OK) def test_HEAD_body(self): """ RenderMixin.http_HEAD() Check body is empty """ response = generateResponse("HEAD") self.assertEquals(response.stream.length, 0) test_HEAD_body.todo = ( "http_HEAD is implemented in a goober way that " "relies on the server code to clean up after it." ) def test_GET_status(self): """ RenderMixin.http_GET() Response code is OK """ response = generateResponse("GET") self.assertEquals(response.code, responsecode.OK) def test_GET_body(self): """ RenderMixin.http_GET() Check body is empty """ response = generateResponse("GET") self.assertEquals( str(response.stream.read()), TestResource.renderOutput ) class ResourceTestCase (unittest.TestCase): """ Test Resource. """ def test_addSlash(self): # I think this would include a test of http_GET() raise NotImplementedError() test_addSlash.todo = "Someone should write this test" def test_locateChild(self): raise NotImplementedError() test_locateChild.todo = "Someone should write this test" def test_child_nonsense(self): raise NotImplementedError() test_child_nonsense.todo = "Someone should write this test" class PostableResourceTestCase (unittest.TestCase): """ Test PostableResource. """ def test_POST(self): raise NotImplementedError() test_POST.todo = "Someone should write this test" class LeafResourceTestCase (unittest.TestCase): """ Test LeafResource. """ def test_locateChild(self): resource = LeafResource() child, segments = ( resource.locateChild( SimpleRequest(Site(resource), "GET", "/"), ("", "foo"), ) ) self.assertEquals(child, resource) self.assertEquals(segments, StopTraversal) class WrapperResourceTestCase (unittest.TestCase): """ Test WrapperResource. """ def test_hook(self): raise NotImplementedError() test_hook.todo = "Someone should write this test" calendarserver-5.2+dfsg/twext/web2/test/simple_client.py0000644000175000017500000000156111337102650022502 0ustar rahulrahulimport socket, sys test_type = sys.argv[1] port = int(sys.argv[2]) socket_type = sys.argv[3] s = socket.socket(socket.AF_INET) s.connect(("127.0.0.1", port)) s.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 40000) if socket_type == 'ssl': s2 = socket.ssl(s) send=s2.write recv=s2.read else: send=s.send recv=s.recv print >> sys.stderr, ">> Making %s request to port %d" % (socket_type, port) send("GET /error HTTP/1.0\r\n") send("Host: localhost\r\n") if test_type == "lingeringClose": print >> sys.stderr, ">> Sending lots of data" send("Content-Length: 1000000\r\n\r\n") send("X"*1000000) else: send('\r\n') #import time #time.sleep(5) print >> sys.stderr, ">> Getting data" data='' while len(data) < 299999: try: x=recv(10000) except: break if x == '': break data+=x sys.stdout.write(data) calendarserver-5.2+dfsg/twext/web2/test/stream_data.txt0000644000175000017500000000002511337102650022320 0ustar rahulrahulWe've got some text! calendarserver-5.2+dfsg/twext/web2/test/test_metafd.py0000644000175000017500000002452012306427141022154 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for twext.web2.metafd. """ from socket import error as SocketError, AF_INET from errno import ENOTCONN from twext.internet import sendfdport from twext.web2 import metafd from twext.web2.channel.http import HTTPChannel from twext.web2.metafd import ReportingHTTPService, ConnectionLimiter from twisted.internet.tcp import Server from twisted.application.service import Service from twext.internet.test.test_sendfdport import ReaderAdder from twext.web2.metafd import WorkerStatus from twisted.trial.unittest import TestCase class FakeSocket(object): """ A fake socket for testing. """ def __init__(self, test): self.test = test def fileno(self): return "not a socket" def setblocking(self, blocking): return def getpeername(self): if self.test.peerNameSucceed: return ("4.3.2.1", 4321) else: raise SocketError(ENOTCONN, "Transport endpoint not connected") def getsockname(self): return ("4.3.2.1", 4321) class InheritedPortForTesting(sendfdport.InheritedPort): """ L{sendfdport.InheritedPort} subclass that prevents certain I/O operations for better unit testing. """ def startReading(self): "Do nothing." def stopReading(self): "Do nothing." def startWriting(self): "Do nothing." def stopWriting(self): "Do nothing." class ServerTransportForTesting(Server): """ tcp.Server replacement for testing purposes. """ def startReading(self): "Do nothing." def stopReading(self): "Do nothing." def startWriting(self): "Do nothing." def stopWriting(self): "Do nothing." def __init__(self, *a, **kw): super(ServerTransportForTesting, self).__init__(*a, **kw) self.reactor = None class ReportingHTTPServiceTests(TestCase): """ Tests for L{ReportingHTTPService} """ peerNameSucceed = True def setUp(self): def fakefromfd(fd, addressFamily, socketType): return FakeSocket(self) def fakerecvfd(fd): return "not an fd", "not a description" def fakeclose(fd): "" def fakegetsockfam(fd): return AF_INET self.patch(sendfdport, 'recvfd', fakerecvfd) self.patch(sendfdport, 'fromfd', fakefromfd) self.patch(sendfdport, 'close', fakeclose) self.patch(sendfdport, 'getsockfam', fakegetsockfam) self.patch(metafd, 'InheritedPort', InheritedPortForTesting) self.patch(metafd, 'Server', ServerTransportForTesting) # This last stubbed out just to prevent dirty reactor warnings. self.patch(HTTPChannel, "callLater", lambda *a, **k: None) self.svc = ReportingHTTPService(None, None, None) self.svc.startService() def test_quickClosedSocket(self): """ If a socket is closed very quickly after being {accept()}ed, requesting its peer (or even host) address may fail with C{ENOTCONN}. If this happens, its transport should be supplied with a dummy peer address. """ self.peerNameSucceed = False self.svc.reportingFactory.inheritedPort.doRead() channels = self.svc.reportingFactory.connectedChannels self.assertEqual(len(channels), 1) self.assertEqual(list(channels)[0].transport.getPeer().host, "0.0.0.0") class ConnectionLimiterTests(TestCase): """ Tests for L{ConnectionLimiter} """ def test_loadReducedStartsReadingAgain(self): """ L{ConnectionLimiter.statusesChanged} determines whether the current "load" of all subprocesses - that is, the total outstanding request count - is high enough that the listening ports attached to it should be suspended. """ builder = LimiterBuilder(self) builder.fillUp() self.assertEquals(builder.port.reading, False) # sanity check self.assertEquals(builder.highestLoad(), builder.requestsPerSocket) builder.loadDown() self.assertEquals(builder.port.reading, True) def test_processRestartedStartsReadingAgain(self): """ L{ConnectionLimiter.statusesChanged} determines whether the current number of outstanding requests is above the limit, and either stops or resumes reading on the listening port. """ builder = LimiterBuilder(self) builder.fillUp() self.assertEquals(builder.port.reading, False) self.assertEquals(builder.highestLoad(), builder.requestsPerSocket) builder.processRestart() self.assertEquals(builder.port.reading, True) def test_unevenLoadDistribution(self): """ Subprocess sockets should be selected for subsequent socket sends by ascending status. Status should sum sent and successfully subsumed sockets. """ builder = LimiterBuilder(self) # Give one simulated worker a higher acknowledged load than the other. builder.fillUp(True, 1) # There should still be plenty of spare capacity. self.assertEquals(builder.port.reading, True) # Then slam it with a bunch of incoming requests. builder.fillUp(False, builder.limiter.maxRequests - 1) # Now capacity is full. self.assertEquals(builder.port.reading, False) # And everyone should have an even amount of work. self.assertEquals(builder.highestLoad(), builder.requestsPerSocket) def test_processStopsReadingEvenWhenConnectionsAreNotAcknowledged(self): """ L{ConnectionLimiter.statusesChanged} determines whether the current number of outstanding requests is above the limit. """ builder = LimiterBuilder(self) builder.fillUp(acknowledged=False) self.assertEquals(builder.highestLoad(), builder.requestsPerSocket) self.assertEquals(builder.port.reading, False) builder.processRestart() self.assertEquals(builder.port.reading, True) def test_workerStatusRepr(self): """ L{WorkerStatus.__repr__} will show all the values associated with the status of the worker. """ self.assertEquals(repr(WorkerStatus(1, 2, 3, 4, 5, 6, 7, 8)), "") def test_workerStatusNonNegative(self): """ L{WorkerStatus.__repr__} will show all the values associated with the status of the worker. """ w = WorkerStatus() w.adjust( acknowledged=1, unacknowledged=-1, total=1, ) self.assertEquals(w.acknowledged, 1) self.assertEquals(w.unacknowledged, 0) self.assertEquals(w.total, 1) class LimiterBuilder(object): """ A L{LimiterBuilder} can build a L{ConnectionLimiter} and associated objects for a given unit test. """ def __init__(self, test, requestsPerSocket=3, socketCount=2): # Similar to MaxRequests in the configuration. self.requestsPerSocket = requestsPerSocket # Similar to ProcessCount in the configuration. self.socketCount = socketCount self.limiter = ConnectionLimiter( 2, maxRequests=requestsPerSocket * socketCount ) self.dispatcher = self.limiter.dispatcher self.dispatcher.reactor = ReaderAdder() self.service = Service() self.limiter.addPortService("TCP", 4321, "127.0.0.1", 5, self.serverServiceMakerMaker(self.service)) for ignored in xrange(socketCount): subskt = self.dispatcher.addSocket() subskt.start() subskt.restarted() # Has to be running in order to add stuff. self.limiter.startService() self.port = self.service.myPort def highestLoad(self): return max( skt.status.effective() for skt in self.limiter.dispatcher._subprocessSockets ) def serverServiceMakerMaker(self, s): """ Make a serverServiceMaker for use with L{ConnectionLimiter.addPortService}. """ class NotAPort(object): def startReading(self): self.reading = True def stopReading(self): self.reading = False def serverServiceMaker(port, factory, *a, **k): s.factory = factory s.myPort = NotAPort() # TODO: technically, the following should wait for startService s.myPort.startReading() factory.myServer = s return s return serverServiceMaker def fillUp(self, acknowledged=True, count=0): """ Fill up all the slots on the connection limiter. @param acknowledged: Should the virtual connections created by this method send a message back to the dispatcher indicating that the subprocess has acknowledged receipt of the file descriptor? @param count: Amount of load to add; default to the maximum that the limiter. """ for _ignore_x in range(count or self.limiter.maxRequests): self.dispatcher.sendFileDescriptor(None, "SSL") if acknowledged: self.dispatcher.statusMessage( self.dispatcher._subprocessSockets[0], "+" ) def processRestart(self): self.dispatcher._subprocessSockets[0].stop() self.dispatcher._subprocessSockets[0].start() self.dispatcher.statusMessage( self.dispatcher._subprocessSockets[0], "0" ) def loadDown(self): self.dispatcher.statusMessage( self.dispatcher._subprocessSockets[0], "-" ) calendarserver-5.2+dfsg/twext/web2/test/test_stream.py0000644000175000017500000005242711736444441022226 0ustar rahulrahul# Copyright (c) 2008 Twisted Matrix Laboratories. # See LICENSE for details. """ Tests for the stream implementations in L{twext.web2}. """ import tempfile, sys, os from zope.interface import implements from twisted.python.util import sibpath sibpath # sibpath is *not* unused - the doctests use it. from twisted.python.hashlib import md5 from twisted.internet import reactor, defer, interfaces from twisted.trial import unittest from twext.web2 import stream def bufstr(data): try: return str(buffer(data)) except TypeError: raise TypeError("%s doesn't conform to the buffer interface" % (data,)) class SimpleStreamTests: text = '1234567890' def test_split(self): for point in range(10): s = self.makeStream(0) a,b = s.split(point) if point > 0: self.assertEquals(bufstr(a.read()), self.text[:point]) self.assertEquals(a.read(), None) if point < len(self.text): self.assertEquals(bufstr(b.read()), self.text[point:]) self.assertEquals(b.read(), None) for point in range(7): s = self.makeStream(2, 6) self.assertEquals(s.length, 6) a,b = s.split(point) if point > 0: self.assertEquals(bufstr(a.read()), self.text[2:point+2]) self.assertEquals(a.read(), None) if point < 6: self.assertEquals(bufstr(b.read()), self.text[point+2:8]) self.assertEquals(b.read(), None) def test_read(self): s = self.makeStream() self.assertEquals(s.length, len(self.text)) self.assertEquals(bufstr(s.read()), self.text) self.assertEquals(s.read(), None) s = self.makeStream(0, 4) self.assertEquals(s.length, 4) self.assertEquals(bufstr(s.read()), self.text[0:4]) self.assertEquals(s.read(), None) self.assertEquals(s.length, 0) s = self.makeStream(4, 6) self.assertEquals(s.length, 6) self.assertEquals(bufstr(s.read()), self.text[4:10]) self.assertEquals(s.read(), None) self.assertEquals(s.length, 0) class FileStreamTest(SimpleStreamTests, unittest.TestCase): def makeStream(self, *args, **kw): return stream.FileStream(self.f, *args, **kw) def setUp(self): """ Create a file containing C{self.text} to be streamed. """ f = tempfile.TemporaryFile('w+') f.write(self.text) f.seek(0, 0) self.f = f def test_close(self): s = self.makeStream() s.close() self.assertEquals(s.length, 0) # Make sure close doesn't close file # would raise exception if f is closed self.f.seek(0, 0) def test_read2(self): s = self.makeStream(0) s.CHUNK_SIZE = 6 self.assertEquals(s.length, 10) self.assertEquals(bufstr(s.read()), self.text[0:6]) self.assertEquals(bufstr(s.read()), self.text[6:10]) self.assertEquals(s.read(), None) s = self.makeStream(0) s.CHUNK_SIZE = 5 self.assertEquals(s.length, 10) self.assertEquals(bufstr(s.read()), self.text[0:5]) self.assertEquals(bufstr(s.read()), self.text[5:10]) self.assertEquals(s.read(), None) s = self.makeStream(0, 20) self.assertEquals(s.length, 20) self.assertEquals(bufstr(s.read()), self.text) self.assertRaises(RuntimeError, s.read) # ran out of data class MMapFileStreamTest(SimpleStreamTests, unittest.TestCase): text = SimpleStreamTests.text text = text * (stream.MMAP_THRESHOLD // len(text) + 1) def makeStream(self, *args, **kw): return stream.FileStream(self.f, *args, **kw) def setUp(self): """ Create a file containing C{self.text}, which should be long enough to trigger the mmap-case in L{stream.FileStream}. """ f = tempfile.TemporaryFile('w+') f.write(self.text) f.seek(0, 0) self.f = f def test_mmapwrapper(self): self.assertRaises(TypeError, stream.mmapwrapper) self.assertRaises(TypeError, stream.mmapwrapper, offset = 0) self.assertRaises(TypeError, stream.mmapwrapper, offset = None) if not stream.mmap: test_mmapwrapper.skip = 'mmap not supported here' class MemoryStreamTest(SimpleStreamTests, unittest.TestCase): def makeStream(self, *args, **kw): return stream.MemoryStream(self.text, *args, **kw) def test_close(self): s = self.makeStream() s.close() self.assertEquals(s.length, 0) def test_read2(self): self.assertRaises(ValueError, self.makeStream, 0, 20) testdata = """I was angry with my friend: I told my wrath, my wrath did end. I was angry with my foe: I told it not, my wrath did grow. And I water'd it in fears, Night and morning with my tears; And I sunned it with smiles, And with soft deceitful wiles. And it grew both day and night, Till it bore an apple bright; And my foe beheld it shine, And he knew that is was mine, And into my garden stole When the night had veil'd the pole: In the morning glad I see My foe outstretch'd beneath the tree""" class TestBufferedStream(unittest.TestCase): def setUp(self): self.data = testdata.replace('\n', '\r\n') s = stream.MemoryStream(self.data) self.s = stream.BufferedStream(s) def _cbGotData(self, data, expected): self.assertEqual(data, expected) def test_readline(self): """Test that readline reads a line.""" d = self.s.readline() d.addCallback(self._cbGotData, 'I was angry with my friend:\r\n') return d def test_readlineWithSize(self): """Test the size argument to readline""" d = self.s.readline(size = 5) d.addCallback(self._cbGotData, 'I was') return d def test_readlineWithBigSize(self): """Test the size argument when it's bigger than the length of the line.""" d = self.s.readline(size = 40) d.addCallback(self._cbGotData, 'I was angry with my friend:\r\n') return d def test_readlineWithZero(self): """Test readline with size = 0.""" d = self.s.readline(size = 0) d.addCallback(self._cbGotData, '') return d def test_readlineFinished(self): """Test readline on a finished stream.""" nolines = len(self.data.split('\r\n')) for i in range(nolines): self.s.readline() d = self.s.readline() d.addCallback(self._cbGotData, '') return d def test_readlineNegSize(self): """Ensure that readline with a negative size raises an exception.""" self.assertRaises(ValueError, self.s.readline, size = -1) def test_readlineSizeInDelimiter(self): """ Test behavior of readline when size falls inside the delimiter. """ d = self.s.readline(size=28) d.addCallback(self._cbGotData, "I was angry with my friend:\r") d.addCallback(lambda _: self.s.readline()) d.addCallback(self._cbGotData, "\nI told my wrath, my wrath did end.\r\n") def test_readExactly(self): """Make sure readExactly with no arg reads all the data.""" d = self.s.readExactly() d.addCallback(self._cbGotData, self.data) return d def test_readExactlyLimited(self): """ Test readExactly with a number. """ d = self.s.readExactly(10) d.addCallback(self._cbGotData, self.data[:10]) return d def test_readExactlyBig(self): """ Test readExactly with a number larger than the size of the datastream. """ d = self.s.readExactly(100000) d.addCallback(self._cbGotData, self.data) return d def test_read(self): """ Make sure read() also functions. (note that this test uses an implementation detail of this particular stream. s.read() isn't guaranteed to return self.data on all streams.) """ self.assertEqual(str(self.s.read()), self.data) class TestStreamer: implements(stream.IStream, stream.IByteStream) length = None readCalled=0 closeCalled=0 def __init__(self, list): self.list = list def read(self): self.readCalled+=1 if self.list: return self.list.pop(0) return None def close(self): self.closeCalled+=1 self.list = [] class FallbackSplitTest(unittest.TestCase): def test_split(self): s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl']) left,right = stream.fallbackSplit(s, 5) self.assertEquals(left.length, 5) self.assertEquals(right.length, None) self.assertEquals(bufstr(left.read()), 'abcd') d = left.read() d.addCallback(self._cbSplit, left, right) return d def _cbSplit(self, result, left, right): self.assertEquals(bufstr(result), 'e') self.assertEquals(left.read(), None) self.assertEquals(bufstr(right.read().result), 'fgh') self.assertEquals(bufstr(right.read()), 'ijkl') self.assertEquals(right.read(), None) def test_split2(self): s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl']) left,right = stream.fallbackSplit(s, 4) self.assertEquals(left.length, 4) self.assertEquals(right.length, None) self.assertEquals(bufstr(left.read()), 'abcd') self.assertEquals(left.read(), None) self.assertEquals(bufstr(right.read().result), 'efgh') self.assertEquals(bufstr(right.read()), 'ijkl') self.assertEquals(right.read(), None) def test_splitsplit(self): s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl']) left,right = stream.fallbackSplit(s, 5) left,middle = left.split(3) self.assertEquals(left.length, 3) self.assertEquals(middle.length, 2) self.assertEquals(right.length, None) self.assertEquals(bufstr(left.read()), 'abc') self.assertEquals(left.read(), None) self.assertEquals(bufstr(middle.read().result), 'd') self.assertEquals(bufstr(middle.read().result), 'e') self.assertEquals(middle.read(), None) self.assertEquals(bufstr(right.read().result), 'fgh') self.assertEquals(bufstr(right.read()), 'ijkl') self.assertEquals(right.read(), None) def test_closeboth(self): s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl']) left,right = stream.fallbackSplit(s, 5) left.close() self.assertEquals(s.closeCalled, 0) right.close() # Make sure nothing got read self.assertEquals(s.readCalled, 0) self.assertEquals(s.closeCalled, 1) def test_closeboth_rev(self): s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl']) left,right = stream.fallbackSplit(s, 5) right.close() self.assertEquals(s.closeCalled, 0) left.close() # Make sure nothing got read self.assertEquals(s.readCalled, 0) self.assertEquals(s.closeCalled, 1) def test_closeleft(self): s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl']) left,right = stream.fallbackSplit(s, 5) left.close() d = right.read() d.addCallback(self._cbCloseleft, right) return d def _cbCloseleft(self, result, right): self.assertEquals(bufstr(result), 'fgh') self.assertEquals(bufstr(right.read()), 'ijkl') self.assertEquals(right.read(), None) def test_closeright(self): s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl']) left,right = stream.fallbackSplit(s, 3) right.close() self.assertEquals(bufstr(left.read()), 'abc') self.assertEquals(left.read(), None) self.assertEquals(s.closeCalled, 1) class ProcessStreamerTest(unittest.TestCase): if interfaces.IReactorProcess(reactor, None) is None: skip = "Platform lacks spawnProcess support, can't test process streaming." def runCode(self, code, inputStream=None): if inputStream is None: inputStream = stream.MemoryStream("") return stream.ProcessStreamer(inputStream, sys.executable, [sys.executable, "-u", "-c", code], os.environ) def test_output(self): p = self.runCode("import sys\nfor i in range(100): sys.stdout.write('x' * 1000)") l = [] d = stream.readStream(p.outStream, l.append) def verify(_): self.assertEquals("".join(l), ("x" * 1000) * 100) d2 = p.run() return d.addCallback(verify).addCallback(lambda _: d2) def test_errouput(self): p = self.runCode("import sys\nfor i in range(100): sys.stderr.write('x' * 1000)") l = [] d = stream.readStream(p.errStream, l.append) def verify(_): self.assertEquals("".join(l), ("x" * 1000) * 100) p.run() return d.addCallback(verify) def test_input(self): p = self.runCode("import sys\nsys.stdout.write(sys.stdin.read())", "hello world") l = [] d = stream.readStream(p.outStream, l.append) d2 = p.run() def verify(_): self.assertEquals("".join(l), "hello world") return d2 return d.addCallback(verify) def test_badexit(self): p = self.runCode("raise ValueError") l = [] from twisted.internet.error import ProcessTerminated def verify(_): self.assertEquals(l, [1]) self.assert_(p.outStream.closed) self.assert_(p.errStream.closed) return p.run().addErrback(lambda _: _.trap(ProcessTerminated) and l.append(1)).addCallback(verify) def test_inputerror(self): p = self.runCode("import sys\nsys.stdout.write(sys.stdin.read())", TestStreamer(["hello", defer.fail(ZeroDivisionError())])) l = [] d = stream.readStream(p.outStream, l.append) d2 = p.run() def verify(_): self.assertEquals("".join(l), "hello") return d2 def cbVerified(ignored): excs = self.flushLoggedErrors(ZeroDivisionError) self.assertEqual(len(excs), 1) return d.addCallback(verify).addCallback(cbVerified) def test_processclosedinput(self): p = self.runCode("import sys; sys.stdout.write(sys.stdin.read(3));" + "sys.stdin.close(); sys.stdout.write('def')", "abc123") l = [] d = stream.readStream(p.outStream, l.append) def verify(_): self.assertEquals("".join(l), "abcdef") d2 = p.run() return d.addCallback(verify).addCallback(lambda _: d2) class AdapterTestCase(unittest.TestCase): def test_adapt(self): fName = self.mktemp() f = file(fName, "w") f.write("test") f.close() for i in ("test", buffer("test"), file(fName)): s = stream.IByteStream(i) self.assertEquals(str(s.read()), "test") self.assertEquals(s.read(), None) class ReadStreamTestCase(unittest.TestCase): def test_pull(self): l = [] s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl']) return stream.readStream(s, l.append).addCallback( lambda _: self.assertEquals(l, ["abcd", "efgh", "ijkl"])) def test_pullFailure(self): l = [] s = TestStreamer(['abcd', defer.fail(RuntimeError()), 'ijkl']) def test(result): result.trap(RuntimeError) self.assertEquals(l, ["abcd"]) return stream.readStream(s, l.append).addErrback(test) def test_pullException(self): class Failer: def read(self): raise RuntimeError return stream.readStream(Failer(), lambda _: None).addErrback( lambda _: _.trap(RuntimeError)) def test_processingException(self): s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl']) return stream.readStream(s, lambda x: 1/0).addErrback( lambda _: _.trap(ZeroDivisionError)) class ProducerStreamTestCase(unittest.TestCase): def test_failfinish(self): p = stream.ProducerStream() p.write("hello") p.finish(RuntimeError()) self.assertEquals(p.read(), "hello") d = p.read() l = [] d.addErrback(lambda _: (l.append(1), _.trap(RuntimeError))).addCallback( lambda _: self.assertEquals(l, [1])) return d class CompoundStreamTest: """ CompoundStream lets you combine many streams into one continuous stream. For example, let's make a stream: >>> s = stream.CompoundStream() Then, add a couple streams: >>> s.addStream(stream.MemoryStream("Stream1")) >>> s.addStream(stream.MemoryStream("Stream2")) The length is the sum of all the streams: >>> s.length 14 We can read data from the stream: >>> str(s.read()) 'Stream1' After having read some data, length is now smaller, as you might expect: >>> s.length 7 So, continue reading... >>> str(s.read()) 'Stream2' Now that the stream is exhausted: >>> s.read() is None True >>> s.length 0 We can also create CompoundStream more easily like so: >>> s = stream.CompoundStream(['hello', stream.MemoryStream(' world')]) >>> str(s.read()) 'hello' >>> str(s.read()) ' world' For a more complicated example, let's try reading from a file: >>> s = stream.CompoundStream() >>> s.addStream(stream.FileStream(open(sibpath(__file__, "stream_data.txt")))) >>> s.addStream("================") >>> s.addStream(stream.FileStream(open(sibpath(__file__, "stream_data.txt")))) Again, the length is the sum: >>> int(s.length) 58 >>> str(s.read()) "We've got some text!\\n" >>> str(s.read()) '================' What if you close the stream? >>> s.close() >>> s.read() is None True >>> s.length 0 Error handling works using Deferreds: >>> m = stream.MemoryStream("after") >>> s = stream.CompoundStream([TestStreamer([defer.fail(ZeroDivisionError())]), m]) # z< >>> l = []; x = s.read().addErrback(lambda _: l.append(1)) >>> l [1] >>> s.length 0 >>> m.length # streams after the failed one got closed 0 """ class AsynchronousDummyStream(object): """ An L{IByteStream} implementation which always returns a L{defer.Deferred} from C{read} and lets an external driver fire them. """ def __init__(self): self._readResults = [] def read(self): result = defer.Deferred() self._readResults.append(result) return result def _write(self, bytes): self._readResults.pop(0).callback(bytes) class MD5StreamTest(unittest.TestCase): """ Tests for L{stream.MD5Stream}. """ data = "I am sorry Dave, I can't do that.\n--HAL 9000" digest = md5(data).hexdigest() def test_synchronous(self): """ L{stream.MD5Stream} computes the MD5 hash of the contents of the stream around which it is wrapped. It supports L{IByteStream} providers which return C{str} from their C{read} method. """ dataStream = stream.MemoryStream(self.data) md5Stream = stream.MD5Stream(dataStream) self.assertEquals(str(md5Stream.read()), self.data) self.assertIdentical(md5Stream.read(), None) md5Stream.close() self.assertEquals(self.digest, md5Stream.getMD5()) def test_asynchronous(self): """ L{stream.MD5Stream} also supports L{IByteStream} providers which return L{Deferreds} from their C{read} method. """ dataStream = AsynchronousDummyStream() md5Stream = stream.MD5Stream(dataStream) result = md5Stream.read() dataStream._write(self.data) result.addCallback(self.assertEquals, self.data) def cbRead(ignored): result = md5Stream.read() dataStream._write(None) result.addCallback(self.assertIdentical, None) return result result.addCallback(cbRead) def cbClosed(ignored): md5Stream.close() self.assertEquals(md5Stream.getMD5(), self.digest) result.addCallback(cbClosed) return result def test_getMD5FailsBeforeClose(self): """ L{stream.MD5Stream.getMD5} raises L{RuntimeError} if called before L{stream.MD5Stream.close}. """ dataStream = stream.MemoryStream(self.data) md5Stream = stream.MD5Stream(dataStream) self.assertRaises(RuntimeError, md5Stream.getMD5) def test_initializationFailsWithoutStream(self): """ L{stream.MD5Stream.__init__} raises L{ValueError} if passed C{None} as the stream to wrap. """ self.assertRaises(ValueError, stream.MD5Stream, None) def test_readAfterClose(self): """ L{stream.MD5Stream.read} raises L{RuntimeError} if called after L{stream.MD5Stream.close}. """ dataStream = stream.MemoryStream(self.data) md5Stream = stream.MD5Stream(dataStream) md5Stream.close() self.assertRaises(RuntimeError, md5Stream.read) __doctests__ = ['twext.web2.test.test_stream', 'twext.web2.stream'] # TODO: # CompoundStreamTest # more tests for ProducerStreamTest # StreamProducerTest calendarserver-5.2+dfsg/twext/web2/test/test_httpauth.py0000644000175000017500000010250612103053166022553 0ustar rahulrahul# Copyright (c) 2006-2009 Twisted Matrix Laboratories. # See LICENSE for details. from twisted.python.hashlib import md5 from twisted.internet import address from twisted.trial import unittest from twisted.cred import error from twext.web2 import http, responsecode from twext.web2.auth import basic, digest, wrapper from twext.web2.auth.interfaces import IAuthenticatedRequest, IHTTPUser from twext.web2.test.test_server import SimpleRequest from twext.web2.test import test_server import base64 _trivial_GET = SimpleRequest(None, 'GET', '/') FAKE_STATIC_NONCE = '178288758716122392881254770685' def makeDigestDeterministic(twistedDigestFactory, key="0", nonce=FAKE_STATIC_NONCE, time=0): """ Patch up various bits of private state to make a digest credential factory (the one that comes from Twisted) behave deterministically. """ def _fakeStaticNonce(): """ Generate a static nonce """ return nonce def _fakeStaticTime(): """ Return a stable time """ return time twistedDigestFactory.privateKey = key # FIXME: These tests are somewhat redundant with the tests for Twisted's # built-in digest auth; these private values need to be patched to # create deterministic results, but at some future point the whole # digest module should be removed from twext.web2 (as all of twext.web2 # should be removed) and we can just get rid of this. twistedDigestFactory._generateNonce = _fakeStaticNonce twistedDigestFactory._getTime = _fakeStaticTime class FakeDigestCredentialFactory(digest.DigestCredentialFactory): """ A Fake Digest Credential Factory that generates a predictable nonce and opaque """ def __init__(self, *args, **kwargs): super(FakeDigestCredentialFactory, self).__init__(*args, **kwargs) makeDigestDeterministic(self._real, self._fakeStaticPrivateKey) _fakeStaticPrivateKey = "0" class BasicAuthTestCase(unittest.TestCase): def setUp(self): self.credentialFactory = basic.BasicCredentialFactory('foo') self.username = 'dreid' self.password = 'S3CuR1Ty' def test_usernamePassword(self): """ Test acceptance of username/password in basic auth. """ response = base64.encodestring('%s:%s' % ( self.username, self.password)) d = self.credentialFactory.decode(response, _trivial_GET) return d.addCallback( lambda creds: self.failUnless(creds.checkPassword(self.password))) def test_incorrectPassword(self): """ Incorrect passwords cause auth to fail. """ response = base64.encodestring('%s:%s' % ( self.username, 'incorrectPassword')) d = self.credentialFactory.decode(response, _trivial_GET) return d.addCallback( lambda creds: self.failIf(creds.checkPassword(self.password))) def test_incorrectPadding(self): """ Responses that have incorrect padding cause auth to fail. """ response = base64.encodestring('%s:%s' % ( self.username, self.password)) response = response.strip('=') d = self.credentialFactory.decode(response, _trivial_GET) def _test(creds): self.failUnless(creds.checkPassword(self.password)) return d.addCallback(_test) def test_invalidCredentials(self): """ Auth attempts with no password should fail. """ response = base64.encodestring(self.username) d = self.credentialFactory.decode(response, _trivial_GET) self.assertFailure(d, error.LoginFailed) clientAddress = address.IPv4Address('TCP', '127.0.0.1', 80) challengeOpaque = ('75c4bd95b96b7b7341c646c6502f0833-MTc4Mjg4NzU' '4NzE2MTIyMzkyODgxMjU0NzcwNjg1LHJlbW90ZWhvc3Q' 'sMA==') challengeNonce = '178288758716122392881254770685' challengeResponse = ('digest', {'nonce': challengeNonce, 'qop': 'auth', 'realm': 'test realm', 'algorithm': 'md5', 'opaque': challengeOpaque}) cnonce = "29fc54aa1641c6fa0e151419361c8f23" authRequest1 = ('username="username", realm="test realm", nonce="%s", ' 'uri="/write/", response="%s", opaque="%s", algorithm="md5", ' 'cnonce="29fc54aa1641c6fa0e151419361c8f23", nc=00000001, ' 'qop="auth"') authRequest2 = ('username="username", realm="test realm", nonce="%s", ' 'uri="/write/", response="%s", opaque="%s", algorithm="md5", ' 'cnonce="29fc54aa1641c6fa0e151419361c8f23", nc=00000002, ' 'qop="auth"') namelessAuthRequest = 'realm="test realm",nonce="doesn\'t matter"' class DigestAuthTestCase(unittest.TestCase): """ Test the behavior of DigestCredentialFactory """ def setUp(self): """ Create a DigestCredentialFactory for testing """ self.credentialFactory = digest.DigestCredentialFactory('md5', 'test realm') def getDigestResponse(self, challenge, ncount): """ Calculate the response for the given challenge """ nonce = challenge.get('nonce') algo = challenge.get('algorithm').lower() qop = challenge.get('qop') expected = digest.calcResponse( digest.calcHA1(algo, "username", "test realm", "password", nonce, cnonce), algo, nonce, ncount, cnonce, qop, "GET", "/write/", None ) return expected def test_getChallenge(self): """ Test that all the required fields exist in the challenge, and that the information matches what we put into our DigestCredentialFactory """ d = self.credentialFactory.getChallenge(clientAddress) def _test(challenge): self.assertEquals(challenge['qop'], 'auth') self.assertEquals(challenge['realm'], 'test realm') self.assertEquals(challenge['algorithm'], 'md5') self.assertTrue(challenge.has_key("nonce")) self.assertTrue(challenge.has_key("opaque")) return d.addCallback(_test) def _createAndDecodeChallenge(self, chalID="00000001", req=_trivial_GET): d = self.credentialFactory.getChallenge(clientAddress) def _getChallenge(challenge): return authRequest1 % ( challenge['nonce'], self.getDigestResponse(challenge, chalID), challenge['opaque']) def _getResponse(clientResponse): return self.credentialFactory.decode(clientResponse, req) return d.addCallback(_getChallenge).addCallback(_getResponse) def test_response(self): """ Test that we can decode a valid response to our challenge """ d = self._createAndDecodeChallenge() def _test(creds): self.failUnless(creds.checkPassword('password')) return d.addCallback(_test) def test_multiResponse(self): """ Test that multiple responses to to a single challenge are handled successfully. """ d = self._createAndDecodeChallenge() def _test(creds): self.failUnless(creds.checkPassword('password')) def _test2(_): d2 = self._createAndDecodeChallenge("00000002") return d2.addCallback(_test) return d.addCallback(_test) def test_failsWithDifferentMethod(self): """ Test that the response fails if made for a different request method than it is being issued for. """ d = self._createAndDecodeChallenge(req=SimpleRequest(None, 'POST', '/')) def _test(creds): self.failIf(creds.checkPassword('password')) return d.addCallback(_test) def test_noUsername(self): """ Test that login fails when our response does not contain a username, or the username field is empty. """ # Check for no username e = self.assertRaises(error.LoginFailed, self.credentialFactory.decode, namelessAuthRequest, _trivial_GET) self.assertEquals(str(e), "Invalid response, no username given.") # Check for an empty username e = self.assertRaises(error.LoginFailed, self.credentialFactory.decode, namelessAuthRequest + ',username=""', _trivial_GET) self.assertEquals(str(e), "Invalid response, no username given.") def test_noNonce(self): """ Test that login fails when our response does not contain a nonce """ e = self.assertRaises(error.LoginFailed, self.credentialFactory.decode, 'realm="Test",username="Foo",opaque="bar"', _trivial_GET) self.assertEquals(str(e), "Invalid response, no nonce given.") def test_noOpaque(self): """ Test that login fails when our response does not contain a nonce """ e = self.assertRaises(error.LoginFailed, self.credentialFactory.decode, 'realm="Test",username="Foo"', _trivial_GET) self.assertEquals(str(e), "Invalid response, no opaque given.") def test_checkHash(self): """ Check that given a hash of the form 'username:realm:password' we can verify the digest challenge """ d = self._createAndDecodeChallenge() def _test(creds): self.failUnless(creds.checkHash( md5('username:test realm:password').hexdigest())) self.failIf(creds.checkHash( md5('username:test realm:bogus').hexdigest())) return d.addCallback(_test) def test_invalidOpaque(self): """ Test that login fails when the opaque does not contain all the required parts. """ credentialFactory = FakeDigestCredentialFactory('md5', 'test realm') d = credentialFactory.getChallenge(clientAddress) def _test(challenge): self.assertRaises( error.LoginFailed, credentialFactory.verifyOpaque, 'badOpaque', challenge['nonce'], clientAddress.host) badOpaque = ('foo-%s' % ( 'nonce,clientip'.encode('base64').strip('\n'),)) self.assertRaises( error.LoginFailed, credentialFactory.verifyOpaque, badOpaque, challenge['nonce'], clientAddress.host) self.assertRaises( error.LoginFailed, credentialFactory.verifyOpaque, '', challenge['nonce'], clientAddress.host) return d.addCallback(_test) def test_incompatibleNonce(self): """ Test that login fails when the given nonce from the response, does not match the nonce encoded in the opaque. """ credentialFactory = FakeDigestCredentialFactory('md5', 'test realm') d = credentialFactory.getChallenge(clientAddress) def _test(challenge): badNonceOpaque = credentialFactory.generateOpaque( '1234567890', clientAddress.host) self.assertRaises( error.LoginFailed, credentialFactory.verifyOpaque, badNonceOpaque, challenge['nonce'], clientAddress.host) self.assertRaises( error.LoginFailed, credentialFactory.verifyOpaque, badNonceOpaque, '', clientAddress.host) return d.addCallback(_test) def test_incompatibleClientIp(self): """ Test that the login fails when the request comes from a client ip other than what is encoded in the opaque. """ credentialFactory = FakeDigestCredentialFactory('md5', 'test realm') d = credentialFactory.getChallenge(clientAddress) def _test(challenge): badNonceOpaque = credentialFactory.generateOpaque( challenge['nonce'], '10.0.0.1') self.assertRaises( error.LoginFailed, credentialFactory.verifyOpaque, badNonceOpaque, challenge['nonce'], clientAddress.host) return d.addCallback(_test) def test_oldNonce(self): """ Test that the login fails when the given opaque is older than DigestCredentialFactory.CHALLENGE_LIFETIME_SECS """ credentialFactory = FakeDigestCredentialFactory('md5', 'test realm') d = credentialFactory.getChallenge(clientAddress) def _test(challenge): key = '%s,%s,%s' % (challenge['nonce'], clientAddress.host, '-137876876') digest = (md5(key + credentialFactory._fakeStaticPrivateKey) .hexdigest()) ekey = key.encode('base64') oldNonceOpaque = '%s-%s' % (digest, ekey.strip('\n')) self.assertRaises( error.LoginFailed, credentialFactory.verifyOpaque, oldNonceOpaque, challenge['nonce'], clientAddress.host) return d.addCallback(_test) def test_mismatchedOpaqueChecksum(self): """ Test that login fails when the opaque checksum fails verification """ credentialFactory = FakeDigestCredentialFactory('md5', 'test realm') d = credentialFactory.getChallenge(clientAddress) def _test(challenge): key = '%s,%s,%s' % (challenge['nonce'], clientAddress.host, '0') digest = md5(key + 'this is not the right pkey').hexdigest() badChecksum = '%s-%s' % (digest, key.encode('base64').strip('\n')) self.assertRaises( error.LoginFailed, credentialFactory.verifyOpaque, badChecksum, challenge['nonce'], clientAddress.host) return d.addCallback(_test) def test_incompatibleCalcHA1Options(self): """ Test that the appropriate error is raised when any of the pszUsername, pszRealm, or pszPassword arguments are specified with the preHA1 keyword argument. """ arguments = ( ("user", "realm", "password", "preHA1"), (None, "realm", None, "preHA1"), (None, None, "password", "preHA1"), ) for pszUsername, pszRealm, pszPassword, preHA1 in arguments: self.assertRaises( TypeError, digest.calcHA1, "md5", pszUsername, pszRealm, pszPassword, "nonce", "cnonce", preHA1=preHA1 ) def test_noNewlineOpaque(self): """ L{digest.DigestCredentialFactory._generateOpaque} returns a value without newlines, regardless of the length of the nonce. """ opaque = self.credentialFactory.generateOpaque( "long nonce " * 10, None) self.assertNotIn('\n', opaque) from zope.interface import implements from twisted.cred import portal, checkers class TestHTTPUser(object): """ Test avatar implementation for http auth with cred """ implements(IHTTPUser) username = None def __init__(self, username): """ @param username: The str username sent as part of the HTTP auth response. """ self.username = username class TestAuthRealm(object): """ Test realm that supports the IHTTPUser interface """ implements(portal.IRealm) def requestAvatar(self, avatarId, mind, *interfaces): if IHTTPUser in interfaces: if avatarId == checkers.ANONYMOUS: return IHTTPUser, TestHTTPUser('anonymous') return IHTTPUser, TestHTTPUser(avatarId) raise NotImplementedError("Only IHTTPUser interface is supported") class ProtectedResource(test_server.BaseTestResource): """ A test resource for use with HTTPAuthWrapper that holds on to it's request and segments so we can assert things about them. """ addSlash = True request = None segments = None def render(self, req): self.request = req return super(ProtectedResource, self).render(req) def locateChild(self, req, segments): self.segments = segments return super(ProtectedResource, self).locateChild(req, segments) class NonAnonymousResource(test_server.BaseTestResource): """ A resource that forces authentication by raising an HTTPError with an UNAUTHORIZED code if the request is an anonymous one. """ addSlash = True sendOwnHeaders = False def render(self, req): if req.avatar.username == 'anonymous': if not self.sendOwnHeaders: raise http.HTTPError(responsecode.UNAUTHORIZED) else: return http.Response( responsecode.UNAUTHORIZED, {'www-authenticate': [('basic', {'realm': 'foo'})]}) else: return super(NonAnonymousResource, self).render(req) class HTTPAuthResourceTest(test_server.BaseCase): """ Tests for the HTTPAuthWrapper Resource """ def setUp(self): """ Create a portal and add an in memory checker to it. Then set up a protectedResource that will be wrapped in each test. """ self.portal = portal.Portal(TestAuthRealm()) c = checkers.InMemoryUsernamePasswordDatabaseDontUse() c.addUser('username', 'password') self.portal.registerChecker(c) self.credFactory = basic.BasicCredentialFactory('test realm') self.protectedResource = ProtectedResource() self.protectedResource.responseText = "You shouldn't see me." def tearDown(self): """ Clean up by getting rid of the portal, credentialFactory, and protected resource """ del self.portal del self.credFactory del self.protectedResource def test_authenticatedRequest(self): """ Test that after successful authentication the request provides IAuthenticatedRequest and that the request.avatar implements the proper interfaces for this realm and has the proper values for this request. """ self.protectedResource.responseText = "I hope you can see me." root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces=(IHTTPUser,)) credentials = base64.encodestring('username:password') d = self.assertResponse((root, 'http://localhost/', {'authorization': ('basic', credentials)}), (200, {}, 'I hope you can see me.')) def checkRequest(result): resource = self.protectedResource self.failUnless(hasattr(resource, "request")) request = resource.request self.failUnless(IAuthenticatedRequest.providedBy(request)) self.failUnless(hasattr(request, "avatar")) self.failUnless(IHTTPUser.providedBy(request.avatar)) self.failUnless(hasattr(request, "avatarInterface")) self.assertEquals(request.avatarInterface, IHTTPUser) self.assertEquals(request.avatar.username, 'username') d.addCallback(checkRequest) return d def test_allowedMethods(self): """ Test that unknown methods result in a 401 instead of a 405 when authentication hasn't been completed. """ self.method = 'PROPFIND' root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces=(IHTTPUser,)) d = self.assertResponse( (root, 'http://localhost/'), (401, {'WWW-Authenticate': [('basic', {'realm': "test realm"})]}, None)) self.method = 'GET' return d def test_unauthorizedResponse(self): """ Test that a request with no credentials results in a valid Unauthorized response. """ root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces=(IHTTPUser,)) def makeDeepRequest(res): return self.assertResponse( (root, 'http://localhost/foo/bar/baz/bax'), (401, {'WWW-Authenticate': [('basic', {'realm': "test realm"})]}, None)) d = self.assertResponse( (root, 'http://localhost/'), (401, {'WWW-Authenticate': [('basic', {'realm': "test realm"})]}, None)) return d.addCallback(makeDeepRequest) def test_badCredentials(self): """ Test that a request with bad credentials results in a valid Unauthorized response """ root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces=(IHTTPUser,)) credentials = base64.encodestring('bad:credentials') d = self.assertResponse( (root, 'http://localhost/', {'authorization': [('basic', credentials)]}), (401, {'WWW-Authenticate': [('basic', {'realm': "test realm"})]}, None)) return d def test_successfulLogin(self): """ Test that a request with good credentials results in the appropriate response from the protected resource """ self.protectedResource.responseText = "I hope you can see me." root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces=(IHTTPUser,)) credentials = base64.encodestring('username:password') d = self.assertResponse((root, 'http://localhost/', {'authorization': ('basic', credentials)}), (200, {}, 'I hope you can see me.')) return d def test_wrongScheme(self): """ Test that a request with credentials for a scheme that is not advertised by this resource results in the appropriate unauthorized response. """ root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces=(IHTTPUser,)) d = self.assertResponse((root, 'http://localhost/', {'authorization': [('digest', 'realm="foo", response="crap"')]}), (401, {'www-authenticate': [('basic', {'realm': 'test realm'})]}, None)) return d def test_multipleWWWAuthenticateSchemes(self): """ Test that our unauthorized response can contain challenges for multiple authentication schemes. """ root = wrapper.HTTPAuthResource( self.protectedResource, (basic.BasicCredentialFactory('test realm'), FakeDigestCredentialFactory('md5', 'test realm')), self.portal, interfaces=(IHTTPUser,)) d = self.assertResponse((root, 'http://localhost/', {}), (401, {'www-authenticate': [challengeResponse, ('basic', {'realm': 'test realm'})]}, None)) return d def test_authorizationAgainstMultipleSchemes(self): """ Test that we can successfully authenticate when presented with multiple WWW-Authenticate headers """ root = wrapper.HTTPAuthResource( self.protectedResource, (basic.BasicCredentialFactory('test realm'), FakeDigestCredentialFactory('md5', 'test realm')), self.portal, interfaces=(IHTTPUser,)) def respondBasic(ign): credentials = base64.encodestring('username:password') d = self.assertResponse((root, 'http://localhost/', {'authorization': ('basic', credentials)}), (200, {}, None)) return d def respond(ign): d = self.assertResponse((root, 'http://localhost/', {'authorization': authRequest1}), (200, {}, None)) return d.addCallback(respondBasic) d = self.assertResponse((root, 'http://localhost/', {}), (401, {'www-authenticate': [challengeResponse, ('basic', {'realm': 'test realm'})]}, None)) return d def test_wrappedResourceGetsFullSegments(self): """ Test that the wrapped resource gets all the URL segments in it's locateChild. """ self.protectedResource.responseText = "I hope you can see me." root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces=(IHTTPUser,)) credentials = base64.encodestring('username:password') d = self.assertResponse((root, 'http://localhost/foo/bar/baz/bax', {'authorization': ('basic', credentials)}), (404, {}, None)) def checkSegments(ign): resource = self.protectedResource self.assertEquals(resource.segments, ['foo', 'bar', 'baz', 'bax']) d.addCallback(checkSegments) return d def test_invalidCredentials(self): """ Malformed or otherwise invalid credentials (as determined by the credential factory) should result in an Unauthorized response """ root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces=(IHTTPUser,)) credentials = base64.encodestring('Not Good Credentials') d = self.assertResponse((root, 'http://localhost/', {'authorization': ('basic', credentials)}), (401, {'WWW-Authenticate': [('basic', {'realm': "test realm"})]}, None)) return d def test_anonymousAuthentication(self): """ If our portal has a credentials checker for IAnonymous credentials authentication succeeds if no Authorization header is present """ self.portal.registerChecker(checkers.AllowAnonymousAccess()) self.protectedResource.responseText = "Anonymous access allowed" root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces=(IHTTPUser,)) def _checkRequest(ign): self.assertEquals( self.protectedResource.request.avatar.username, 'anonymous') d = self.assertResponse((root, 'http://localhost/', {}), (200, {}, "Anonymous access allowed")) d.addCallback(_checkRequest) return d def test_forceAuthentication(self): """ Test that if an HTTPError with an Unauthorized status code is raised from within our protected resource, we add the WWW-Authenticate headers if they do not already exist. """ self.portal.registerChecker(checkers.AllowAnonymousAccess()) nonAnonResource = NonAnonymousResource() nonAnonResource.responseText = "We don't like anonymous users" root = wrapper.HTTPAuthResource(nonAnonResource, [self.credFactory], self.portal, interfaces = (IHTTPUser,)) def _tryAuthenticate(result): credentials = base64.encodestring('username:password') d2 = self.assertResponse( (root, 'http://localhost/', {'authorization': ('basic', credentials)}), (200, {}, "We don't like anonymous users")) return d2 d = self.assertResponse( (root, 'http://localhost/', {}), (401, {'WWW-Authenticate': [('basic', {'realm': "test realm"})]}, None)) d.addCallback(_tryAuthenticate) return d def test_responseFilterDoesntClobberHeaders(self): """ Test that if an UNAUTHORIZED response is returned and already has 'WWW-Authenticate' headers we don't add them. """ self.portal.registerChecker(checkers.AllowAnonymousAccess()) nonAnonResource = NonAnonymousResource() nonAnonResource.responseText = "We don't like anonymous users" nonAnonResource.sendOwnHeaders = True root = wrapper.HTTPAuthResource(nonAnonResource, [self.credFactory], self.portal, interfaces = (IHTTPUser,)) d = self.assertResponse( (root, 'http://localhost/', {}), (401, {'WWW-Authenticate': [('basic', {'realm': "foo"})]}, None)) return d def test_renderHTTP(self): """ Test that if the renderHTTP method is ever called we authenticate the request and delegate rendering to the wrapper. """ self.protectedResource.responseText = "I hope you can see me." self.protectedResource.addSlash = True root = wrapper.HTTPAuthResource(self.protectedResource, [self.credFactory], self.portal, interfaces = (IHTTPUser,)) request = SimpleRequest(None, "GET", "/") request.prepath = [''] def _gotSecondResponse(response): self.assertEquals(response.code, 200) self.assertEquals(str(response.stream.read()), "I hope you can see me.") def _gotResponse(exception): response = exception.response self.assertEquals(response.code, 401) self.failUnless(response.headers.hasHeader('WWW-Authenticate')) self.assertEquals(response.headers.getHeader('WWW-Authenticate'), [('basic', {'realm': "test realm"})]) credentials = base64.encodestring('username:password') request.headers.setHeader('authorization', ['basic', credentials]) d = root.renderHTTP(request) d.addCallback(_gotSecondResponse) d = self.assertFailure(root.renderHTTP(request), http.HTTPError) d.addCallback(_gotResponse) return d calendarserver-5.2+dfsg/twext/web2/test/test_http_headers.py0000644000175000017500000010070612107006303023357 0ustar rahulrahul# Copyright (c) 2008 Twisted Matrix Laboratories. # See LICENSE for details. """ Tests for L{twext.web2.http_headers}. """ from twisted.trial import unittest import random import time from twext.web2 import http_headers from twext.web2.http_headers import Cookie, HeaderHandler, quoteString, generateKeyValues from twisted.python import util class parsedvalue: """Marker class""" def __init__(self, raw): self.raw = raw def __eq__(self, other): return isinstance(other, parsedvalue) and other.raw == self.raw class HeadersAPITest(unittest.TestCase): """Make sure the public API exists and works.""" def testRaw(self): rawvalue = ("value1", "value2") h = http_headers.Headers(handler=HeaderHandler(parsers={}, generators={})) h.setRawHeaders("test", rawvalue) self.assertEquals(h.hasHeader("test"), True) self.assertEquals(h.getRawHeaders("test"), rawvalue) self.assertEquals(list(h.getAllRawHeaders()), [('Test', rawvalue)]) self.assertEquals(h.getRawHeaders("foobar"), None) h.removeHeader("test") self.assertEquals(h.getRawHeaders("test"), None) def testParsed(self): parsed = parsedvalue(("value1", "value2")) h = http_headers.Headers(handler=HeaderHandler(parsers={}, generators={})) h.setHeader("test", parsed) self.assertEquals(h.hasHeader("test"), True) self.assertEquals(h.getHeader("test"), parsed) self.assertEquals(h.getHeader("foobar"), None) h.removeHeader("test") self.assertEquals(h.getHeader("test"), None) def testParsedAndRaw(self): def parse(raw): return parsedvalue(raw) def generate(parsed): return parsed.raw rawvalue = ("value1", "value2") rawvalue2 = ("value3", "value4") handler = HeaderHandler(parsers={'test': (parse,)}, generators={'test': (generate,)}) h = http_headers.Headers(handler=handler) h.setRawHeaders("test", rawvalue) self.assertEquals(h.getHeader("test"), parsedvalue(rawvalue)) h.setHeader("test", parsedvalue(rawvalue2)) self.assertEquals(h.getRawHeaders("test"), rawvalue2) # Check the initializers h = http_headers.Headers(rawHeaders={"test": rawvalue}, handler=handler) self.assertEquals(h.getHeader("test"), parsedvalue(rawvalue)) h = http_headers.Headers({"test": parsedvalue(rawvalue2)}, handler=handler) self.assertEquals(h.getRawHeaders("test"), rawvalue2) def testImmutable(self): h = http_headers.Headers(handler=HeaderHandler(parsers={}, generators={})) h.makeImmutable() self.assertRaises(AttributeError, h.setRawHeaders, "test", [1]) self.assertRaises(AttributeError, h.setHeader, "test", 1) self.assertRaises(AttributeError, h.removeHeader, "test") class TokenizerTest(unittest.TestCase): """Test header list parsing functions.""" def testParse(self): parser = lambda val: list(http_headers.tokenize([val, ])) Token = http_headers.Token tests = (('foo,bar', ['foo', Token(','), 'bar']), ('FOO,BAR', ['foo', Token(','), 'bar']), (' \t foo \t bar \t , \t baz ', ['foo', Token(' '), 'bar', Token(','), 'baz']), ('()<>@,;:\\/[]?={}', [Token('('), Token(')'), Token('<'), Token('>'), Token('@'), Token(','), Token(';'), Token(':'), Token('\\'), Token('/'), Token('['), Token(']'), Token('?'), Token('='), Token('{'), Token('}')]), (' "foo" ', ['foo']), ('"FOO(),\\"BAR,"', ['FOO(),"BAR,'])) raiseTests = ('"open quote', '"ending \\', "control character: \x127", "\x00", "\x1f") for test, result in tests: self.assertEquals(parser(test), result) for test in raiseTests: self.assertRaises(ValueError, parser, test) def testGenerate(self): pass def testRoundtrip(self): pass def atSpecifiedTime(when, func): def inner(*a, **kw): orig = time.time time.time = lambda: when try: return func(*a, **kw) finally: time.time = orig return util.mergeFunctionMetadata(func, inner) def parseHeader(name, val): head = http_headers.Headers(handler=http_headers.DefaultHTTPHandler) head.setRawHeaders(name, val) return head.getHeader(name) parseHeader = atSpecifiedTime(999999990, parseHeader) # Sun, 09 Sep 2001 01:46:30 GMT def generateHeader(name, val): head = http_headers.Headers(handler=http_headers.DefaultHTTPHandler) head.setHeader(name, val) return head.getRawHeaders(name) generateHeader = atSpecifiedTime(999999990, generateHeader) # Sun, 09 Sep 2001 01:46:30 GMT class HeaderParsingTestBase(unittest.TestCase): def runRoundtripTest(self, headername, table): """ Perform some assertions about the behavior of parsing and generating HTTP headers. Specifically: parse an HTTP header value, assert that the parsed form contains all the available information with the correct structure; generate the HTTP header value from the parsed form, assert that it contains certain literal strings; finally, re-parse the generated HTTP header value and assert that the resulting structured data is the same as the first-pass parsed form. @type headername: C{str} @param headername: The name of the HTTP header L{table} contains values for. @type table: A sequence of tuples describing inputs to and outputs from header parsing and generation. The tuples may be either 2 or 3 elements long. In either case: the first element is a string representing an HTTP-format header value; the second element is a dictionary mapping names of parameters to values of those parameters (the parsed form of the header). If there is a third element, it is a list of strings which must occur exactly in the HTTP header value string which is re-generated from the parsed form. """ for row in table: if len(row) == 2: rawHeaderInput, parsedHeaderData = row requiredGeneratedElements = [] elif len(row) == 3: rawHeaderInput, parsedHeaderData, requiredGeneratedElements = row assert isinstance(requiredGeneratedElements, list) # parser parsed = parseHeader(headername, [rawHeaderInput, ]) self.assertEquals(parsed, parsedHeaderData) regeneratedHeaderValue = generateHeader(headername, parsed) if requiredGeneratedElements: # generator for regeneratedElement in regeneratedHeaderValue: reqEle = requiredGeneratedElements[regeneratedHeaderValue.index(regeneratedElement)] elementIndex = regeneratedElement.find(reqEle) self.assertNotEqual( elementIndex, -1, "%r did not appear in generated HTTP header %r: %r" % (reqEle, headername, regeneratedElement)) # parser/generator reparsed = parseHeader(headername, regeneratedHeaderValue) self.assertEquals(parsed, reparsed) def invalidParseTest(self, headername, values): for val in values: parsed = parseHeader(headername, val) self.assertEquals(parsed, None) class GeneralHeaderParsingTests(HeaderParsingTestBase): def testCacheControl(self): table = ( ("no-cache", {'no-cache': None}), ("no-cache, no-store, max-age=5, max-stale=3, min-fresh=5, no-transform, only-if-cached, blahblah-extension-thingy", {'no-cache': None, 'no-store': None, 'max-age': 5, 'max-stale': 3, 'min-fresh': 5, 'no-transform': None, 'only-if-cached': None, 'blahblah-extension-thingy': None}), ("max-stale", {'max-stale': None}), ("public, private, no-cache, no-store, no-transform, must-revalidate, proxy-revalidate, max-age=5, s-maxage=10, blahblah-extension-thingy", {'public': None, 'private': None, 'no-cache': None, 'no-store': None, 'no-transform': None, 'must-revalidate': None, 'proxy-revalidate': None, 'max-age': 5, 's-maxage': 10, 'blahblah-extension-thingy': None}), ('private="Set-Cookie, Set-Cookie2", no-cache="PROXY-AUTHENTICATE"', {'private': ['set-cookie', 'set-cookie2'], 'no-cache': ['proxy-authenticate']}, ['private="Set-Cookie, Set-Cookie2"', 'no-cache="Proxy-Authenticate"']), ) self.runRoundtripTest("Cache-Control", table) def testConnection(self): table = ( ("close", ['close', ]), ("close, foo-bar", ['close', 'foo-bar']) ) self.runRoundtripTest("Connection", table) def testDate(self): # Don't need major tests since the datetime parser has its own tests self.runRoundtripTest("Date", (("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),)) # def testPragma(self): # fail # def testTrailer(self): # fail def testTransferEncoding(self): table = ( ('chunked', ['chunked']), ('gzip, chunked', ['gzip', 'chunked']) ) self.runRoundtripTest("Transfer-Encoding", table) # def testUpgrade(self): # fail # def testVia(self): # fail # def testWarning(self): # fail class RequestHeaderParsingTests(HeaderParsingTestBase): # FIXME test ordering too. def testAccept(self): table = ( ("audio/*;q=0.2, audio/basic", {http_headers.MimeType('audio', '*'): 0.2, http_headers.MimeType('audio', 'basic'): 1.0}), ("text/plain;q=0.5, text/html, text/x-dvi;q=0.8, text/x-c", {http_headers.MimeType('text', 'plain'): 0.5, http_headers.MimeType('text', 'html'): 1.0, http_headers.MimeType('text', 'x-dvi'): 0.8, http_headers.MimeType('text', 'x-c'): 1.0}), ("text/*, text/html, text/html;level=1, */*", {http_headers.MimeType('text', '*'): 1.0, http_headers.MimeType('text', 'html'): 1.0, http_headers.MimeType('text', 'html', (('level', '1'),)): 1.0, http_headers.MimeType('*', '*'): 1.0}), ("text/*;q=0.3, text/html;q=0.7, text/html;level=1, text/html;level=2;q=0.4, */*;q=0.5", {http_headers.MimeType('text', '*'): 0.3, http_headers.MimeType('text', 'html'): 0.7, http_headers.MimeType('text', 'html', (('level', '1'),)): 1.0, http_headers.MimeType('text', 'html', (('level', '2'),)): 0.4, http_headers.MimeType('*', '*'): 0.5}), ) self.runRoundtripTest("Accept", table) def testAcceptCharset(self): table = ( ("iso-8859-5, unicode-1-1;q=0.8", {'iso-8859-5': 1.0, 'iso-8859-1': 1.0, 'unicode-1-1': 0.8}, ["iso-8859-5", "unicode-1-1;q=0.8", "iso-8859-1"]), ("iso-8859-1;q=0.7", {'iso-8859-1': 0.7}), ("*;q=.7", {'*': 0.7}, ["*;q=0.7"]), ("", {'iso-8859-1': 1.0}, ["iso-8859-1"]), # Yes this is an actual change -- we'll say that's okay. :) ) self.runRoundtripTest("Accept-Charset", table) def testAcceptEncoding(self): table = ( ("compress, gzip", {'compress': 1.0, 'gzip': 1.0, 'identity': 0.0001}), ("", {'identity': 0.0001}), ("*", {'*': 1}), ("compress;q=0.5, gzip;q=1.0", {'compress': 0.5, 'gzip': 1.0, 'identity': 0.0001}, ["compress;q=0.5", "gzip"]), ("gzip;q=1.0, identity;q=0.5, *;q=0", {'gzip': 1.0, 'identity': 0.5, '*': 0}, ["gzip", "identity;q=0.5", "*;q=0"]), ) self.runRoundtripTest("Accept-Encoding", table) def testAcceptLanguage(self): table = ( ("da, en-gb;q=0.8, en;q=0.7", {'da': 1.0, 'en-gb': 0.8, 'en': 0.7}), ("*", {'*': 1}), ) self.runRoundtripTest("Accept-Language", table) def testAuthorization(self): table = ( ("Basic dXNlcm5hbWU6cGFzc3dvcmQ=", ("basic", "dXNlcm5hbWU6cGFzc3dvcmQ="), ["basic dXNlcm5hbWU6cGFzc3dvcmQ="]), ('Digest nonce="bar", realm="foo", username="baz", response="bax"', ('digest', 'nonce="bar", realm="foo", username="baz", response="bax"'), ['digest', 'nonce="bar"', 'realm="foo"', 'username="baz"', 'response="bax"']) ) self.runRoundtripTest("Authorization", table) def testCookie(self): table = ( ('name=value', [Cookie('name', 'value')]), ('"name"="value"', [Cookie('"name"', '"value"')]), ('name,"blah=value,"', [Cookie('name,"blah', 'value,"')]), ('name,"blah = value," ', [Cookie('name,"blah', 'value,"')], ['name,"blah=value,"']), ("`~!@#$%^&*()-_+[{]}\\|:'\",<.>/?=`~!@#$%^&*()-_+[{]}\\|:'\",<.>/?", [Cookie("`~!@#$%^&*()-_+[{]}\\|:'\",<.>/?", "`~!@#$%^&*()-_+[{]}\\|:'\",<.>/?")]), ('name,"blah = value," ; name2=val2', [Cookie('name,"blah', 'value,"'), Cookie('name2', 'val2')], ['name,"blah=value,"', 'name2=val2']), ) self.runRoundtripTest("Cookie", table) # newstyle RFC2965 Cookie table2 = ( ('$Version="1";' 'name="value";$Path="/foo";$Domain="www.local";$Port="80,8000";' 'name2="value"', [Cookie('name', 'value', path='/foo', domain='www.local', ports=(80, 8000), version=1), Cookie('name2', 'value', version=1)]), ('$Version="1";' 'name="value";$Port', [Cookie('name', 'value', ports=(), version=1)]), ('$Version = 1, NAME = "qq\\"qq",Frob=boo', [Cookie('name', 'qq"qq', version=1), Cookie('frob', 'boo', version=1)], ['$Version="1";name="qq\\"qq";frob="boo"']), ) self.runRoundtripTest("Cookie", table2) # Generate only! # make headers by combining oldstyle and newstyle cookies table3 = ( ([Cookie('name', 'value'), Cookie('name2', 'value2', version=1)], '$Version="1";name=value;name2="value2"'), ([Cookie('name', 'value', path="/foo"), Cookie('name2', 'value2', domain="bar.baz", version=1)], '$Version="1";name=value;$Path="/foo";name2="value2";$Domain="bar.baz"'), ([Cookie('invalid,"name', 'value'), Cookie('name2', 'value2', version=1)], '$Version="1";name2="value2"'), ([Cookie('name', 'qq"qq'), Cookie('name2', 'value2', version=1)], '$Version="1";name="qq\\"qq";name2="value2"'), ) for row in table3: self.assertEquals(generateHeader("Cookie", row[0]), [row[1], ]) def testSetCookie(self): table = ( ('name,"blah=value,; expires=Sun, 09 Sep 2001 01:46:40 GMT; path=/foo; domain=bar.baz; secure', [Cookie('name,"blah', 'value,', expires=1000000000, path="/foo", domain="bar.baz", secure=True)]), ('name,"blah = value, ; expires="Sun, 09 Sep 2001 01:46:40 GMT"', [Cookie('name,"blah', 'value,', expires=1000000000)], ['name,"blah=value,', 'expires=Sun, 09 Sep 2001 01:46:40 GMT']), ) self.runRoundtripTest("Set-Cookie", table) def testSetCookie2(self): table = ( ('name="value"; Comment="YadaYada"; CommentURL="http://frobnotz/"; Discard; Domain="blah.blah"; Max-Age=10; Path="/foo"; Port="80,8080"; Secure; Version="1"', [Cookie("name", "value", comment="YadaYada", commenturl="http://frobnotz/", discard=True, domain="blah.blah", expires=1000000000, path="/foo", ports=(80, 8080), secure=True, version=1)]), ) self.runRoundtripTest("Set-Cookie2", table) def testExpect(self): table = ( ("100-continue", {"100-continue": (None,)}), ('foobar=twiddle', {'foobar': ('twiddle',)}), ("foo=bar;a=b;c", {'foo': ('bar', ('a', 'b'), ('c', None))}) ) self.runRoundtripTest("Expect", table) def testPrefer(self): table = ( ("wait", [("wait", None, [])]), ("return = representation", [("return", "representation", [])]), ("return =minimal;arg1;arg2=val2", [("return", "minimal", [("arg1", None), ("arg2", "val2")])]), ) self.runRoundtripTest("Prefer", table) def testFrom(self): self.runRoundtripTest("From", (("webmaster@w3.org", "webmaster@w3.org"),)) def testHost(self): self.runRoundtripTest("Host", (("www.w3.org", "www.w3.org"),)) def testIfMatch(self): table = ( ('"xyzzy"', [http_headers.ETag('xyzzy')]), ('"xyzzy", "r2d2xxxx", "c3piozzzz"', [http_headers.ETag('xyzzy'), http_headers.ETag('r2d2xxxx'), http_headers.ETag('c3piozzzz')]), ('*', ['*']), ) self.runRoundtripTest("If-Match", table) def testIfModifiedSince(self): # Don't need major tests since the datetime parser has its own test # Just test stupid ; length= brokenness. table = ( ("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000), ("Sun, 09 Sep 2001 01:46:40 GMT; length=500", 1000000000, ["Sun, 09 Sep 2001 01:46:40 GMT"]), ) self.runRoundtripTest("If-Modified-Since", table) def testIfNoneMatch(self): table = ( ('"xyzzy"', [http_headers.ETag('xyzzy')]), ('W/"xyzzy", "r2d2xxxx", "c3piozzzz"', [http_headers.ETag('xyzzy', weak=True), http_headers.ETag('r2d2xxxx'), http_headers.ETag('c3piozzzz')]), ('W/"xyzzy", W/"r2d2xxxx", W/"c3piozzzz"', [http_headers.ETag('xyzzy', weak=True), http_headers.ETag('r2d2xxxx', weak=True), http_headers.ETag('c3piozzzz', weak=True)]), ('*', ['*']), ) self.runRoundtripTest("If-None-Match", table) def testIfRange(self): table = ( ('"xyzzy"', http_headers.ETag('xyzzy')), ('W/"xyzzy"', http_headers.ETag('xyzzy', weak=True)), ('W/"xyzzy"', http_headers.ETag('xyzzy', weak=True)), ("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000), ) self.runRoundtripTest("If-Range", table) def testIfUnmodifiedSince(self): self.runRoundtripTest("If-Unmodified-Since", (("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),)) def testMaxForwards(self): self.runRoundtripTest("Max-Forwards", (("15", 15),)) # def testProxyAuthorize(self): # fail def testRange(self): table = ( ("bytes=0-499", ('bytes', [(0, 499), ])), ("bytes=500-999", ('bytes', [(500, 999), ])), ("bytes=-500", ('bytes', [(None, 500), ])), ("bytes=9500-", ('bytes', [(9500, None), ])), ("bytes=0-0,-1", ('bytes', [(0, 0), (None, 1)])), ) self.runRoundtripTest("Range", table) def testReferer(self): self.runRoundtripTest("Referer", (("http://www.w3.org/hypertext/DataSources/Overview.html", "http://www.w3.org/hypertext/DataSources/Overview.html"),)) def testTE(self): table = ( ("deflate", {'deflate': 1}), ("", {}), ("trailers, deflate;q=0.5", {'trailers': 1, 'deflate': 0.5}), ) self.runRoundtripTest("TE", table) def testUserAgent(self): self.runRoundtripTest("User-Agent", (("CERN-LineMode/2.15 libwww/2.17b3", "CERN-LineMode/2.15 libwww/2.17b3"),)) class ResponseHeaderParsingTests(HeaderParsingTestBase): def testAcceptRanges(self): self.runRoundtripTest("Accept-Ranges", (("bytes", ["bytes"]), ("none", ["none"]))) def testAge(self): self.runRoundtripTest("Age", (("15", 15),)) def testETag(self): table = ( ('"xyzzy"', http_headers.ETag('xyzzy')), ('W/"xyzzy"', http_headers.ETag('xyzzy', weak=True)), ('""', http_headers.ETag('')), ) self.runRoundtripTest("ETag", table) def testLocation(self): self.runRoundtripTest("Location", (("http://www.w3.org/pub/WWW/People.htm", "http://www.w3.org/pub/WWW/People.htm"),)) # def testProxyAuthenticate(self): # fail def testRetryAfter(self): # time() is always 999999990 when being tested. table = ( ("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000, ["10"]), ("120", 999999990 + 120), ) self.runRoundtripTest("Retry-After", table) def testServer(self): self.runRoundtripTest("Server", (("CERN/3.0 libwww/2.17", "CERN/3.0 libwww/2.17"),)) def testVary(self): table = ( ("*", ["*"]), ("Accept, Accept-Encoding", ["accept", "accept-encoding"], ["accept", "accept-encoding"]) ) self.runRoundtripTest("Vary", table) def testWWWAuthenticate(self): digest = ('Digest realm="digest realm", nonce="bAr", qop="auth"', [('Digest', {'realm': 'digest realm', 'nonce': 'bAr', 'qop': 'auth'})], ['Digest', 'realm="digest realm"', 'nonce="bAr"', 'qop="auth"']) basic = ('Basic realm="foo"', [('Basic', {'realm': 'foo'})], ['Basic', 'realm="foo"']) ntlm = ('NTLM', [('NTLM', {})], ['NTLM', '']) negotiate = ('Negotiate SomeGssAPIData', [('Negotiate', 'SomeGssAPIData')], ['Negotiate', 'SomeGssAPIData']) table = (digest, basic, (digest[0] + ', ' + basic[0], digest[1] + basic[1], [digest[2], basic[2]]), ntlm, negotiate, (ntlm[0] + ', ' + basic[0], ntlm[1] + basic[1], [ntlm[2], basic[2]]), (digest[0] + ', ' + negotiate[0], digest[1] + negotiate[1], [digest[2], negotiate[2]]), (negotiate[0] + ', ' + negotiate[0], negotiate[1] + negotiate[1], [negotiate[2] + negotiate[2]]), (ntlm[0] + ', ' + ntlm[0], ntlm[1] + ntlm[1], [ntlm[2], ntlm[2]]), (basic[0] + ', ' + ntlm[0], basic[1] + ntlm[1], [basic[2], ntlm[2]]), ) # runRoundtripTest doesn't work because we don't generate a single # header headername = 'WWW-Authenticate' for row in table: rawHeaderInput, parsedHeaderData, requiredGeneratedElements = row parsed = parseHeader(headername, [rawHeaderInput, ]) self.assertEquals(parsed, parsedHeaderData) regeneratedHeaderValue = generateHeader(headername, parsed) for regeneratedElement in regeneratedHeaderValue: requiredElements = requiredGeneratedElements[ regeneratedHeaderValue.index( regeneratedElement)] for reqEle in requiredElements: elementIndex = regeneratedElement.find(reqEle) self.assertNotEqual( elementIndex, -1, "%r did not appear in generated HTTP header %r: %r" % (reqEle, headername, regeneratedElement)) # parser/generator reparsed = parseHeader(headername, regeneratedHeaderValue) self.assertEquals(parsed, reparsed) class EntityHeaderParsingTests(HeaderParsingTestBase): def testAllow(self): # Allow is a silly case-sensitive header unlike all the rest table = ( ("GET", ['GET', ]), ("GET, HEAD, PUT", ['GET', 'HEAD', 'PUT']), ) self.runRoundtripTest("Allow", table) def testContentEncoding(self): table = ( ("gzip", ['gzip', ]), ) self.runRoundtripTest("Content-Encoding", table) def testContentLanguage(self): table = ( ("da", ['da', ]), ("mi, en", ['mi', 'en']), ) self.runRoundtripTest("Content-Language", table) def testContentLength(self): self.runRoundtripTest("Content-Length", (("15", 15),)) self.invalidParseTest("Content-Length", ("asdf",)) def testContentLocation(self): self.runRoundtripTest("Content-Location", (("http://www.w3.org/pub/WWW/People.htm", "http://www.w3.org/pub/WWW/People.htm"),)) def testContentMD5(self): self.runRoundtripTest("Content-MD5", (("Q2hlY2sgSW50ZWdyaXR5IQ==", "Check Integrity!"),)) self.invalidParseTest("Content-MD5", ("sdlaksjdfhlkaj",)) def testContentRange(self): table = ( ("bytes 0-499/1234", ("bytes", 0, 499, 1234)), ("bytes 500-999/1234", ("bytes", 500, 999, 1234)), ("bytes 500-1233/1234", ("bytes", 500, 1233, 1234)), ("bytes 734-1233/1234", ("bytes", 734, 1233, 1234)), ("bytes 734-1233/*", ("bytes", 734, 1233, None)), ("bytes */1234", ("bytes", None, None, 1234)), ("bytes */*", ("bytes", None, None, None)) ) self.runRoundtripTest("Content-Range", table) def testContentType(self): table = ( ("text/html;charset=iso-8859-4", http_headers.MimeType('text', 'html', (('charset', 'iso-8859-4'),))), ("text/html", http_headers.MimeType('text', 'html')), ) self.runRoundtripTest("Content-Type", table) def testContentDisposition(self): table = ( ("attachment;filename=foo.txt", http_headers.MimeDisposition('attachment', (('filename', 'foo.txt'),))), ("inline", http_headers.MimeDisposition('inline')), ) self.runRoundtripTest("Content-Disposition", table) def testExpires(self): self.runRoundtripTest("Expires", (("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),)) # Invalid expires MUST return date in the past. self.assertEquals(parseHeader("Expires", ["0"]), 0) self.assertEquals(parseHeader("Expires", ["wejthnaljn"]), 0) def testLastModified(self): # Don't need major tests since the datetime parser has its own test self.runRoundtripTest("Last-Modified", (("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),)) class DateTimeTest(unittest.TestCase): """Test date parsing functions.""" def testParse(self): timeNum = 784111777 timeStrs = ('Sun, 06 Nov 1994 08:49:37 GMT', 'Sunday, 06-Nov-94 08:49:37 GMT', 'Sun Nov 6 08:49:37 1994', # Also some non-RFC formats, for good measure. 'Somefakeday 6 Nov 1994 8:49:37', '6 Nov 1994 8:49:37', 'Sun, 6 Nov 1994 8:49:37', '6 Nov 1994 8:49:37 GMT', '06-Nov-94 08:49:37', 'Sunday, 06-Nov-94 08:49:37', '06-Nov-94 08:49:37 GMT', 'Nov 6 08:49:37 1994', ) for timeStr in timeStrs: self.assertEquals(http_headers.parseDateTime(timeStr), timeNum) # Test 2 Digit date wraparound yuckiness. self.assertEquals(http_headers.parseDateTime( 'Monday, 11-Oct-04 14:56:50 GMT'), 1097506610) self.assertEquals(http_headers.parseDateTime( 'Monday, 11-Oct-2004 14:56:50 GMT'), 1097506610) def testGenerate(self): self.assertEquals(http_headers.generateDateTime(784111777), 'Sun, 06 Nov 1994 08:49:37 GMT') def testRoundtrip(self): for _ignore in range(2000): randomTime = random.randint(0, 2000000000) timestr = http_headers.generateDateTime(randomTime) time2 = http_headers.parseDateTime(timestr) self.assertEquals(randomTime, time2) class TestMimeType(unittest.TestCase): def testEquality(self): """Test that various uses of the constructer are equal """ kwargMime = http_headers.MimeType('text', 'plain', key='value', param=None) dictMime = http_headers.MimeType('text', 'plain', {'param': None, 'key': 'value'}) tupleMime = http_headers.MimeType('text', 'plain', (('param', None), ('key', 'value'))) stringMime = http_headers.MimeType.fromString('text/plain;key=value;param') self.assertEquals(kwargMime, dictMime) self.assertEquals(dictMime, tupleMime) self.assertEquals(kwargMime, tupleMime) self.assertEquals(kwargMime, stringMime) class TestMimeDisposition(unittest.TestCase): def testEquality(self): """Test that various uses of the constructer are equal """ kwargMime = http_headers.MimeDisposition('attachment', key='value') dictMime = http_headers.MimeDisposition('attachment', {'key': 'value'}) tupleMime = http_headers.MimeDisposition('attachment', (('key', 'value'),)) stringMime = http_headers.MimeDisposition.fromString('attachment;key=value') self.assertEquals(kwargMime, dictMime) self.assertEquals(dictMime, tupleMime) self.assertEquals(kwargMime, tupleMime) self.assertEquals(kwargMime, stringMime) class FormattingUtilityTests(unittest.TestCase): """ Tests for various string formatting functionality required to generate headers. """ def test_quoteString(self): """ L{quoteString} returns a string which when interpreted according to the rules for I{quoted-string} (RFC 2616 section 2.2) matches the input string. """ self.assertEqual( quoteString('a\\b"c'), '"a\\\\b\\"c"') def test_generateKeyValues(self): """ L{generateKeyValues} accepts an iterable of parameters and returns a string formatted according to RFC 2045 section 5.1. """ self.assertEqual( generateKeyValues(iter([("foo", "bar"), ("baz", "quux")])), "foo=bar;baz=quux") def test_generateKeyValuesNone(self): """ L{generateKeyValues} accepts C{None} as the 2nd element of a tuple and includes just the 1st element in the output without an C{"="}. """ self.assertEqual( generateKeyValues([("foo", None), ("bar", "baz")]), "foo;bar=baz") def test_generateKeyValuesQuoting(self): """ L{generateKeyValues} quotes the value of the 2nd element of a tuple if it includes a character which cannot be in an HTTP token as defined in RFC 2616 section 2.2. """ for needsQuote in [' ', '\t', '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}']: self.assertEqual( generateKeyValues([("foo", needsQuote)]), 'foo=%s' % (quoteString(needsQuote),)) calendarserver-5.2+dfsg/twext/web2/test/test_client.py0000644000175000017500000003561512113213176022176 0ustar rahulrahul# Copyright (c) 2001-2007 Twisted Matrix Laboratories. # See LICENSE for details. from __future__ import print_function """ Tests for HTTP client. """ from twisted.internet import protocol, defer from twext.web2.client import http from twext.web2 import http_headers from twext.web2 import stream from twext.web2.test.test_http import LoopbackRelay, HTTPTests, TestConnection class TestServer(protocol.Protocol): data = "" done = False def dataReceived(self, data): self.data += data def write(self, data): self.transport.write(data) def connectionLost(self, reason): self.done = True self.transport.loseConnection() def loseConnection(self): self.done = True self.transport.loseConnection() class ClientTests(HTTPTests): def connect(self, logFile=None, maxPipeline=4, inputTimeOut=60000, betweenRequestsTimeOut=600000): cxn = TestConnection() cxn.client = http.HTTPClientProtocol() cxn.client.inputTimeOut = inputTimeOut cxn.server = TestServer() cxn.serverToClient = LoopbackRelay(cxn.client, logFile) cxn.clientToServer = LoopbackRelay(cxn.server, logFile) cxn.server.makeConnection(cxn.serverToClient) cxn.client.makeConnection(cxn.clientToServer) return cxn def writeToClient(self, cxn, data): cxn.server.write(data) self.iterate(cxn) def writeLines(self, cxn, lines): self.writeToClient(cxn, '\r\n'.join(lines)) def assertReceived(self, cxn, expectedStatus, expectedHeaders, expectedContent=None): self.iterate(cxn) headers, content = cxn.server.data.split('\r\n\r\n', 1) status, headers = headers.split('\r\n', 1) headers = headers.split('\r\n') # check status line self.assertEquals(status, expectedStatus) # check headers (header order isn't guraunteed so we use # self.assertIn for x in headers: self.assertIn(x, expectedHeaders) if not expectedContent: expectedContent = '' self.assertEquals(content, expectedContent) def assertDone(self, cxn): self.iterate(cxn) self.assertEquals(cxn.server.done, True, 'Connection not closed.') def assertHeaders(self, resp, expectedHeaders): headers = list(resp.headers.getAllRawHeaders()) headers.sort() self.assertEquals(headers, expectedHeaders) def checkResponse(self, resp, code, headers, length, data): """ Assert various things about a response: http code, headers, stream length, and data in stream. """ def gotData(gotdata): self.assertEquals(gotdata, data) self.assertEquals(resp.code, code) self.assertHeaders(resp, headers) self.assertEquals(resp.stream.length, length) return defer.maybeDeferred(resp.stream.read).addCallback(gotData) class TestHTTPClient(ClientTests): """ Test that the http client works. """ def test_simpleRequest(self): """ Your basic simple HTTP Request. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('GET', '/', None, None) d = cxn.client.submitRequest(req).addCallback(self.checkResponse, 200, [], 10, '1234567890') self.assertReceived(cxn, 'GET / HTTP/1.1', ['Connection: close']) self.writeLines(cxn, ('HTTP/1.1 200 OK', 'Content-Length: 10', 'Connection: close', '', '1234567890')) return d.addCallback(lambda _: self.assertDone(cxn)) def test_delayedContent(self): """ Make sure that the client returns the response object as soon as the headers are received, even if the data hasn't arrived yet. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('GET', '/', None, None) def gotData(data): self.assertEquals(data, '1234567890') def gotResp(resp): self.assertEquals(resp.code, 200) self.assertHeaders(resp, []) self.assertEquals(resp.stream.length, 10) self.writeToClient(cxn, '1234567890') return defer.maybeDeferred(resp.stream.read).addCallback(gotData) d = cxn.client.submitRequest(req).addCallback(gotResp) self.assertReceived(cxn, 'GET / HTTP/1.1', ['Connection: close']) self.writeLines(cxn, ('HTTP/1.1 200 OK', 'Content-Length: 10', 'Connection: close', '\r\n')) return d.addCallback(lambda _: self.assertDone(cxn)) def test_prematurePipelining(self): """ Ensure that submitting a second request before it's allowed results in an AssertionError. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('GET', '/', None, None) req2 = http.ClientRequest('GET', '/bar', None, None) d = cxn.client.submitRequest(req, closeAfter=False).addCallback( self.checkResponse, 200, [], 0, None) self.assertRaises(AssertionError, cxn.client.submitRequest, req2) self.assertReceived(cxn, 'GET / HTTP/1.1', ['Connection: Keep-Alive']) self.writeLines(cxn, ('HTTP/1.1 200 OK', 'Content-Length: 0', 'Connection: close', '\r\n')) return d def test_userHeaders(self): """ Make sure that headers get through in both directions. """ cxn = self.connect(inputTimeOut=None) def submitNext(_): headers = http_headers.Headers( headers={'Accept-Language': {'en': 1.0}}, rawHeaders={'X-My-Other-Header': ['socks']}) req = http.ClientRequest('GET', '/', headers, None) cxn.server.data = '' d = cxn.client.submitRequest(req, closeAfter=True) self.assertReceived(cxn, 'GET / HTTP/1.1', ['Connection: close', 'X-My-Other-Header: socks', 'Accept-Language: en']) self.writeLines(cxn, ('HTTP/1.1 200 OK', 'Content-Length: 0', 'Connection: close', '\r\n')) return d req = http.ClientRequest('GET', '/', {'Accept-Language': {'en': 1.0}}, None) d = cxn.client.submitRequest(req, closeAfter=False).addCallback( self.checkResponse, 200, [('X-Foobar', ['Yes'])], 0, None).addCallback( submitNext) self.assertReceived(cxn, 'GET / HTTP/1.1', ['Connection: Keep-Alive', 'Accept-Language: en']) self.writeLines(cxn, ('HTTP/1.1 200 OK', 'Content-Length: 0', 'X-Foobar: Yes', '\r\n')) return d.addCallback(lambda _: self.assertDone(cxn)) def test_streamedUpload(self): """ Make sure that sending request content works. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('PUT', '/foo', None, 'Helloooo content') d = cxn.client.submitRequest(req).addCallback(self.checkResponse, 202, [], 0, None) self.assertReceived(cxn, 'PUT /foo HTTP/1.1', ['Connection: close', 'Content-Length: 16'], 'Helloooo content') self.writeLines(cxn, ('HTTP/1.1 202 Accepted', 'Content-Length: 0', 'Connection: close', '\r\n')) return d.addCallback(lambda _: self.assertDone(cxn)) def test_sentHead(self): """ Ensure that HEAD requests work, and return Content-Length. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('HEAD', '/', None, None) d = cxn.client.submitRequest(req).addCallback(self.checkResponse, 200, [('Content-Length', ['5'])], 0, None) self.assertReceived(cxn, 'HEAD / HTTP/1.1', ['Connection: close']) self.writeLines(cxn, ('HTTP/1.1 200 OK', 'Connection: close', 'Content-Length: 5', '', 'Pants')) # bad server return d.addCallback(lambda _: self.assertDone(cxn)) def test_sentHeadKeepAlive(self): """ Ensure that keepalive works right after a HEAD request. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('HEAD', '/', None, None) didIt = [0] def gotData(data): self.assertEquals(data, None) def gotResp(resp): self.assertEquals(resp.code, 200) self.assertEquals(resp.stream.length, 0) self.assertHeaders(resp, []) return defer.maybeDeferred(resp.stream.read).addCallback(gotData) def submitRequest(second): if didIt[0]: return didIt[0] = second if second: keepAlive='close' else: keepAlive='Keep-Alive' cxn.server.data = '' d = cxn.client.submitRequest(req, closeAfter=second).addCallback( self.checkResponse, 200, [('Content-Length', ['5'])], 0, None) self.assertReceived(cxn, 'HEAD / HTTP/1.1', ['Connection: '+ keepAlive]) self.writeLines(cxn, ('HTTP/1.1 200 OK', 'Connection: '+ keepAlive, 'Content-Length: 5', '\r\n')) return d.addCallback(lambda _: submitRequest(1)) d = submitRequest(0) return d.addCallback(lambda _: self.assertDone(cxn)) def test_chunkedUpload(self): """ Ensure chunked data is correctly decoded on upload. """ cxn = self.connect(inputTimeOut=None) data = 'Foo bar baz bax' s = stream.ProducerStream(length=None) s.write(data) req = http.ClientRequest('PUT', '/', None, s) d = cxn.client.submitRequest(req) s.finish() self.assertReceived(cxn, 'PUT / HTTP/1.1', ['Connection: close', 'Transfer-Encoding: chunked'], '%X\r\n%s\r\n0\r\n\r\n' % (len(data), data)) self.writeLines(cxn, ('HTTP/1.1 200 OK', 'Connection: close', 'Content-Length: 0', '\r\n')) return d.addCallback(lambda _: self.assertDone(cxn)) class TestEdgeCases(ClientTests): def test_serverDoesntSendConnectionClose(self): """ Check that a lost connection is treated as end of response, if we requested connection: close, even if the server didn't respond with connection: close. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('GET', '/', None, None) d = cxn.client.submitRequest(req).addCallback(self.checkResponse, 200, [], None, 'Some Content') self.assertReceived(cxn, 'GET / HTTP/1.1', ['Connection: close']) self.writeLines(cxn, ('HTTP/1.1 200 OK', '', 'Some Content')) return d.addCallback(lambda _: self.assertDone(cxn)) def test_serverIsntHttp(self): """ Check that an error is returned if the server doesn't talk HTTP. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('GET', '/', None, None) def gotResp(r): print(r) d = cxn.client.submitRequest(req).addCallback(gotResp) self.assertFailure(d, http.ProtocolError) self.writeLines(cxn, ('HTTP-NG/1.1 200 OK', '\r\n')) def test_newServer(self): """ Check that an error is returned if the server is a new major version. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('GET', '/', None, None) d = cxn.client.submitRequest(req) self.assertFailure(d, http.ProtocolError) self.writeLines(cxn, ('HTTP/2.3 200 OK', '\r\n')) def test_shortStatus(self): """ Check that an error is returned if the response line is invalid. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('GET', '/', None, None) d = cxn.client.submitRequest(req) self.assertFailure(d, http.ProtocolError) self.writeLines(cxn, ('HTTP/1.1 200', '\r\n')) def test_errorReadingRequestStream(self): """ Ensure that stream errors are propagated to the response. """ cxn = self.connect(inputTimeOut=None) s = stream.ProducerStream() s.write('Foo') req = http.ClientRequest('GET', '/', None, s) d = cxn.client.submitRequest(req) s.finish(IOError('Test Error')) return self.assertFailure(d, IOError) def test_connectionLost(self): """ Check that closing the connection is propagated to the response deferred. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('GET', '/', None, None) d = cxn.client.submitRequest(req) self.assertReceived(cxn, 'GET / HTTP/1.1', ['Connection: close']) cxn.client.connectionLost(ValueError("foo")) return self.assertFailure(d, ValueError) def test_connectionLostAfterHeaders(self): """ Test that closing the connection after headers are sent is propagated to the response stream. """ cxn = self.connect(inputTimeOut=None) req = http.ClientRequest('GET', '/', None, None) d = cxn.client.submitRequest(req) self.assertReceived(cxn, 'GET / HTTP/1.1', ['Connection: close']) self.writeLines(cxn, ('HTTP/1.1 200 OK', 'Content-Length: 10', 'Connection: close', '\r\n')) cxn.client.connectionLost(ValueError("foo")) def cb(response): return self.assertFailure(response.stream.read(), ValueError) d.addCallback(cb) return d calendarserver-5.2+dfsg/twext/web2/test/server.pem0000644000175000017500000000400011337102650021301 0ustar rahulrahul-----BEGIN CERTIFICATE----- MIIDBjCCAm+gAwIBAgIBATANBgkqhkiG9w0BAQQFADB7MQswCQYDVQQGEwJTRzER MA8GA1UEChMITTJDcnlwdG8xFDASBgNVBAsTC00yQ3J5cHRvIENBMSQwIgYDVQQD ExtNMkNyeXB0byBDZXJ0aWZpY2F0ZSBNYXN0ZXIxHTAbBgkqhkiG9w0BCQEWDm5n cHNAcG9zdDEuY29tMB4XDTAwMDkxMDA5NTEzMFoXDTAyMDkxMDA5NTEzMFowUzEL MAkGA1UEBhMCU0cxETAPBgNVBAoTCE0yQ3J5cHRvMRIwEAYDVQQDEwlsb2NhbGhv c3QxHTAbBgkqhkiG9w0BCQEWDm5ncHNAcG9zdDEuY29tMFwwDQYJKoZIhvcNAQEB BQADSwAwSAJBAKy+e3dulvXzV7zoTZWc5TzgApr8DmeQHTYC8ydfzH7EECe4R1Xh 5kwIzOuuFfn178FBiS84gngaNcrFi0Z5fAkCAwEAAaOCAQQwggEAMAkGA1UdEwQC MAAwLAYJYIZIAYb4QgENBB8WHU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRl MB0GA1UdDgQWBBTPhIKSvnsmYsBVNWjj0m3M2z0qVTCBpQYDVR0jBIGdMIGagBT7 hyNp65w6kxXlxb8pUU/+7Sg4AaF/pH0wezELMAkGA1UEBhMCU0cxETAPBgNVBAoT CE0yQ3J5cHRvMRQwEgYDVQQLEwtNMkNyeXB0byBDQTEkMCIGA1UEAxMbTTJDcnlw dG8gQ2VydGlmaWNhdGUgTWFzdGVyMR0wGwYJKoZIhvcNAQkBFg5uZ3BzQHBvc3Qx LmNvbYIBADANBgkqhkiG9w0BAQQFAAOBgQA7/CqT6PoHycTdhEStWNZde7M/2Yc6 BoJuVwnW8YxGO8Sn6UJ4FeffZNcYZddSDKosw8LtPOeWoK3JINjAk5jiPQ2cww++ 7QGG/g5NDjxFZNDJP1dGiLAxPW6JXwov4v0FmdzfLOZ01jDcgQQZqEpYlgpuI5JE WUQ9Ho4EzbYCOQ== -----END CERTIFICATE----- -----BEGIN RSA PRIVATE KEY----- MIIBPAIBAAJBAKy+e3dulvXzV7zoTZWc5TzgApr8DmeQHTYC8ydfzH7EECe4R1Xh 5kwIzOuuFfn178FBiS84gngaNcrFi0Z5fAkCAwEAAQJBAIqm/bz4NA1H++Vx5Ewx OcKp3w19QSaZAwlGRtsUxrP7436QjnREM3Bm8ygU11BjkPVmtrKm6AayQfCHqJoT ZIECIQDW0BoMoL0HOYM/mrTLhaykYAVqgIeJsPjvkEhTFXWBuQIhAM3deFAvWNu4 nklUQ37XsCT2c9tmNt1LAT+slG2JOTTRAiAuXDtC/m3NYVwyHfFm+zKHRzHkClk2 HjubeEgjpj32AQIhAJqMGTaZVOwevTXvvHwNEH+vRWsAYU/gbx+OQB+7VOcBAiEA oolb6NMg/R3enNPvS1O4UU1H8wpaF77L4yiSWlE0p4w= -----END RSA PRIVATE KEY----- -----BEGIN CERTIFICATE REQUEST----- MIIBDTCBuAIBADBTMQswCQYDVQQGEwJTRzERMA8GA1UEChMITTJDcnlwdG8xEjAQ BgNVBAMTCWxvY2FsaG9zdDEdMBsGCSqGSIb3DQEJARYObmdwc0Bwb3N0MS5jb20w XDANBgkqhkiG9w0BAQEFAANLADBIAkEArL57d26W9fNXvOhNlZzlPOACmvwOZ5Ad NgLzJ1/MfsQQJ7hHVeHmTAjM664V+fXvwUGJLziCeBo1ysWLRnl8CQIDAQABoAAw DQYJKoZIhvcNAQEEBQADQQA7uqbrNTjVWpF6By5ZNPvhZ4YdFgkeXFVWi5ao/TaP Vq4BG021fJ9nlHRtr4rotpgHDX1rr+iWeHKsx4+5DRSy -----END CERTIFICATE REQUEST----- calendarserver-5.2+dfsg/twext/web2/test/test_server.py0000644000175000017500000007616012165665515022245 0ustar rahulrahul# Copyright (c) 2001-2007 Twisted Matrix Laboratories. # See LICENSE for details. """ A test harness for the twext.web2 server. """ from zope.interface import implementer from twisted.python import components from twext.web2 import http, http_headers, iweb, server from twext.web2 import resource, stream from twext.web2.dav.test.util import SimpleRequest from twisted.trial import unittest from twisted.internet import reactor, defer, address class NotResource(object): """ Class which does not implement IResource. Used as an adaptee by L{AdaptionTestCase.test_registered} to test that if an object which does not provide IResource is adapted to IResource and there is an adapter to IResource registered, that adapter is used. """ @implementer(iweb.IResource) class ResourceAdapter(object): """ Adapter to IResource. Registered as an adapter from NotResource to IResource so that L{AdaptionTestCase.test_registered} can test that such an adapter will be used. """ def __init__(self, original): pass components.registerAdapter(ResourceAdapter, NotResource, iweb.IResource) class NotOldResource(object): """ Class which does not implement IOldNevowResource or IResource. Used as an adaptee by L{AdaptionTestCase.test_transitive} to test that if an object which does not provide IResource or IOldNevowResource is adapted to IResource and there is an adapter to IOldNevowResource registered, first that adapter is used, then the included adapter from IOldNevowResource to IResource is used. """ @implementer(iweb.IOldNevowResource) class OldResourceAdapter(object): """ Adapter to IOldNevowResource. Registered as an adapter from NotOldResource to IOldNevowResource so that L{AdaptionTestCase.test_transitive} can test that such an adapter will be used to allow the initial input to be adapted to IResource. """ def __init__(self, original): pass components.registerAdapter(OldResourceAdapter, NotOldResource, iweb.IOldNevowResource) class AdaptionTestCase(unittest.TestCase): """ Test the adaption of various objects to IResource. Necessary due to the special implementation of __call__ on IResource which extends the behavior provided by the base Interface.__call__. """ def test_unadaptable(self): """ Test that attempting to adapt to IResource an object not adaptable to IResource raises an exception or returns the specified alternate object. """ class Unadaptable(object): pass self.assertRaises(TypeError, iweb.IResource, Unadaptable()) alternate = object() self.assertIdentical(iweb.IResource(Unadaptable(), alternate), alternate) def test_redundant(self): """ Test that the adaption to IResource of an object which provides IResource returns the same object. """ @implementer(iweb.IResource) class Resource(object): "" resource = Resource() self.assertIdentical(iweb.IResource(resource), resource) def test_registered(self): """ Test that if an adapter exists which can provide IResource for an object which does not provide it, that adapter is used. """ notResource = NotResource() self.failUnless(isinstance(iweb.IResource(notResource), ResourceAdapter)) @implementer(iweb.IChanRequest) class TestChanRequest: hostInfo = address.IPv4Address('TCP', 'host', 80), False remoteHost = address.IPv4Address('TCP', 'remotehost', 34567) finished = False def __init__(self, site, method, prepath, uri, length=None, headers=None, version=(1,1), content=None): self.producer = None self.site = site self.method = method self.prepath = prepath self.uri = uri if headers is None: headers = http_headers.Headers() self.headers = headers self.http_version = version # Anything below here we do not pass as arguments self.request = server.Request(self, self.method, self.uri, self.http_version, length, self.headers, site=self.site, prepathuri=self.prepath) if content is not None: self.request.handleContentChunk(content) self.request.handleContentComplete() self.code = None self.responseHeaders = None self.data = '' self.deferredFinish = defer.Deferred() def writeIntermediateResponse(code, headers=None): pass def writeHeaders(self, code, headers): self.responseHeaders = headers self.code = code def write(self, data): self.data += data def finish(self, failed=False): result = self.code, self.responseHeaders, self.data, failed self.finished = True self.deferredFinish.callback(result) def abortConnection(self): self.finish(failed=True) def registerProducer(self, producer, streaming): if self.producer is not None: raise ValueError("Producer still set: " + repr(self.producer)) self.producer = producer def unregisterProducer(self): self.producer = None def getHostInfo(self): return self.hostInfo def getRemoteHost(self): return self.remoteHost class BaseTestResource(resource.Resource): responseCode = 200 responseText = 'This is a fake resource.' responseHeaders = {} addSlash = False def __init__(self, children=[]): """ @type children: C{list} of C{tuple} @param children: a list of ('path', resource) tuples """ for i in children: self.putChild(i[0], i[1]) def render(self, req): return http.Response(self.responseCode, headers=self.responseHeaders, stream=self.responseStream()) def responseStream(self): return stream.MemoryStream(self.responseText) class MyRenderError(Exception): "" class ErrorWithProducerResource(BaseTestResource): addSlash = True def render(self, req): req.chanRequest.registerProducer(object(), None) return defer.fail(MyRenderError()) def child_(self, request): return self _unset = object() class BaseCase(unittest.TestCase): """ Base class for test cases that involve testing the result of arbitrary HTTP(S) queries. """ method = 'GET' version = (1, 1) wait_timeout = 5.0 def chanrequest(self, root, uri, length, headers, method, version, prepath, content): site = server.Site(root) return TestChanRequest(site, method, prepath, uri, length, headers, version, content) def getResponseFor(self, root, uri, headers={}, method=None, version=None, prepath='', content=None, length=_unset): if not isinstance(headers, http_headers.Headers): headers = http_headers.Headers(headers) if length is _unset: if content is not None: length = len(content) else: length = 0 if method is None: method = self.method if version is None: version = self.version cr = self.chanrequest(root, uri, length, headers, method, version, prepath, content) cr.request.process() return cr.deferredFinish def assertResponse(self, request_data, expected_response, failure=False): """ @type request_data: C{tuple} @type expected_response: C{tuple} @param request_data: A tuple of arguments to pass to L{getResponseFor}: (root, uri, headers, method, version, prepath). Root resource and requested URI are required, and everything else is optional. @param expected_response: A 3-tuple of the expected response: (responseCode, headers, htmlData) """ d = self.getResponseFor(*request_data) d.addCallback(self._cbGotResponse, expected_response, failure) return d def _cbGotResponse(self, (code, headers, data, failed), expected_response, expectedfailure=False): expected_code, expected_headers, expected_data = expected_response self.assertEquals(code, expected_code) if expected_data is not None: self.assertEquals(data, expected_data) for key, value in expected_headers.iteritems(): self.assertEquals(headers.getHeader(key), value) self.assertEquals(failed, expectedfailure) class ErrorHandlingTest(BaseCase): """ Tests for error handling. """ def test_processingReallyReallyReallyFailed(self): """ The HTTP connection will be shut down if there's really no way to relay any useful information about the error to the HTTP client. """ root = ErrorWithProducerResource() site = server.Site(root) tcr = TestChanRequest(site, "GET", "/", "http://localhost/") request = server.Request(tcr, "GET", "/", (1, 1), 0, http_headers.Headers( {"host": "localhost"}), site=site) proc = request.process() done = [] proc.addBoth(done.append) self.assertEquals(done, [None]) errs = self.flushLoggedErrors(ValueError) self.assertIn('producer', str(errs[0]).lower()) errs = self.flushLoggedErrors(MyRenderError) self.assertEquals(bool(errs), True) self.assertEquals(tcr.finished, True) class SampleWebTest(BaseCase): class SampleTestResource(BaseTestResource): addSlash = True def child_validChild(self, req): f = BaseTestResource() f.responseCode = 200 f.responseText = 'This is a valid child resource.' return f def child_missingChild(self, req): f = BaseTestResource() f.responseCode = 404 f.responseStream = lambda self: None return f def child_remoteAddr(self, req): f = BaseTestResource() f.responseCode = 200 f.responseText = 'Remote Addr: %r' % req.remoteAddr.host return f def setUp(self): self.root = self.SampleTestResource() def test_root(self): return self.assertResponse( (self.root, 'http://host/'), (200, {}, 'This is a fake resource.')) def test_validChild(self): return self.assertResponse( (self.root, 'http://host/validChild'), (200, {}, 'This is a valid child resource.')) def test_invalidChild(self): return self.assertResponse( (self.root, 'http://host/invalidChild'), (404, {}, None)) def test_remoteAddrExposure(self): return self.assertResponse( (self.root, 'http://host/remoteAddr'), (200, {}, "Remote Addr: 'remotehost'")) def test_leafresource(self): class TestResource(resource.LeafResource): def render(self, req): return http.Response(stream="prepath:%s postpath:%s" % ( req.prepath, req.postpath)) return self.assertResponse( (TestResource(), 'http://host/consumed/path/segments'), (200, {}, "prepath:[] postpath:['consumed', 'path', 'segments']")) def test_redirectResource(self): """ Make sure a redirect response has the correct status and Location header. """ redirectResource = resource.RedirectResource(scheme='https', host='localhost', port=443, path='/foo', querystring='bar=baz') return self.assertResponse( (redirectResource, 'http://localhost/'), (301, {'location': 'https://localhost/foo?bar=baz'}, None)) def test_redirectResourceWithSchemeRemapping(self): """ Make sure a redirect response has the correct status and Location header, when SSL is on, and the client request uses scheme http with the SSL port. """ def chanrequest2(root, uri, length, headers, method, version, prepath, content): site = server.Site(root) site.EnableSSL = True site.SSLPort = 8443 site.BindSSLPorts = [] return TestChanRequest(site, method, prepath, uri, length, headers, version, content) self.patch(self, "chanrequest", chanrequest2) redirectResource = resource.RedirectResource(path='/foo') return self.assertResponse( (redirectResource, 'http://localhost:8443/'), (301, {'location': 'https://localhost:8443/foo'}, None)) def test_redirectResourceWithoutSchemeRemapping(self): """ Make sure a redirect response has the correct status and Location header, when SSL is on, and the client request uses scheme http with the non-SSL port. """ def chanrequest2(root, uri, length, headers, method, version, prepath, content): site = server.Site(root) site.EnableSSL = True site.SSLPort = 8443 site.BindSSLPorts = [] return TestChanRequest(site, method, prepath, uri, length, headers, version, content) self.patch(self, "chanrequest", chanrequest2) redirectResource = resource.RedirectResource(path='/foo') return self.assertResponse( (redirectResource, 'http://localhost:8008/'), (301, {'location': 'http://localhost:8008/foo'}, None)) def test_redirectResourceWithoutSSLSchemeRemapping(self): """ Make sure a redirect response has the correct status and Location header, when SSL is off, and the client request uses scheme http with the SSL port. """ def chanrequest2(root, uri, length, headers, method, version, prepath, content): site = server.Site(root) site.EnableSSL = False site.SSLPort = 8443 site.BindSSLPorts = [] return TestChanRequest(site, method, prepath, uri, length, headers, version, content) self.patch(self, "chanrequest", chanrequest2) redirectResource = resource.RedirectResource(path='/foo') return self.assertResponse( (redirectResource, 'http://localhost:8443/'), (301, {'location': 'http://localhost:8443/foo'}, None)) class URLParsingTest(BaseCase): class TestResource(resource.LeafResource): def render(self, req): return http.Response(stream="Host:%s, Path:%s"%(req.host, req.path)) def setUp(self): self.root = self.TestResource() def test_normal(self): return self.assertResponse( (self.root, '/path', {'Host':'host'}), (200, {}, 'Host:host, Path:/path')) def test_fullurl(self): return self.assertResponse( (self.root, 'http://host/path'), (200, {}, 'Host:host, Path:/path')) def test_strangepath(self): # Ensure that the double slashes don't confuse it return self.assertResponse( (self.root, '//path', {'Host':'host'}), (200, {}, 'Host:host, Path://path')) def test_strangepathfull(self): return self.assertResponse( (self.root, 'http://host//path'), (200, {}, 'Host:host, Path://path')) class TestDeferredRendering(BaseCase): class ResourceWithDeferreds(BaseTestResource): addSlash=True responseText = 'I should be wrapped in a Deferred.' def render(self, req): d = defer.Deferred() reactor.callLater( 0, d.callback, BaseTestResource.render(self, req)) return d def child_deferred(self, req): d = defer.Deferred() reactor.callLater(0, d.callback, BaseTestResource()) return d def test_deferredRootResource(self): return self.assertResponse( (self.ResourceWithDeferreds(), 'http://host/'), (200, {}, 'I should be wrapped in a Deferred.')) def test_deferredChild(self): return self.assertResponse( (self.ResourceWithDeferreds(), 'http://host/deferred'), (200, {}, 'This is a fake resource.')) class RedirectResourceTest(BaseCase): def html(url): return "Moved Permanently

Moved Permanently

Document moved to %s.

" % (url,) html = staticmethod(html) def test_noRedirect(self): # This is useless, since it's a loop, but hey ds = [] for url in ("http://host/", "http://host/foo"): ds.append(self.assertResponse( (resource.RedirectResource(), url), (301, {"location": url}, self.html(url)) )) return defer.DeferredList(ds, fireOnOneErrback=True) def test_hostRedirect(self): ds = [] for url1, url2 in ( ("http://host/", "http://other/"), ("http://host/foo", "http://other/foo"), ): ds.append(self.assertResponse( (resource.RedirectResource(host="other"), url1), (301, {"location": url2}, self.html(url2)) )) return defer.DeferredList(ds, fireOnOneErrback=True) def test_pathRedirect(self): root = BaseTestResource() redirect = resource.RedirectResource(path="/other") root.putChild("r", redirect) ds = [] for url1, url2 in ( ("http://host/r", "http://host/other"), ("http://host/r/foo", "http://host/other"), ): ds.append(self.assertResponse( (resource.RedirectResource(path="/other"), url1), (301, {"location": url2}, self.html(url2)) )) return defer.DeferredList(ds, fireOnOneErrback=True) class EmptyResource(resource.Resource): def __init__(self, test): self.test = test def render(self, request): self.test.assertEquals(request.urlForResource(self), self.expectedURI) return 201 class RememberURIs(BaseCase): """ Tests for URI memory and lookup mechanism in server.Request. """ def test_requestedResource(self): """ Test urlForResource() on deeply nested resource looked up via request processing. """ root = EmptyResource(self) root.expectedURI = "/" foo = EmptyResource(self) foo.expectedURI = "/foo" root.putChild("foo", foo) bar = EmptyResource(self) bar.expectedURI = foo.expectedURI + "/bar" foo.putChild("bar", bar) baz = EmptyResource(self) baz.expectedURI = bar.expectedURI + "/baz" bar.putChild("baz", baz) ds = [] for uri in (foo.expectedURI, bar.expectedURI, baz.expectedURI): ds.append(self.assertResponse( (root, uri, {'Host':'host'}), (201, {}, None), )) return defer.DeferredList(ds, fireOnOneErrback=True) def test_urlEncoding(self): """ Test to make sure that URL encoding is working. """ root = EmptyResource(self) root.expectedURI = "/" child = EmptyResource(self) child.expectedURI = "/foo%20bar" root.putChild("foo bar", child) return self.assertResponse( (root, child.expectedURI, {'Host':'host'}), (201, {}, None) ) def test_locateResource(self): """ Test urlForResource() on resource looked up via a locateResource() call. """ root = resource.Resource() child = resource.Resource() root.putChild("foo", child) request = SimpleRequest(server.Site(root), "GET", "/") def gotResource(resource): self.assertEquals("/foo", request.urlForResource(resource)) d = defer.maybeDeferred(request.locateResource, "/foo") d.addCallback(gotResource) return d def test_unknownResource(self): """ Test urlForResource() on unknown resource. """ root = resource.Resource() child = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/") self.assertRaises(server.NoURLForResourceError, request.urlForResource, child) def test_locateChildResource(self): """ Test urlForResource() on deeply nested resource looked up via locateChildResource(). """ root = EmptyResource(self) root.expectedURI = "/" foo = EmptyResource(self) foo.expectedURI = "/foo" root.putChild("foo", foo) bar = EmptyResource(self) bar.expectedURI = "/foo/bar" foo.putChild("bar", bar) baz = EmptyResource(self) baz.expectedURI = "/foo/bar/b%20a%20z" bar.putChild("b a z", baz) request = SimpleRequest(server.Site(root), "GET", "/") def gotResource(resource): # Make sure locateChildResource() gave us the right answer self.assertEquals(resource, bar) return request.locateChildResource(resource, "b a z").addCallback(gotChildResource) def gotChildResource(resource): # Make sure locateChildResource() gave us the right answer self.assertEquals(resource, baz) self.assertEquals(resource.expectedURI, request.urlForResource(resource)) d = request.locateResource(bar.expectedURI) d.addCallback(gotResource) return d def test_deferredLocateChild(self): """ Test deferred value from locateChild() """ class DeferredLocateChild(resource.Resource): def locateChild(self, req, segments): return defer.maybeDeferred( super(DeferredLocateChild, self).locateChild, req, segments ) root = DeferredLocateChild() child = resource.Resource() root.putChild("foo", child) request = SimpleRequest(server.Site(root), "GET", "/foo") def gotResource(resource): self.assertEquals("/foo", request.urlForResource(resource)) d = request.locateResource("/foo") d.addCallback(gotResource) return d class ParsePostDataTests(unittest.TestCase): """ Tests for L{server.parsePOSTData}. """ def test_noData(self): """ Parsing a request without data should succeed but should not fill the C{args} and C{files} attributes of the request. """ root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/") def cb(ign): self.assertEquals(request.args, {}) self.assertEquals(request.files, {}) return server.parsePOSTData(request).addCallback(cb) def test_noContentType(self): """ Parsing a request without content-type should succeed but should not fill the C{args} and C{files} attributes of the request. """ root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", content="foo") def cb(ign): self.assertEquals(request.args, {}) self.assertEquals(request.files, {}) return server.parsePOSTData(request).addCallback(cb) def test_urlencoded(self): """ Test parsing data in urlencoded format: it should end in the C{args} attribute. """ ctype = http_headers.MimeType('application', 'x-www-form-urlencoded') content = "key=value&multiple=two+words&multiple=more%20words" root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", http_headers.Headers({'content-type': ctype}), content) def cb(ign): self.assertEquals(request.files, {}) self.assertEquals(request.args, {'multiple': ['two words', 'more words'], 'key': ['value']}) return server.parsePOSTData(request).addCallback(cb) def test_multipart(self): """ Test parsing data in multipart format: it should fill the C{files} attribute. """ ctype = http_headers.MimeType('multipart', 'form-data', (('boundary', '---weeboundary'),)) content="""-----weeboundary\r Content-Disposition: form-data; name="FileNameOne"; filename="myfilename"\r Content-Type: text/html\r \r my great content wooo\r -----weeboundary--\r """ root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", http_headers.Headers({'content-type': ctype}), content) def cb(ign): self.assertEquals(request.args, {}) self.assertEquals(request.files.keys(), ['FileNameOne']) self.assertEquals(request.files.values()[0][0][:2], ('myfilename', http_headers.MimeType('text', 'html', {}))) f = request.files.values()[0][0][2] self.assertEquals(f.read(), "my great content wooo") return server.parsePOSTData(request).addCallback(cb) def test_multipartWithNoBoundary(self): """ If the boundary type is not specified, parsing should fail with a C{http.HTTPError}. """ ctype = http_headers.MimeType('multipart', 'form-data') content="""-----weeboundary\r Content-Disposition: form-data; name="FileNameOne"; filename="myfilename"\r Content-Type: text/html\r \r my great content wooo\r -----weeboundary--\r """ root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", http_headers.Headers({'content-type': ctype}), content) return self.assertFailure(server.parsePOSTData(request), http.HTTPError) def test_wrongContentType(self): """ Check that a content-type not handled raise a C{http.HTTPError}. """ ctype = http_headers.MimeType('application', 'foobar') content = "key=value&multiple=two+words&multiple=more%20words" root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", http_headers.Headers({'content-type': ctype}), content) return self.assertFailure(server.parsePOSTData(request), http.HTTPError) def test_mimeParsingError(self): """ A malformed content should result in a C{http.HTTPError}. The tested content has an invalid closing boundary. """ ctype = http_headers.MimeType('multipart', 'form-data', (('boundary', '---weeboundary'),)) content="""-----weeboundary\r Content-Disposition: form-data; name="FileNameOne"; filename="myfilename"\r Content-Type: text/html\r \r my great content wooo\r -----weeoundary--\r """ root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", http_headers.Headers({'content-type': ctype}), content) return self.assertFailure(server.parsePOSTData(request), http.HTTPError) def test_multipartMaxMem(self): """ Check that the C{maxMem} parameter makes the parsing raise an exception if the value is reached. """ ctype = http_headers.MimeType('multipart', 'form-data', (('boundary', '---weeboundary'),)) content="""-----weeboundary\r Content-Disposition: form-data; name="FileNameOne"\r Content-Type: text/html\r \r my great content wooo and even more and more\r -----weeboundary--\r """ root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", http_headers.Headers({'content-type': ctype}), content) def cb(res): self.assertEquals(res.response.description, "Maximum length of 10 bytes exceeded.") return self.assertFailure(server.parsePOSTData(request, maxMem=10), http.HTTPError).addCallback(cb) def test_multipartMaxSize(self): """ Check that the C{maxSize} parameter makes the parsing raise an exception if the data is too big. """ ctype = http_headers.MimeType('multipart', 'form-data', (('boundary', '---weeboundary'),)) content="""-----weeboundary\r Content-Disposition: form-data; name="FileNameOne"; filename="myfilename"\r Content-Type: text/html\r \r my great content wooo and even more and more\r -----weeboundary--\r """ root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", http_headers.Headers({'content-type': ctype}), content) def cb(res): self.assertEquals(res.response.description, "Maximum length of 10 bytes exceeded.") return self.assertFailure(server.parsePOSTData(request, maxSize=10), http.HTTPError).addCallback(cb) def test_maxFields(self): """ Check that the C{maxSize} parameter makes the parsing raise an exception if the data contains too many fields. """ ctype = http_headers.MimeType('multipart', 'form-data', (('boundary', '---xyz'),)) content = """-----xyz\r Content-Disposition: form-data; name="foo"\r \r Foo Bar\r -----xyz\r Content-Disposition: form-data; name="foo"\r \r Baz\r -----xyz\r Content-Disposition: form-data; name="file"; filename="filename"\r Content-Type: text/html\r \r blah\r -----xyz\r Content-Disposition: form-data; name="file"; filename="filename"\r Content-Type: text/plain\r \r bleh\r -----xyz--\r """ root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", http_headers.Headers({'content-type': ctype}), content) def cb(res): self.assertEquals(res.response.description, "Maximum number of fields 3 exceeded") return self.assertFailure(server.parsePOSTData(request, maxFields=3), http.HTTPError).addCallback(cb) def test_otherErrors(self): """ Test that errors durign parsing other than C{MimeFormatError} are propagated. """ ctype = http_headers.MimeType('multipart', 'form-data', (('boundary', '---weeboundary'),)) # XXX: maybe this is not a good example # parseContentDispositionFormData could handle this problem content="""-----weeboundary\r Content-Disposition: form-data; name="FileNameOne"; filename="myfilename and invalid data \r -----weeboundary--\r """ root = resource.Resource() request = SimpleRequest(server.Site(root), "GET", "/", http_headers.Headers({'content-type': ctype}), content) return self.assertFailure(server.parsePOSTData(request), ValueError) calendarserver-5.2+dfsg/twext/web2/test/__init__.py0000644000175000017500000000024211337102650021405 0ustar rahulrahul# Copyright (c) 2001-2006 Twisted Matrix Laboratories. # See LICENSE for details. """ twext.web2.test: unittests for the Twext.Web2, Web Server Framework """ calendarserver-5.2+dfsg/twext/web2/fileupload.py0000644000175000017500000003212712263343324021026 0ustar rahulrahul## # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## from __future__ import print_function import re from zope.interface import implements import urllib import tempfile from twisted.internet import defer from twext.web2.stream import IStream, FileStream, BufferedStream, readStream from twext.web2.stream import generatorToStream, readAndDiscard from twext.web2 import http_headers from cStringIO import StringIO ################################### ##### Multipart MIME Reader ##### ################################### class MimeFormatError(Exception): pass # parseContentDispositionFormData is absolutely horrible, but as # browsers don't seem to believe in sensible quoting rules, it's # really the only way to handle the header. (Quotes can be in the # filename, unescaped) cd_regexp = re.compile( ' *form-data; *name="([^"]*)"(?:; *filename="(.*)")?$', re.IGNORECASE) def parseContentDispositionFormData(value): match = cd_regexp.match(value) if not match: # Error parsing. raise ValueError("Unknown content-disposition format.") name=match.group(1) filename=match.group(2) return name, filename #@defer.deferredGenerator def _readHeaders(stream): """Read the MIME headers. Assumes we've just finished reading in the boundary string.""" ctype = fieldname = filename = None headers = [] # Now read headers while 1: line = stream.readline(size=1024) if isinstance(line, defer.Deferred): line = defer.waitForDeferred(line) yield line line = line.getResult() #print("GOT", line) if not line.endswith('\r\n'): if line == "": raise MimeFormatError("Unexpected end of stream.") else: raise MimeFormatError("Header line too long") line = line[:-2] # strip \r\n if line == "": break # End of headers parts = line.split(':', 1) if len(parts) != 2: raise MimeFormatError("Header did not have a :") name, value = parts name = name.lower() headers.append((name, value)) if name == "content-type": ctype = http_headers.parseContentType(http_headers.tokenize((value,), foldCase=False)) elif name == "content-disposition": fieldname, filename = parseContentDispositionFormData(value) if ctype is None: ctype == http_headers.MimeType('application', 'octet-stream') if fieldname is None: raise MimeFormatError('Content-disposition invalid or omitted.') # End of headers, return (field name, content-type, filename) yield fieldname, filename, ctype return _readHeaders = defer.deferredGenerator(_readHeaders) class _BoundaryWatchingStream(object): def __init__(self, stream, boundary): self.stream = stream self.boundary = boundary self.data = '' self.deferred = defer.Deferred() length = None # unknown def read(self): if self.stream is None: if self.deferred is not None: deferred = self.deferred self.deferred = None deferred.callback(None) return None newdata = self.stream.read() if isinstance(newdata, defer.Deferred): return newdata.addCallbacks(self._gotRead, self._gotError) return self._gotRead(newdata) def _gotRead(self, newdata): if not newdata: raise MimeFormatError("Unexpected EOF") # BLECH, converting buffer back into string. self.data += str(newdata) data = self.data boundary = self.boundary off = data.find(boundary) if off == -1: # No full boundary, check for the first character off = data.rfind(boundary[0], max(0, len(data)-len(boundary))) if off != -1: # We could have a partial boundary, store it for next time self.data = data[off:] return data[:off] else: self.data = '' return data else: self.stream.pushback(data[off+len(boundary):]) self.stream = None return data[:off] def _gotError(self, err): # Propogate error back to MultipartMimeStream also if self.deferred is not None: deferred = self.deferred self.deferred = None deferred.errback(err) return err def close(self): # Assume error will be raised again and handled by MMS? readAndDiscard(self).addErrback(lambda _: None) class MultipartMimeStream(object): implements(IStream) def __init__(self, stream, boundary): self.stream = BufferedStream(stream) self.boundary = "--"+boundary self.first = True def read(self): """ Return a deferred which will fire with a tuple of: (fieldname, filename, ctype, dataStream) or None when all done. Format errors will be sent to the errback. Returns None when all done. IMPORTANT: you *must* exhaust dataStream returned by this call before calling .read() again! """ if self.first: self.first = False d = self._readFirstBoundary() else: d = self._readBoundaryLine() d.addCallback(self._doReadHeaders) d.addCallback(self._gotHeaders) return d def _readFirstBoundary(self): #print("_readFirstBoundary") line = self.stream.readline(size=1024) if isinstance(line, defer.Deferred): line = defer.waitForDeferred(line) yield line line = line.getResult() if line != self.boundary + '\r\n': raise MimeFormatError("Extra data before first boundary: %r looking for: %r" % (line, self.boundary + '\r\n')) self.boundary = "\r\n"+self.boundary yield True return _readFirstBoundary = defer.deferredGenerator(_readFirstBoundary) def _readBoundaryLine(self): #print("_readBoundaryLine") line = self.stream.readline(size=1024) if isinstance(line, defer.Deferred): line = defer.waitForDeferred(line) yield line line = line.getResult() if line == "--\r\n": # THE END! yield False return elif line != "\r\n": raise MimeFormatError("Unexpected data on same line as boundary: %r" % (line,)) yield True return _readBoundaryLine = defer.deferredGenerator(_readBoundaryLine) def _doReadHeaders(self, morefields): #print("_doReadHeaders", morefields) if not morefields: return None return _readHeaders(self.stream) def _gotHeaders(self, headers): if headers is None: return None bws = _BoundaryWatchingStream(self.stream, self.boundary) self.deferred = bws.deferred ret=list(headers) ret.append(bws) return tuple(ret) def readIntoFile(stream, outFile, maxlen): """Read the stream into a file, but not if it's longer than maxlen. Returns Deferred which will be triggered on finish. """ curlen = [0] def done(_): return _ def write(data): curlen[0] += len(data) if curlen[0] > maxlen: raise MimeFormatError("Maximum length of %d bytes exceeded." % maxlen) outFile.write(data) return readStream(stream, write).addBoth(done) #@defer.deferredGenerator def parseMultipartFormData(stream, boundary, maxMem=100*1024, maxFields=1024, maxSize=10*1024*1024): # If the stream length is known to be too large upfront, abort immediately if stream.length is not None and stream.length > maxSize: raise MimeFormatError("Maximum length of %d bytes exceeded." % maxSize) mms = MultipartMimeStream(stream, boundary) numFields = 0 args = {} files = {} while 1: datas = mms.read() if isinstance(datas, defer.Deferred): datas = defer.waitForDeferred(datas) yield datas datas = datas.getResult() if datas is None: break numFields+=1 if numFields == maxFields: raise MimeFormatError("Maximum number of fields %d exceeded"%maxFields) # Parse data fieldname, filename, ctype, stream = datas if filename is None: # Not a file outfile = StringIO() maxBuf = min(maxSize, maxMem) else: outfile = tempfile.NamedTemporaryFile() maxBuf = maxSize x = readIntoFile(stream, outfile, maxBuf) if isinstance(x, defer.Deferred): x = defer.waitForDeferred(x) yield x x = x.getResult() if filename is None: # Is a normal form field outfile.seek(0) data = outfile.read() args.setdefault(fieldname, []).append(data) maxMem -= len(data) maxSize -= len(data) else: # Is a file upload maxSize -= outfile.tell() outfile.seek(0) files.setdefault(fieldname, []).append((filename, ctype, outfile)) yield args, files return parseMultipartFormData = defer.deferredGenerator(parseMultipartFormData) ################################### ##### x-www-urlencoded reader ##### ################################### def parse_urlencoded_stream(input, maxMem=100*1024, keep_blank_values=False, strict_parsing=False): lastdata = '' still_going=1 while still_going: try: yield input.wait data = input.next() except StopIteration: pairs = [lastdata] still_going=0 else: maxMem -= len(data) if maxMem < 0: raise MimeFormatError("Maximum length of %d bytes exceeded." % maxMem) pairs = str(data).split('&') pairs[0] = lastdata + pairs[0] lastdata=pairs.pop() for name_value in pairs: nv = name_value.split('=', 1) if len(nv) != 2: if strict_parsing: raise MimeFormatError("bad query field: %s") % `name_value` continue if len(nv[1]) or keep_blank_values: name = urllib.unquote(nv[0].replace('+', ' ')) value = urllib.unquote(nv[1].replace('+', ' ')) yield name, value parse_urlencoded_stream = generatorToStream(parse_urlencoded_stream) def parse_urlencoded(stream, maxMem=100*1024, maxFields=1024, keep_blank_values=False, strict_parsing=False): d = {} numFields = 0 s=parse_urlencoded_stream(stream, maxMem, keep_blank_values, strict_parsing) while 1: datas = s.read() if isinstance(datas, defer.Deferred): datas = defer.waitForDeferred(datas) yield datas datas = datas.getResult() if datas is None: break name, value = datas numFields += 1 if numFields == maxFields: raise MimeFormatError("Maximum number of fields %d exceeded"%maxFields) if name in d: d[name].append(value) else: d[name] = [value] yield d return parse_urlencoded = defer.deferredGenerator(parse_urlencoded) if __name__ == '__main__': d = parseMultipartFormData( FileStream(open("upload.txt")), "----------0xKhTmLbOuNdArY") from twext.python.log import Logger log = Logger() d.addErrback(log.err) def pr(s): print(s) d.addCallback(pr) __all__ = ['parseMultipartFormData', 'parse_urlencoded', 'parse_urlencoded_stream', 'MultipartMimeStream', 'MimeFormatError'] calendarserver-5.2+dfsg/twext/web2/http.py0000644000175000017500000004671212263343324017666 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_http -*- ## # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """HyperText Transfer Protocol implementation. The second coming. Maintainer: James Y Knight """ # import traceback; log.info(''.join(traceback.format_stack())) import json import time from twisted.internet import interfaces, error from twisted.python import components from twisted.web.template import Element, XMLString, renderer, flattenString from zope.interface import implements from twext.python.log import Logger from twext.web2 import responsecode from twext.web2 import http_headers from twext.web2 import iweb from twext.web2 import stream from twext.web2.stream import IByteStream, readAndDiscard log = Logger() defaultPortForScheme = {'http': 80, 'https': 443, 'ftp': 21} def splitHostPort(scheme, hostport): """Split the host in "host:port" format into host and port fields. If port was not specified, use the default for the given scheme, if known. Returns a tuple of (hostname, portnumber).""" # Split hostport into host and port hostport = hostport.split(':', 1) try: if len(hostport) == 2: return hostport[0], int(hostport[1]) except ValueError: pass return hostport[0], defaultPortForScheme.get(scheme, 0) def parseVersion(strversion): """Parse version strings of the form Protocol '/' Major '.' Minor. E.g. 'HTTP/1.1'. Returns (protocol, major, minor). Will raise ValueError on bad syntax.""" proto, strversion = strversion.split('/') major, minor = strversion.split('.') major, minor = int(major), int(minor) if major < 0 or minor < 0: raise ValueError("negative number") return (proto.lower(), major, minor) class HTTPError(Exception): def __init__(self, codeOrResponse): """An Exception for propagating HTTP Error Responses. @param codeOrResponse: The numeric HTTP code or a complete http.Response object. @type codeOrResponse: C{int} or L{http.Response} """ self.response = iweb.IResponse(codeOrResponse) Exception.__init__(self, str(self.response)) def __repr__(self): return "<%s %s>" % (self.__class__.__name__, self.response) class Response(object): """An object representing an HTTP Response to be sent to the client. """ implements(iweb.IResponse) code = responsecode.OK headers = None stream = None def __init__(self, code=None, headers=None, stream=None): """ @param code: The HTTP status code for this Response @type code: C{int} @param headers: Headers to be sent to the client. @type headers: C{dict}, L{twext.web2.http_headers.Headers}, or C{None} @param stream: Content body to send to the HTTP client @type stream: L{twext.web2.stream.IByteStream} """ if code is not None: self.code = int(code) if headers is not None: if isinstance(headers, dict): headers = http_headers.Headers(headers) self.headers = headers else: self.headers = http_headers.Headers() if stream is not None: self.stream = IByteStream(stream) def __repr__(self): if self.stream is None: streamlen = None else: streamlen = self.stream.length return "<%s.%s code=%d, streamlen=%s>" % (self.__module__, self.__class__.__name__, self.code, streamlen) class StatusResponseElement(Element): """ Render the HTML for a L{StatusResponse} """ loader = XMLString("""<t:slot name="title" />

""") def __init__(self, title, description): super(StatusResponseElement, self).__init__() self.title = title self.description = description @renderer def response(self, request, tag): """ Top-level renderer. """ return tag.fillSlots(title=self.title, description=self.description) class StatusResponse (Response): """ A L{Response} object which simply contains a status code and a description of what happened. """ def __init__(self, code, description, title=None): """ @param code: a response code in L{responsecode.RESPONSES}. @param description: a string description. @param title: the message title. If not specified or C{None}, defaults to C{responsecode.RESPONSES[code]}. """ if title is None: title = responsecode.RESPONSES[code] element = StatusResponseElement(title, description) out = [] flattenString(None, element).addCallback(out.append) mime_params = {"charset": "utf-8"} super(StatusResponse, self).__init__(code=code, stream=out[0]) self.headers.setHeader( "content-type", http_headers.MimeType("text", "html", mime_params) ) self.description = description def __repr__(self): return "<%s %s %s>" % (self.__class__.__name__, self.code, self.description) class RedirectResponse (StatusResponse): """ A L{Response} object that contains a redirect to another network location. """ def __init__(self, location, temporary=False): """ @param location: the URI to redirect to. @param temporary: whether it's a temporary redirect or permanent """ code = responsecode.TEMPORARY_REDIRECT if temporary else responsecode.MOVED_PERMANENTLY super(RedirectResponse, self).__init__( code, "Document moved to %s." % (location,) ) self.headers.setHeader("location", location) def NotModifiedResponse(oldResponse=None): if oldResponse is not None: headers = http_headers.Headers() for header in ( # Required from sec 10.3.5: 'date', 'etag', 'content-location', 'expires', 'cache-control', 'vary', # Others: 'server', 'proxy-authenticate', 'www-authenticate', 'warning'): value = oldResponse.headers.getRawHeaders(header) if value is not None: headers.setRawHeaders(header, value) else: headers = None return Response(code=responsecode.NOT_MODIFIED, headers=headers) def checkPreconditions(request, response=None, entityExists=True, etag=None, lastModified=None): """Check to see if this request passes the conditional checks specified by the client. May raise an HTTPError with result codes L{NOT_MODIFIED} or L{PRECONDITION_FAILED}, as appropriate. This function is called automatically as an output filter for GET and HEAD requests. With GET/HEAD, it is not important for the precondition check to occur before doing the action, as the method is non-destructive. However, if you are implementing other request methods, like PUT for your resource, you will need to call this after determining the etag and last-modified time of the existing resource but before actually doing the requested action. In that case, This examines the appropriate request headers for conditionals, (If-Modified-Since, If-Unmodified-Since, If-Match, If-None-Match, or If-Range), compares with the etag and last and and then sets the response code as necessary. @param response: This should be provided for GET/HEAD methods. If it is specified, the etag and lastModified arguments will be retrieved automatically from the response headers and shouldn't be separately specified. Not providing the response with a GET request may cause the emitted "Not Modified" responses to be non-conformant. @param entityExists: Set to False if the entity in question doesn't yet exist. Necessary for PUT support with 'If-None-Match: *'. @param etag: The etag of the resource to check against, or None. @param lastModified: The last modified date of the resource to check against, or None. @raise: HTTPError: Raised when the preconditions fail, in order to abort processing and emit an error page. """ if response: assert etag is None and lastModified is None # if the code is some sort of error code, don't do anything if not ((response.code >= 200 and response.code <= 299) or response.code == responsecode.PRECONDITION_FAILED): return False etag = response.headers.getHeader("etag") lastModified = response.headers.getHeader("last-modified") def matchETag(tags, allowWeak): if entityExists and '*' in tags: return True if etag is None: return False return ((allowWeak or not etag.weak) and ([etagmatch for etagmatch in tags if etag.match(etagmatch, strongCompare=not allowWeak)])) # First check if-match/if-unmodified-since # If either one fails, we return PRECONDITION_FAILED match = request.headers.getHeader("if-match") if match: if not matchETag(match, False): raise HTTPError(StatusResponse(responsecode.PRECONDITION_FAILED, "Requested resource does not have a matching ETag.")) unmod_since = request.headers.getHeader("if-unmodified-since") if unmod_since: if not lastModified or lastModified > unmod_since: raise HTTPError(StatusResponse(responsecode.PRECONDITION_FAILED, "Requested resource has changed.")) # Now check if-none-match/if-modified-since. # This bit is tricky, because of the requirements when both IMS and INM # are present. In that case, you can't return a failure code # unless *both* checks think it failed. # Also, if the INM check succeeds, ignore IMS, because INM is treated # as more reliable. # I hope I got the logic right here...the RFC is quite poorly written # in this area. Someone might want to verify the testcase against # RFC wording. # If IMS header is later than current time, ignore it. notModified = None ims = request.headers.getHeader('if-modified-since') if ims: notModified = (ims < time.time() and lastModified and lastModified <= ims) inm = request.headers.getHeader("if-none-match") if inm: if request.method in ("HEAD", "GET"): # If it's a range request, don't allow a weak ETag, as that # would break. canBeWeak = not request.headers.hasHeader('Range') if notModified != False and matchETag(inm, canBeWeak): raise HTTPError(NotModifiedResponse(response)) else: if notModified != False and matchETag(inm, False): raise HTTPError(StatusResponse(responsecode.PRECONDITION_FAILED, "Requested resource has a matching ETag.")) else: if notModified == True: raise HTTPError(NotModifiedResponse(response)) def checkIfRange(request, response): """Checks for the If-Range header, and if it exists, checks if the test passes. Returns true if the server should return partial data.""" ifrange = request.headers.getHeader("if-range") if ifrange is None: return True if isinstance(ifrange, http_headers.ETag): return ifrange.match(response.headers.getHeader("etag"), strongCompare=True) else: return ifrange == response.headers.getHeader("last-modified") class _NotifyingProducerStream(stream.ProducerStream): doStartReading = None def __init__(self, length=None, doStartReading=None): stream.ProducerStream.__init__(self, length=length) self.doStartReading = doStartReading def read(self): if self.doStartReading is not None: doStartReading = self.doStartReading self.doStartReading = None doStartReading() return stream.ProducerStream.read(self) def write(self, data): self.doStartReading = None stream.ProducerStream.write(self, data) def finish(self): self.doStartReading = None stream.ProducerStream.finish(self) # response codes that must have empty bodies NO_BODY_CODES = (responsecode.NO_CONTENT, responsecode.NOT_MODIFIED) class Request(object): """A HTTP request. Subclasses should override the process() method to determine how the request will be processed. @ivar method: The HTTP method that was used. @ivar uri: The full URI that was requested (includes arguments). @ivar headers: All received headers @ivar clientproto: client HTTP version @ivar stream: incoming data stream. """ implements(iweb.IRequest, interfaces.IConsumer) known_expects = ('100-continue',) def __init__(self, chanRequest, command, path, version, contentLength, headers): """ @param chanRequest: the channel request we're associated with. """ self.chanRequest = chanRequest self.method = command self.uri = path self.clientproto = version self.headers = headers if '100-continue' in self.headers.getHeader('expect', ()): doStartReading = self._sendContinue else: doStartReading = None self.stream = _NotifyingProducerStream(contentLength, doStartReading) self.stream.registerProducer(self.chanRequest, True) def checkExpect(self): """Ensure there are no expectations that cannot be met. Checks Expect header against self.known_expects.""" expects = self.headers.getHeader('expect', ()) for expect in expects: if expect not in self.known_expects: raise HTTPError(responsecode.EXPECTATION_FAILED) def process(self): """Called by channel to let you process the request. Can be overridden by a subclass to do something useful.""" pass def handleContentChunk(self, data): """Callback from channel when a piece of data has been received. Puts the data in .stream""" self.stream.write(data) def handleContentComplete(self): """Callback from channel when all data has been received. """ self.stream.unregisterProducer() self.stream.finish() def connectionLost(self, reason): """connection was lost""" pass def __repr__(self): return '<%s %s %s>' % (self.method, self.uri, self.clientproto) def _sendContinue(self): self.chanRequest.writeIntermediateResponse(responsecode.CONTINUE) def _reallyFinished(self, x): """We are finished writing data.""" self.chanRequest.finish() def _finished(self, x): """ We are finished writing data. But we need to check that we have also finished reading all data as we might have sent a, for example, 401 response before we read any data. To make sure that the stream/producer sequencing works properly we need to discard the remaining data in the request. """ if self.stream.length != 0: return readAndDiscard(self.stream).addCallback(self._reallyFinished).addErrback(self._error) else: self._reallyFinished(x) def _error(self, reason): if reason.check(error.ConnectionLost): log.info("Request error: {message}", message=reason.getErrorMessage()) else: log.failure("Request error", reason) # Only bother with cleanup on errors other than lost connection. self.chanRequest.abortConnection() def writeResponse(self, response): """ Write a response. """ if self.stream.doStartReading is not None: # Expect: 100-continue was requested, but 100 response has not been # sent, and there's a possibility that data is still waiting to be # sent. # # Ideally this means the remote side will not send any data. # However, because of compatibility requirements, it might timeout, # and decide to do so anyways at the same time we're sending back # this response. Thus, the read state is unknown after this. # We must close the connection. self.chanRequest.channel.setReadPersistent(False) # Nothing more will be read self.chanRequest.allContentReceived() if response.code != responsecode.NOT_MODIFIED: # Not modified response is *special* and doesn't get a content-length. if response.stream is None: response.headers.setHeader('content-length', 0) elif response.stream.length is not None: response.headers.setHeader('content-length', response.stream.length) self.chanRequest.writeHeaders(response.code, response.headers) # if this is a "HEAD" request, or a special response code, # don't return any data. if self.method == "HEAD" or response.code in NO_BODY_CODES: if response.stream is not None: response.stream.close() self._finished(None) return d = stream.StreamProducer(response.stream).beginProducing(self.chanRequest) d.addCallback(self._finished).addErrback(self._error) class XMLResponse (Response): """ XML L{Response} object. Renders itself as an XML document. """ def __init__(self, code, element): """ @param xml_responses: an iterable of davxml.Response objects. """ Response.__init__(self, code, stream=element.toxml()) self.headers.setHeader("content-type", http_headers.MimeType("text", "xml")) class JSONResponse (Response): """ JSON L{Response} object. Renders itself as an JSON document. """ def __init__(self, code, jobj): """ @param xml_responses: an iterable of davxml.Response objects. """ Response.__init__(self, code, stream=json.dumps(jobj)) self.headers.setHeader("content-type", http_headers.MimeType("application", "json")) components.registerAdapter(Response, int, iweb.IResponse) __all__ = ['HTTPError', 'NotModifiedResponse', 'Request', 'Response', 'StatusResponse', 'RedirectResponse', 'checkIfRange', 'checkPreconditions', 'defaultPortForScheme', 'parseVersion', 'splitHostPort', "XMLResponse", "JSONResponse"] calendarserver-5.2+dfsg/twext/web2/auth/0000755000175000017500000000000012322625325017264 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/auth/basic.py0000644000175000017500000000442512263343324020724 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_httpauth -*- ## # Copyright (c) 2006-2009 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## from twisted.cred import credentials, error from twisted.internet.defer import succeed, fail from twext.web2.auth.interfaces import ICredentialFactory from zope.interface import implements class BasicCredentialFactory(object): """ Credential Factory for HTTP Basic Authentication """ implements(ICredentialFactory) scheme = 'basic' def __init__(self, realm): self.realm = realm def getChallenge(self, peer): """ @see L{ICredentialFactory.getChallenge} """ return succeed({'realm': self.realm}) def decode(self, response, request): """ Decode the credentials for basic auth. @see L{ICredentialFactory.decode} """ try: creds = (response + '===').decode('base64') except: raise error.LoginFailed('Invalid credentials') creds = creds.split(':', 1) if len(creds) == 2: return succeed(credentials.UsernamePassword(*creds)) else: return fail(error.LoginFailed('Invalid credentials')) calendarserver-5.2+dfsg/twext/web2/auth/interfaces.py0000644000175000017500000000611112263343324021760 0ustar rahulrahul## # Copyright (c) 2004-2007 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## from zope.interface import Interface, Attribute class ICredentialFactory(Interface): """ A credential factory provides state between stages in HTTP authentication. It is ultimately in charge of creating an ICredential for the specified scheme, that will be used by cred to complete authentication. """ scheme = Attribute(("string indicating the authentication scheme " "this factory is associated with.")) def getChallenge(peer): """ Generate a challenge the client may respond to. @type peer: L{twisted.internet.interfaces.IAddress} @param peer: The client's address @rtype: C{dict} @return: Deferred returning dictionary of challenge arguments """ def decode(response, request): """ Create a credentials object from the given response. May raise twisted.cred.error.LoginFailed if the response is invalid. @type response: C{str} @param response: scheme specific response string @type request: L{twext.web2.server.Request} @param request: the request being processed @return: Deferred returning ICredentials """ class IAuthenticatedRequest(Interface): """ A request that has been authenticated with the use of Cred, and holds a reference to the avatar returned by portal.login """ avatarInterface = Attribute(("The credential interface implemented by " "the avatar")) avatar = Attribute("The application specific avatar returned by " "the application's realm") class IHTTPUser(Interface): """ A generic interface that can implemented by an avatar to provide access to the username used when authenticating. """ username = Attribute(("A string representing the username portion of " "the credentials used for authentication")) calendarserver-5.2+dfsg/twext/web2/auth/wrapper.py0000644000175000017500000002256612263343324021331 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_httpauth -*- ## # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ Wrapper Resources for rfc2617 HTTP Auth. """ from zope.interface import implements, directlyProvides from twisted.cred import error, credentials from twisted.internet.defer import gatherResults, succeed from twisted.python import failure from twext.web2 import responsecode from twext.web2 import http from twext.web2 import iweb from twext.web2.auth.interfaces import IAuthenticatedRequest class UnauthorizedResponse(http.StatusResponse): """A specialized response class for generating www-authenticate headers from the given L{CredentialFactory} instances """ def __init__(self): super(UnauthorizedResponse, self).__init__( responsecode.UNAUTHORIZED, "You are not authorized to access this resource.") def _generateHeaders(self, factories, remoteAddr=None): """ Set up the response's headers. @param factories: A L{dict} of {'scheme': ICredentialFactory} @param remoteAddr: An L{IAddress} for the connecting client. """ schemes = [] challengeDs = [] for factory in factories.itervalues(): schemes.append(factory.scheme) challengeDs.append(factory.getChallenge(remoteAddr)) def _setAuthHeader(challenges): authHeaders = zip(schemes, challenges) self.headers.setHeader('www-authenticate', authHeaders) return gatherResults(challengeDs).addCallback(_setAuthHeader) @classmethod def makeResponse(cls, factories, remoteAddr=None): """ Create an Unauthorized response. @param factories: A L{dict} of {'scheme': ICredentialFactory} @param remoteAddr: An L{IAddress} for the connecting client. @return: a Deferred that fires with the L{UnauthorizedResponse} instance. """ response = UnauthorizedResponse() d = response._generateHeaders(factories, remoteAddr) d.addCallback(lambda _:response) return d class HTTPAuthResource(object): """I wrap a resource to prevent it being accessed unless the authentication can be completed using the credential factory, portal, and interfaces specified. """ implements(iweb.IResource) def __init__(self, wrappedResource, credentialFactories, portal, interfaces): """ @param wrappedResource: A L{twext.web2.iweb.IResource} to be returned from locateChild and render upon successful authentication. @param credentialFactories: A list of instances that implement L{ICredentialFactory}. @type credentialFactories: L{list} @param portal: Portal to handle logins for this resource. @type portal: L{twisted.cred.portal.Portal} @param interfaces: the interfaces that are allowed to log in via the given portal @type interfaces: L{tuple} """ self.wrappedResource = wrappedResource self.credentialFactories = dict([(factory.scheme, factory) for factory in credentialFactories]) self.portal = portal self.interfaces = interfaces def _loginSucceeded(self, avatar, request): """ Callback for successful login. @param avatar: A tuple of the form (interface, avatar) as returned by your realm. @param request: L{IRequest} that encapsulates this auth attempt. @return: the IResource in C{self.wrappedResource} """ request.avatarInterface, request.avatar = avatar directlyProvides(request, IAuthenticatedRequest) def _addAuthenticateHeaders(request, response): """ A response filter that adds www-authenticate headers to an outgoing response if it's code is UNAUTHORIZED (401) and it does not already have them. """ if response.code == responsecode.UNAUTHORIZED: if not response.headers.hasHeader('www-authenticate'): d = UnauthorizedResponse.makeResponse( self.credentialFactories, request.remoteAddr) def _respond(newResp): response.headers.setHeader( 'www-authenticate', newResp.headers.getHeader('www-authenticate')) return response d.addCallback(_respond) return d return succeed(response) _addAuthenticateHeaders.handleErrors = True request.addResponseFilter(_addAuthenticateHeaders) return self.wrappedResource def _loginFailed(self, ignored, request): """ Errback for failed login. @param request: L{IRequest} that encapsulates this auth attempt. @return: A Deferred L{Failure} containing an L{HTTPError} containing the L{UnauthorizedResponse} if C{result} is an L{UnauthorizedLogin} or L{UnhandledCredentials} error """ d = UnauthorizedResponse.makeResponse(self.credentialFactories, request.remoteAddr) def _fail(response): return failure.Failure(http.HTTPError(response)) return d.addCallback(_fail) def login(self, factory, response, request): """ @param factory: An L{ICredentialFactory} that understands the given response. @param response: The client's authentication response as a string. @param request: The request that prompted this authentication attempt. @return: A L{Deferred} that fires with the wrappedResource on success or a failure containing an L{UnauthorizedResponse} """ d = factory.decode(response, request) def _decodeFailure(err): err.trap(error.LoginFailed) d = UnauthorizedResponse.makeResponse(self.credentialFactories, request.remoteAddr) def _respond(response): return failure.Failure(http.HTTPError(response)) return d.addCallback(_respond) def _login(creds): return self.portal.login(creds, None, *self.interfaces ).addCallbacks(self._loginSucceeded, self._loginFailed, (request,), None, (request,), None) return d.addErrback(_decodeFailure).addCallback(_login) def authenticate(self, request): """ Attempt to authenticate the given request @param request: An L{IRequest} to be authenticated. """ authHeader = request.headers.getHeader('authorization') if authHeader is None: return self.portal.login(credentials.Anonymous(), None, *self.interfaces ).addCallbacks(self._loginSucceeded, self._loginFailed, (request,), None, (request,), None) elif authHeader[0] not in self.credentialFactories: return self._loginFailed(None, request) else: return self.login(self.credentialFactories[authHeader[0]], authHeader[1], request) def locateChild(self, request, seg): """ Authenticate the request then return the C{self.wrappedResource} and the unmodified segments. """ return self.authenticate(request), seg def renderHTTP(self, request): """ Authenticate the request then return the result of calling renderHTTP on C{self.wrappedResource} """ def _renderResource(resource): return resource.renderHTTP(request) d = self.authenticate(request) d.addCallback(_renderResource) return d calendarserver-5.2+dfsg/twext/web2/auth/digest.py0000644000175000017500000001055012263343324021116 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_httpauth -*- ## # Copyright (c) 2006-2009 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ Implementation of RFC2617: HTTP Digest Authentication http://www.faqs.org/rfcs/rfc2617.html """ from zope.interface import implements from twisted.python.hashlib import md5, sha1 from twisted.cred import credentials # FIXME: Technically speaking - although you can't tell from looking at them - # these APIs are private, they're defined within twisted.cred._digest. There # should probably be some upstream bugs agains Twisted to more aggressively hide # implementation details like these if they're not supposed to be used, so we # can see the private-ness more clearly. The fix is really just to eliminate # this whole module though, and use the Twisted stuff via the public interface, # which should be sufficient to do digest auth. from twisted.cred.credentials import (calcHA1 as _origCalcHA1, calcResponse as _origCalcResponse, calcHA2 as _origCalcHA2) from twisted.internet.defer import maybeDeferred from twext.web2.auth.interfaces import ICredentialFactory # The digest math algorithms = { 'md5': md5, 'md5-sess': md5, 'sha': sha1, } # DigestCalcHA1 def calcHA1(pszAlg, pszUserName, pszRealm, pszPassword, pszNonce, pszCNonce, preHA1=None): """ @param pszAlg: The name of the algorithm to use to calculate the digest. Currently supported are md5 md5-sess and sha. @param pszUserName: The username @param pszRealm: The realm @param pszPassword: The password @param pszNonce: The nonce @param pszCNonce: The cnonce @param preHA1: If available this is a str containing a previously calculated HA1 as a hex string. If this is given then the values for pszUserName, pszRealm, and pszPassword are ignored. """ return _origCalcHA1(pszAlg, pszUserName, pszRealm, pszPassword, pszNonce, pszCNonce, preHA1) # DigestCalcResponse def calcResponse( HA1, algo, pszNonce, pszNonceCount, pszCNonce, pszQop, pszMethod, pszDigestUri, pszHEntity, ): return _origCalcResponse(HA1, _origCalcHA2(algo, pszMethod, pszDigestUri, pszQop, pszHEntity), algo, pszNonce, pszNonceCount, pszCNonce, pszQop) DigestedCredentials = credentials.DigestedCredentials class DigestCredentialFactory(object): implements(ICredentialFactory) CHALLENGE_LIFETIME_SECS = ( credentials.DigestCredentialFactory.CHALLENGE_LIFETIME_SECS ) def __init__(self, algorithm, realm): self._real = credentials.DigestCredentialFactory(algorithm, realm) scheme = 'digest' def getChallenge(self, peer): return maybeDeferred(self._real.getChallenge, peer.host) def generateOpaque(self, *a, **k): return self._real._generateOpaque(*a, **k) def verifyOpaque(self, opaque, nonce, clientip): return self._real._verifyOpaque(opaque, nonce, clientip) def decode(self, response, request): method = getattr(request, "originalMethod", request.method) host = request.remoteAddr.host return self._real.decode(response, method, host) calendarserver-5.2+dfsg/twext/web2/auth/__init__.py0000644000175000017500000000234612263343324021402 0ustar rahulrahul## # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ Client and server implementations of http authentication """ calendarserver-5.2+dfsg/twext/web2/server.py0000644000175000017500000006460012263343324020211 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_server -*- ## # Copyright (c) 2001-2008 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ This is a web-server which integrates with the twisted.internet infrastructure. """ from __future__ import print_function import cgi, time, urlparse from urllib import quote, unquote from urlparse import urlsplit import weakref from zope.interface import implements from twisted.internet import defer from twisted.python import failure from twext.python.log import Logger from twext.web2 import http, iweb, fileupload, responsecode from twext.web2 import http_headers from twext.web2.filter.range import rangefilter from twext.web2 import error from twext.web2 import __version__ as web2_version from twisted import __version__ as twisted_version VERSION = "Twisted/%s TwistedWeb/%s" % (twisted_version, web2_version) _errorMarker = object() log = Logger() def defaultHeadersFilter(request, response): if not response.headers.hasHeader('server'): response.headers.setHeader('server', VERSION) if not response.headers.hasHeader('date'): response.headers.setHeader('date', time.time()) return response defaultHeadersFilter.handleErrors = True def preconditionfilter(request, response): if request.method in ("GET", "HEAD"): http.checkPreconditions(request, response) return response def doTrace(request): request = iweb.IRequest(request) txt = "%s %s HTTP/%d.%d\r\n" % (request.method, request.uri, request.clientproto[0], request.clientproto[1]) l=[] for name, valuelist in request.headers.getAllRawHeaders(): for value in valuelist: l.append("%s: %s\r\n" % (name, value)) txt += ''.join(l) return http.Response( responsecode.OK, {'content-type': http_headers.MimeType('message', 'http')}, txt) def parsePOSTData(request, maxMem=100*1024, maxFields=1024, maxSize=10*1024*1024): """ Parse data of a POST request. @param request: the request to parse. @type request: L{twext.web2.http.Request}. @param maxMem: maximum memory used during the parsing of the data. @type maxMem: C{int} @param maxFields: maximum number of form fields allowed. @type maxFields: C{int} @param maxSize: maximum size of file upload allowed. @type maxSize: C{int} @return: a deferred that will fire when the parsing is done. The deferred itself doesn't hold a return value, the request is modified directly. @rtype: C{defer.Deferred} """ if request.stream.length == 0: return defer.succeed(None) ctype = request.headers.getHeader('content-type') if ctype is None: return defer.succeed(None) def updateArgs(data): args = data request.args.update(args) def updateArgsAndFiles(data): args, files = data request.args.update(args) request.files.update(files) def error(f): f.trap(fileupload.MimeFormatError) raise http.HTTPError( http.StatusResponse(responsecode.BAD_REQUEST, str(f.value))) if (ctype.mediaType == 'application' and ctype.mediaSubtype == 'x-www-form-urlencoded'): d = fileupload.parse_urlencoded(request.stream) d.addCallbacks(updateArgs, error) return d elif (ctype.mediaType == 'multipart' and ctype.mediaSubtype == 'form-data'): boundary = ctype.params.get('boundary') if boundary is None: return defer.fail(http.HTTPError( http.StatusResponse( responsecode.BAD_REQUEST, "Boundary not specified in Content-Type."))) d = fileupload.parseMultipartFormData(request.stream, boundary, maxMem, maxFields, maxSize) d.addCallbacks(updateArgsAndFiles, error) return d else: return defer.fail(http.HTTPError( http.StatusResponse( responsecode.BAD_REQUEST, "Invalid content-type: %s/%s" % ( ctype.mediaType, ctype.mediaSubtype)))) class StopTraversal(object): """ Indicates to Request._handleSegment that it should stop handling path segments. """ pass class Request(http.Request): """ vars: site remoteAddr scheme host port path params querystring args files prepath postpath @ivar path: The path only (arguments not included). @ivar args: All of the arguments, including URL and POST arguments. @type args: A mapping of strings (the argument names) to lists of values. i.e., ?foo=bar&foo=baz&quux=spam results in {'foo': ['bar', 'baz'], 'quux': ['spam']}. """ implements(iweb.IRequest) site = None _initialprepath = None responseFilters = [rangefilter, preconditionfilter, error.defaultErrorHandler, defaultHeadersFilter] def __init__(self, *args, **kw): self.timeStamps = [("t", time.time(),)] if kw.has_key('site'): self.site = kw['site'] del kw['site'] if kw.has_key('prepathuri'): self._initialprepath = kw['prepathuri'] del kw['prepathuri'] self._resourcesByURL = {} self._urlsByResource = {} # Copy response filters from the class self.responseFilters = self.responseFilters[:] self.files = {} self.resources = [] http.Request.__init__(self, *args, **kw) try: self.serverInstance = self.chanRequest.channel.transport.server.port except AttributeError: self.serverInstance = "Unknown" def timeStamp(self, tag): self.timeStamps.append((tag, time.time(),)) def addResponseFilter(self, filter, atEnd=False, onlyOnce=False): """ Add a response filter to this request. Response filters are applied to the response to this request in order. @param filter: a callable which takes an response argument and returns a response object. @param atEnd: if C{True}, C{filter} is added at the end of the list of response filters; if C{False}, it is added to the beginning. @param onlyOnce: if C{True}, C{filter} is not added to the list of response filters if it already in the list. """ if onlyOnce and filter in self.responseFilters: return if atEnd: self.responseFilters.append(filter) else: self.responseFilters.insert(0, filter) def unparseURL(self, scheme=None, host=None, port=None, path=None, params=None, querystring=None, fragment=None): """Turn the request path into a url string. For any pieces of the url that are not specified, use the value from the request. The arguments have the same meaning as the same named attributes of Request.""" if scheme is None: scheme = self.scheme if host is None: host = self.host if port is None: port = self.port if path is None: path = self.path if params is None: params = self.params if querystring is None: querystring = self.querystring if fragment is None: fragment = '' if port == http.defaultPortForScheme.get(scheme, 0): hostport = host else: hostport = host + ':' + str(port) return urlparse.urlunparse(( scheme, hostport, path, params, querystring, fragment)) def _parseURL(self): if self.uri[0] == '/': # Can't use urlparse for request_uri because urlparse # wants to be given an absolute or relative URI, not just # an abs_path, and thus gets '//foo' wrong. self.scheme = self.host = self.path = self.params = self.querystring = '' if '?' in self.uri: self.path, self.querystring = self.uri.split('?', 1) else: self.path = self.uri if ';' in self.path: self.path, self.params = self.path.split(';', 1) else: # It is an absolute uri, use standard urlparse (self.scheme, self.host, self.path, self.params, self.querystring, fragment) = urlparse.urlparse(self.uri) if self.querystring: self.args = cgi.parse_qs(self.querystring, True) else: self.args = {} path = map(unquote, self.path[1:].split('/')) if self._initialprepath: # We were given an initial prepath -- this is for supporting # CGI-ish applications where part of the path has already # been processed prepath = map(unquote, self._initialprepath[1:].split('/')) if path[:len(prepath)] == prepath: self.prepath = prepath self.postpath = path[len(prepath):] else: self.prepath = [] self.postpath = path else: self.prepath = [] self.postpath = path #print("_parseURL", self.uri, (self.uri, self.scheme, self.host, self.path, self.params, self.querystring)) def _schemeFromPort(self, port): """ Try to determine the scheme matching the supplied server port. This is needed in case where a device in front of the server is changing the scheme (e.g. decoding SSL) but not rewriting the scheme in URIs returned in responses (e.g. in Location headers). This could trick clients into using an inappropriate scheme for subsequent requests. What we should do is take the port number from the Host header or request-URI and map that to the scheme that matches the service we configured to listen on that port. @param port: the port number to test @type port: C{int} @return: C{True} if scheme is https (secure), C{False} otherwise @rtype: C{bool} """ #from twistedcaldav.config import config if hasattr(self.site, "EnableSSL") and self.site.EnableSSL: if port == self.site.SSLPort: return True elif port in self.site.BindSSLPorts: return True return False def _fixupURLParts(self): hostaddr, secure = self.chanRequest.getHostInfo() if not self.scheme: self.scheme = ('http', 'https')[secure] if self.host: self.host, self.port = http.splitHostPort(self.scheme, self.host) self.scheme = ('http', 'https')[self._schemeFromPort(self.port)] else: # If GET line wasn't an absolute URL host = self.headers.getHeader('host') if host: self.host, self.port = http.splitHostPort(self.scheme, host) self.scheme = ('http', 'https')[self._schemeFromPort(self.port)] else: # When no hostname specified anywhere, either raise an # error, or use the interface hostname, depending on # protocol version if self.clientproto >= (1,1): raise http.HTTPError(responsecode.BAD_REQUEST) self.host = hostaddr.host self.port = hostaddr.port def process(self): "Process a request." log.info("%s %s %s" % ( self.method, self.uri, "HTTP/%s.%s" % self.clientproto )) try: self.checkExpect() resp = self.preprocessRequest() if resp is not None: self._cbFinishRender(resp).addErrback(self._processingFailed) return self._parseURL() self._fixupURLParts() self.remoteAddr = self.chanRequest.getRemoteHost() except: self._processingFailed(failure.Failure()) return d = defer.Deferred() d.addCallback(self._getChild, self.site.resource, self.postpath) d.addCallback(self._rememberResource, "/" + "/".join(quote(s) for s in self.postpath)) d.addCallback(self._processTimeStamp) d.addCallback(lambda res, req: res.renderHTTP(req), self) d.addCallback(self._cbFinishRender) d.addErrback(self._processingFailed) d.callback(None) return d def _processTimeStamp(self, res): self.timeStamp("t-req-proc") return res def preprocessRequest(self): """Do any request processing that doesn't follow the normal resource lookup procedure. "OPTIONS *" is handled here, for example. This would also be the place to do any CONNECT processing.""" if self.method == "OPTIONS" and self.uri == "*": response = http.Response(responsecode.OK) response.headers.setHeader('allow', ('GET', 'HEAD', 'OPTIONS', 'TRACE')) return response elif self.method == "POST": # Allow other methods to tunnel through using POST and a request header. # See http://code.google.com/apis/gdata/docs/2.0/basics.html if self.headers.hasHeader("X-HTTP-Method-Override"): intendedMethod = self.headers.getRawHeaders("X-HTTP-Method-Override")[0]; if intendedMethod: self.originalMethod = self.method self.method = intendedMethod # This is where CONNECT would go if we wanted it return None def _getChild(self, _, res, path, updatepaths=True): """Call res.locateChild, and pass the result on to _handleSegment.""" self.resources.append(res) if not path: return res result = res.locateChild(self, path) if isinstance(result, defer.Deferred): return result.addCallback(self._handleSegment, res, path, updatepaths) else: return self._handleSegment(result, res, path, updatepaths) def _handleSegment(self, result, res, path, updatepaths): """Handle the result of a locateChild call done in _getChild.""" newres, newpath = result # If the child resource is None then display a error page if newres is None: raise http.HTTPError(responsecode.NOT_FOUND) # If we got a deferred then we need to call back later, once the # child is actually available. if isinstance(newres, defer.Deferred): return newres.addCallback( lambda actualRes: self._handleSegment( (actualRes, newpath), res, path, updatepaths) ) if path: url = quote("/" + "/".join(path)) else: url = "/" if newpath is StopTraversal: # We need to rethink how to do this. #if newres is res: return res #else: # raise ValueError("locateChild must not return StopTraversal with a resource other than self.") newres = iweb.IResource(newres) if newres is res: assert not newpath is path, "URL traversal cycle detected when attempting to locateChild %r from resource %r." % (path, res) assert len(newpath) < len(path), "Infinite loop impending..." if updatepaths: # We found a Resource... update the request.prepath and postpath for x in xrange(len(path) - len(newpath)): self.prepath.append(self.postpath.pop(0)) url = quote("/" + "/".join(self.prepath) + ("/" if self.prepath and self.prepath[-1] else "")) self._rememberResource(newres, url) else: try: previousURL = self.urlForResource(res) url = quote(previousURL + path[0] + ("/" if path[0] and len(path) > 1 else "")) self._rememberResource(newres, url) except NoURLForResourceError: pass child = self._getChild(None, newres, newpath, updatepaths=updatepaths) return child _urlsByResource = weakref.WeakKeyDictionary() def _rememberResource(self, resource, url): """ Remember the URL of a visited resource. """ self._resourcesByURL[url] = resource self._urlsByResource[resource] = url return resource def _forgetResource(self, resource, url): """ Remember the URL of a visited resource. """ del self._resourcesByURL[url] del self._urlsByResource[resource] def urlForResource(self, resource): """ Looks up the URL of the given resource if this resource was found while processing this request. Specifically, this includes the requested resource, and resources looked up via L{locateResource}. Note that a resource may be found at multiple URIs; if the same resource is visited at more than one location while processing this request, this method will return one of those URLs, but which one is not defined, nor whether the same URL is returned in subsequent calls. @param resource: the resource to find a URI for. This resource must have been obtained from the request (i.e. via its C{uri} attribute, or through its C{locateResource} or C{locateChildResource} methods). @return: a valid URL for C{resource} in this request. @raise NoURLForResourceError: if C{resource} has no URL in this request (because it was not obtained from the request). """ url = self._urlsByResource.get(resource, None) if url is None: raise NoURLForResourceError(resource) return url def locateResource(self, url): """ Looks up the resource with the given URL. @param uri: The URL of the desired resource. @return: a L{Deferred} resulting in the L{IResource} at the given URL or C{None} if no such resource can be located. @raise HTTPError: If C{url} is not a URL on the site that this request is being applied to. The contained response will have a status code of L{responsecode.BAD_GATEWAY}. @raise HTTPError: If C{url} contains a query or fragment. The contained response will have a status code of L{responsecode.BAD_REQUEST}. """ if url is None: return defer.succeed(None) # # Parse the URL # (scheme, host, path, query, fragment) = urlsplit(url) if query or fragment: raise http.HTTPError(http.StatusResponse( responsecode.BAD_REQUEST, "URL may not contain a query or fragment: %s" % (url,) )) # Look for cached value cached = self._resourcesByURL.get(path, None) if cached is not None: return defer.succeed(cached) segments = unquote(path).split("/") assert segments[0] == "", "URL path didn't begin with '/': %s" % (path,) # Walk the segments up to see if we can find a cached resource to start from preSegments = segments[:-1] postSegments = segments[-1:] cachedParent = None while(len(preSegments)): parentPath = "/".join(preSegments) + "/" cachedParent = self._resourcesByURL.get(parentPath, None) if cachedParent is not None: break else: postSegments.insert(0, preSegments.pop()) if cachedParent is None: cachedParent = self.site.resource postSegments = segments[1:] def notFound(f): f.trap(http.HTTPError) if f.value.response.code != responsecode.NOT_FOUND: return f return None d = defer.maybeDeferred(self._getChild, None, cachedParent, postSegments, updatepaths=False) d.addCallback(self._rememberResource, path) d.addErrback(notFound) return d def locateChildResource(self, parent, childName): """ Looks up the child resource with the given name given the parent resource. This is similar to locateResource(), but doesn't have to start the lookup from the root resource, so it is potentially faster. @param parent: the parent of the resource being looked up. This resource must have been obtained from the request (i.e. via its C{uri} attribute, or through its C{locateResource} or C{locateChildResource} methods). @param childName: the name of the child of C{parent} to looked up. to C{parent}. @return: a L{Deferred} resulting in the L{IResource} at the given URL or C{None} if no such resource can be located. @raise NoURLForResourceError: if C{resource} was not obtained from the request. """ if parent is None or childName is None: return None assert "/" not in childName, "Child name may not contain '/': %s" % (childName,) parentURL = self.urlForResource(parent) if not parentURL.endswith("/"): parentURL += "/" url = parentURL + quote(childName) segment = childName def notFound(f): f.trap(http.HTTPError) if f.value.response.code != responsecode.NOT_FOUND: return f return None d = defer.maybeDeferred(self._getChild, None, parent, [segment], updatepaths=False) d.addCallback(self._rememberResource, url) d.addErrback(notFound) return d def _processingFailed(self, reason): if reason.check(http.HTTPError) is not None: # If the exception was an HTTPError, leave it alone d = defer.succeed(reason.value.response) else: # Otherwise, it was a random exception, so give a # ICanHandleException implementer a chance to render the page. def _processingFailed_inner(reason): handler = iweb.ICanHandleException(self, self) return handler.renderHTTP_exception(self, reason) d = defer.maybeDeferred(_processingFailed_inner, reason) d.addCallback(self._cbFinishRender) d.addErrback(self._processingReallyFailed, reason) return d def _processingReallyFailed(self, reason, origReason): """ An error occurred when attempting to report an error to the HTTP client. """ log.failure("Exception rendering error page", reason) log.failure("Original exception", origReason) try: body = ( "Internal Server Error" "

Internal Server Error

" "An error occurred rendering the requested page. " "Additionally, an error occurred rendering the error page." "" ) response = http.Response( responsecode.INTERNAL_SERVER_ERROR, {'content-type': http_headers.MimeType('text','html')}, body ) self.writeResponse(response) except: log.failure( "An error occurred. We tried to report that error. " "Reporting that error caused an error. " "In the process of reporting the error-reporting error to " "the client, there was *yet another* error. Here it is. " "I give up." ) self.chanRequest.abortConnection() def _cbFinishRender(self, result): def filterit(response, f): if (hasattr(f, 'handleErrors') or (response.code >= 200 and response.code < 300)): return f(self, response) else: return response response = iweb.IResponse(result, None) if response: d = defer.Deferred() for f in self.responseFilters: d.addCallback(filterit, f) d.addCallback(self.writeResponse) d.callback(response) return d resource = iweb.IResource(result, None) if resource: self.resources.append(resource) d = defer.maybeDeferred(resource.renderHTTP, self) d.addCallback(self._cbFinishRender) return d raise TypeError("html is not a resource or a response") def renderHTTP_exception(self, req, reason): log.failure("Exception rendering request: {request}", reason, request=req) body = ("Internal Server Error" "

Internal Server Error

An error occurred rendering the requested page. More information is available in the server log.") return http.Response( responsecode.INTERNAL_SERVER_ERROR, {'content-type': http_headers.MimeType('text','html')}, body) class Site(object): def __init__(self, resource): """Initialize. """ self.resource = iweb.IResource(resource) def __call__(self, *args, **kwargs): return Request(site=self, *args, **kwargs) class NoURLForResourceError(RuntimeError): def __init__(self, resource): RuntimeError.__init__(self, "Resource %r has no URL in this request." % (resource,)) self.resource = resource __all__ = ['Request', 'Site', 'StopTraversal', 'VERSION', 'defaultHeadersFilter', 'doTrace', 'parsePOSTData', 'preconditionfilter', 'NoURLForResourceError'] calendarserver-5.2+dfsg/twext/web2/stream.py0000644000175000017500000011126612263343324020177 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_stream -*- ## # Copyright (c) 2001-2007 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ The stream module provides a simple abstraction of streaming data. While Twisted already has some provisions for handling this in its Producer/Consumer model, the rather complex interactions between producer and consumer makes it difficult to implement something like the CompoundStream object. Thus, this API. The IStream interface is very simple. It consists of two methods: read, and close. The read method should either return some data, None if there is no data left to read, or a Deferred. Close frees up any underlying resources and causes read to return None forevermore. IByteStream adds a bit more to the API: 1) read is required to return objects conforming to the buffer interface. 2) .length, which may either an integer number of bytes remaining, or None if unknown 3) .split(position). Split takes a position, and splits the stream in two pieces, returning the two new streams. Using the original stream after calling split is not allowed. There are two builtin source stream classes: FileStream and MemoryStream. The first produces data from a file object, the second from a buffer in memory. Any number of these can be combined into one stream with the CompoundStream object. Then, to interface with other parts of Twisted, there are two transcievers: StreamProducer and ProducerStream. The first takes a stream and turns it into an IPushProducer, which will write to a consumer. The second is a consumer which is a stream, so that other producers can write to it. """ from __future__ import generators import copy, os, types, sys from zope.interface import Interface, Attribute, implements from twisted.internet.defer import Deferred from twisted.internet import interfaces as ti_interfaces, defer, reactor, protocol, error as ti_error from twisted.python import components from twisted.python.failure import Failure from hashlib import md5 from twext.python.log import Logger log = Logger() # Python 2.4.2 (only) has a broken mmap that leaks a fd every time you call it. if sys.version_info[0:3] != (2,4,2): try: import mmap except ImportError: mmap = None else: mmap = None ############################## #### Interfaces #### ############################## class IStream(Interface): """A stream of arbitrary data.""" def read(): """Read some data. Returns some object representing the data. If there is no more data available, returns None. Can also return a Deferred resulting in one of the above. Errors may be indicated by exception or by a Deferred of a Failure. """ def close(): """Prematurely close. Should also cause further reads to return None.""" class IByteStream(IStream): """A stream which is of bytes.""" length = Attribute("""How much data is in this stream. Can be None if unknown.""") def read(): """Read some data. Returns an object conforming to the buffer interface, or if there is no more data available, returns None. Can also return a Deferred resulting in one of the above. Errors may be indicated by exception or by a Deferred of a Failure. """ def split(point): """Split this stream into two, at byte position 'point'. Returns a tuple of (before, after). After calling split, no other methods should be called on this stream. Doing so will have undefined behavior. If you cannot implement split easily, you may implement it as:: return fallbackSplit(self, point) """ def close(): """Prematurely close this stream. Should also cause further reads to return None. Additionally, .length should be set to 0. """ class ISendfileableStream(Interface): def read(sendfile=False): """ Read some data. If sendfile == False, returns an object conforming to the buffer interface, or else a Deferred. If sendfile == True, returns either the above, or a SendfileBuffer. """ class SimpleStream(object): """Superclass of simple streams with a single buffer and a offset and length into that buffer.""" implements(IByteStream) length = None start = None def read(self): return None def close(self): self.length = 0 def split(self, point): if self.length is not None: if point > self.length: raise ValueError("split point (%d) > length (%d)" % (point, self.length)) b = copy.copy(self) self.length = point if b.length is not None: b.length -= point b.start += point return (self, b) ############################## #### FileStream #### ############################## # maximum mmap size MMAP_LIMIT = 4*1024*1024 # minimum mmap size MMAP_THRESHOLD = 8*1024 # maximum sendfile length SENDFILE_LIMIT = 16777216 # minimum sendfile size SENDFILE_THRESHOLD = 256 def mmapwrapper(*args, **kwargs): """ Python's mmap call sucks and ommitted the "offset" argument for no discernable reason. Replace this with a mmap module that has offset. """ offset = kwargs.get('offset', None) if offset in [None, 0]: if 'offset' in kwargs: del kwargs['offset'] else: raise mmap.error("mmap: Python sucks and does not support offset.") return mmap.mmap(*args, **kwargs) class FileStream(SimpleStream): implements(ISendfileableStream) """A stream that reads data from a file. File must be a normal file that supports seek, (e.g. not a pipe or device or socket).""" # 65K, minus some slack CHUNK_SIZE = 2 ** 2 ** 2 ** 2 - 32 f = None def __init__(self, f, start=0, length=None, useMMap=bool(mmap)): """ Create the stream from file f. If you specify start and length, use only that portion of the file. """ self.f = f self.start = start if length is None: self.length = os.fstat(f.fileno()).st_size else: self.length = length self.useMMap = useMMap def read(self, sendfile=False): if self.f is None: return None length = self.length if length == 0: self.f = None return None #if sendfile and length > SENDFILE_THRESHOLD: # # XXX: Yay using non-existent sendfile support! # # FIXME: if we return a SendfileBuffer, and then sendfile # # fails, then what? Or, what if file is too short? # readSize = min(length, SENDFILE_LIMIT) # res = SendfileBuffer(self.f, self.start, readSize) # self.length -= readSize # self.start += readSize # return res if self.useMMap and length > MMAP_THRESHOLD: readSize = min(length, MMAP_LIMIT) try: res = mmapwrapper(self.f.fileno(), readSize, access=mmap.ACCESS_READ, offset=self.start) #madvise(res, MADV_SEQUENTIAL) self.length -= readSize self.start += readSize return res except mmap.error: pass # Fall back to standard read. readSize = min(length, self.CHUNK_SIZE) self.f.seek(self.start) b = self.f.read(readSize) bytesRead = len(b) if not bytesRead: raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length)) else: self.length -= bytesRead self.start += bytesRead return b def close(self): self.f = None SimpleStream.close(self) components.registerAdapter(FileStream, file, IByteStream) ############################## #### MemoryStream #### ############################## class MemoryStream(SimpleStream): """A stream that reads data from a buffer object.""" def __init__(self, mem, start=0, length=None): """ Create the stream from buffer object mem. If you specify start and length, use only that portion of the buffer. """ self.mem = mem self.start = start if length is None: self.length = len(mem) - start else: if len(mem) < length: raise ValueError("len(mem) < start + length") self.length = length def read(self): if self.mem is None: return None if self.length == 0: result = None else: result = buffer(self.mem, self.start, self.length) self.mem = None self.length = 0 return result def close(self): self.mem = None SimpleStream.close(self) components.registerAdapter(MemoryStream, str, IByteStream) components.registerAdapter(MemoryStream, types.BufferType, IByteStream) ############################## #### CompoundStream #### ############################## class CompoundStream(object): """A stream which is composed of many other streams. Call addStream to add substreams. """ implements(IByteStream, ISendfileableStream) deferred = None length = 0 def __init__(self, buckets=()): self.buckets = [IByteStream(s) for s in buckets] def addStream(self, bucket): """Add a stream to the output""" bucket = IByteStream(bucket) self.buckets.append(bucket) if self.length is not None: if bucket.length is None: self.length = None else: self.length += bucket.length def read(self, sendfile=False): if self.deferred is not None: raise RuntimeError("Call to read while read is already outstanding") if not self.buckets: return None if sendfile and ISendfileableStream.providedBy(self.buckets[0]): try: result = self.buckets[0].read(sendfile) except: return self._gotFailure(Failure()) else: try: result = self.buckets[0].read() except: return self._gotFailure(Failure()) if isinstance(result, Deferred): self.deferred = result result.addCallbacks(self._gotRead, self._gotFailure, (sendfile,)) return result return self._gotRead(result, sendfile) def _gotFailure(self, f): self.deferred = None del self.buckets[0] self.close() return f def _gotRead(self, result, sendfile): self.deferred = None if result is None: del self.buckets[0] # Next bucket return self.read(sendfile) if self.length is not None: self.length -= len(result) return result def split(self, point): num = 0 origPoint = point for bucket in self.buckets: num+=1 if point == 0: b = CompoundStream() b.buckets = self.buckets[num:] del self.buckets[num:] return self,b if bucket.length is None: # Indeterminate length bucket. # give up and use fallback splitter. return fallbackSplit(self, origPoint) if point < bucket.length: before,after = bucket.split(point) b = CompoundStream() b.buckets = self.buckets[num:] b.buckets[0] = after del self.buckets[num+1:] self.buckets[num] = before return self,b point -= bucket.length def close(self): for bucket in self.buckets: bucket.close() self.buckets = [] self.length = 0 ############################## #### readStream #### ############################## class _StreamReader(object): """Process a stream's data using callbacks for data and stream finish.""" def __init__(self, stream, gotDataCallback): self.stream = stream self.gotDataCallback = gotDataCallback self.result = Deferred() def run(self): # self.result may be del'd in _read() result = self.result self._read() return result def _read(self): try: result = self.stream.read() except: self._gotError(Failure()) return if isinstance(result, Deferred): result.addCallbacks(self._gotData, self._gotError) else: self._gotData(result) def _gotError(self, failure): result = self.result del self.result, self.gotDataCallback, self.stream result.errback(failure) def _gotData(self, data): if data is None: result = self.result del self.result, self.gotDataCallback, self.stream result.callback(None) return try: self.gotDataCallback(data) except: self._gotError(Failure()) return reactor.callLater(0, self._read) def readStream(stream, gotDataCallback): """Pass a stream's data to a callback. Returns Deferred which will be triggered on finish. Errors in reading the stream or in processing it will be returned via this Deferred. """ return _StreamReader(stream, gotDataCallback).run() def readAndDiscard(stream): """Read all the data from the given stream, and throw it out. Returns Deferred which will be triggered on finish. """ return readStream(stream, lambda _: None) def readIntoFile(stream, outFile): """Read a stream and write it into a file. Returns Deferred which will be triggered on finish. """ def done(_): outFile.close() return _ return readStream(stream, outFile.write).addBoth(done) def connectStream(inputStream, factory): """Connect a protocol constructed from a factory to stream. Returns an output stream from the protocol. The protocol's transport will have a finish() method it should call when done writing. """ # XXX deal better with addresses p = factory.buildProtocol(None) out = ProducerStream() out.disconnecting = False # XXX for LineReceiver suckage p.makeConnection(out) readStream(inputStream, lambda _: p.dataReceived(_)).addCallbacks( lambda _: p.connectionLost(ti_error.ConnectionDone()), lambda _: p.connectionLost(_)) return out ############################## #### fallbackSplit #### ############################## def fallbackSplit(stream, point): after = PostTruncaterStream(stream, point) before = TruncaterStream(stream, point, after) return (before, after) class TruncaterStream(object): def __init__(self, stream, point, postTruncater): self.stream = stream self.length = point self.postTruncater = postTruncater def read(self): if self.length == 0: if self.postTruncater is not None: postTruncater = self.postTruncater self.postTruncater = None postTruncater.sendInitialSegment(self.stream.read()) self.stream = None return None result = self.stream.read() if isinstance(result, Deferred): return result.addCallback(self._gotRead) else: return self._gotRead(result) def _gotRead(self, data): if data is None: raise ValueError("Ran out of data for a split of a indeterminate length source") if self.length >= len(data): self.length -= len(data) return data else: before = buffer(data, 0, self.length) after = buffer(data, self.length) self.length = 0 if self.postTruncater is not None: postTruncater = self.postTruncater self.postTruncater = None postTruncater.sendInitialSegment(after) self.stream = None return before def split(self, point): if point > self.length: raise ValueError("split point (%d) > length (%d)" % (point, self.length)) post = PostTruncaterStream(self.stream, point) trunc = TruncaterStream(post, self.length - point, self.postTruncater) self.length = point self.postTruncater = post return self, trunc def close(self): if self.postTruncater is not None: self.postTruncater.notifyClosed(self) else: # Nothing cares about the rest of the stream self.stream.close() self.stream = None self.length = 0 class PostTruncaterStream(object): deferred = None sentInitialSegment = False truncaterClosed = None closed = False length = None def __init__(self, stream, point): self.stream = stream self.deferred = Deferred() if stream.length is not None: self.length = stream.length - point def read(self): if not self.sentInitialSegment: self.sentInitialSegment = True if self.truncaterClosed is not None: readAndDiscard(self.truncaterClosed) self.truncaterClosed = None return self.deferred return self.stream.read() def split(self, point): return fallbackSplit(self, point) def close(self): self.closed = True if self.truncaterClosed is not None: # have first half close itself self.truncaterClosed.postTruncater = None self.truncaterClosed.close() elif self.sentInitialSegment: # first half already finished up self.stream.close() self.deferred = None # Callbacks from TruncaterStream def sendInitialSegment(self, data): if self.closed: # First half finished, we don't want data. self.stream.close() self.stream = None if self.deferred is not None: if isinstance(data, Deferred): data.chainDeferred(self.deferred) else: self.deferred.callback(data) def notifyClosed(self, truncater): if self.closed: # we are closed, have first half really close truncater.postTruncater = None truncater.close() elif self.sentInitialSegment: # We are trying to read, read up first half readAndDiscard(truncater) else: # Idle, store closed info. self.truncaterClosed = truncater ######################################## #### ProducerStream/StreamProducer #### ######################################## class ProducerStream(object): """Turns producers into a IByteStream. Thus, implements IConsumer and IByteStream.""" implements(IByteStream, ti_interfaces.IConsumer) length = None closed = False failed = False producer = None producerPaused = False deferred = None bufferSize = 5 def __init__(self, length=None): self.buffer = [] self.length = length # IByteStream implementation def read(self): if self.buffer: return self.buffer.pop(0) elif self.closed: self.length = 0 if self.failed: f = self.failure del self.failure return defer.fail(f) return None else: deferred = self.deferred = Deferred() if self.producer is not None and (not self.streamingProducer or self.producerPaused): self.producerPaused = False self.producer.resumeProducing() return deferred def split(self, point): return fallbackSplit(self, point) def close(self): """Called by reader of stream when it is done reading.""" self.buffer=[] self.closed = True if self.producer is not None: self.producer.stopProducing() self.producer = None self.deferred = None # IConsumer implementation def write(self, data): if self.closed: return if self.deferred: deferred = self.deferred self.deferred = None deferred.callback(data) else: self.buffer.append(data) if(self.producer is not None and self.streamingProducer and len(self.buffer) > self.bufferSize): self.producer.pauseProducing() self.producerPaused = True def finish(self, failure=None): """Called by producer when it is done. If the optional failure argument is passed a Failure instance, the stream will return it as errback on next Deferred. """ self.closed = True if not self.buffer: self.length = 0 if self.deferred is not None: deferred = self.deferred self.deferred = None if failure is not None: self.failed = True deferred.errback(failure) else: deferred.callback(None) else: if failure is not None: self.failed = True self.failure = failure def registerProducer(self, producer, streaming): if self.producer is not None: raise RuntimeError("Cannot register producer %s, because producer %s was never unregistered." % (producer, self.producer)) if self.closed: producer.stopProducing() else: self.producer = producer self.streamingProducer = streaming if not streaming: producer.resumeProducing() def unregisterProducer(self): self.producer = None class StreamProducer(object): """A push producer which gets its data by reading a stream.""" implements(ti_interfaces.IPushProducer) deferred = None finishedCallback = None paused = False consumer = None def __init__(self, stream, enforceStr=True): self.stream = stream self.enforceStr = enforceStr def beginProducing(self, consumer): if self.stream is None: return defer.succeed(None) self.consumer = consumer finishedCallback = self.finishedCallback = Deferred() self.consumer.registerProducer(self, True) self.resumeProducing() return finishedCallback def resumeProducing(self): self.paused = False if self.deferred is not None: return try: data = self.stream.read() except: self.stopProducing(Failure()) return if isinstance(data, Deferred): self.deferred = data self.deferred.addCallbacks(self._doWrite, self.stopProducing) else: self._doWrite(data) def _doWrite(self, data): if self.consumer is None: return if data is None: # The end. if self.consumer is not None: self.consumer.unregisterProducer() if self.finishedCallback is not None: self.finishedCallback.callback(None) self.finishedCallback = self.deferred = self.consumer = self.stream = None return self.deferred = None if self.enforceStr: # XXX: sucks that we have to do this. make transport.write(buffer) work! data = str(buffer(data)) self.consumer.write(data) if not self.paused: self.resumeProducing() def pauseProducing(self): self.paused = True def stopProducing(self, failure=ti_error.ConnectionLost()): if self.consumer is not None: self.consumer.unregisterProducer() if self.finishedCallback is not None: if failure is not None: self.finishedCallback.errback(failure) else: self.finishedCallback.callback(None) self.finishedCallback = None self.paused = True if self.stream is not None: self.stream.close() self.finishedCallback = self.deferred = self.consumer = self.stream = None ############################## #### ProcessStreamer #### ############################## class _ProcessStreamerProtocol(protocol.ProcessProtocol): def __init__(self, inputStream, outStream, errStream): self.inputStream = inputStream self.outStream = outStream self.errStream = errStream self.resultDeferred = defer.Deferred() def connectionMade(self): p = StreamProducer(self.inputStream) # if the process stopped reading from the input stream, # this is not an error condition, so it oughtn't result # in a ConnectionLost() from the input stream: p.stopProducing = lambda err=None: StreamProducer.stopProducing(p, err) d = p.beginProducing(self.transport) d.addCallbacks(lambda _: self.transport.closeStdin(), self._inputError) def _inputError(self, f): log.failure("Error in input stream for transport {transport}", f, transport=self.transport) self.transport.closeStdin() def outReceived(self, data): self.outStream.write(data) def errReceived(self, data): self.errStream.write(data) def outConnectionLost(self): self.outStream.finish() def errConnectionLost(self): self.errStream.finish() def processEnded(self, reason): self.resultDeferred.errback(reason) del self.resultDeferred class ProcessStreamer(object): """Runs a process hooked up to streams. Requires an input stream, has attributes 'outStream' and 'errStream' for stdout and stderr. outStream and errStream are public attributes providing streams for stdout and stderr of the process. """ def __init__(self, inputStream, program, args, env={}): self.outStream = ProducerStream() self.errStream = ProducerStream() self._protocol = _ProcessStreamerProtocol(IByteStream(inputStream), self.outStream, self.errStream) self._program = program self._args = args self._env = env def run(self): """Run the process. Returns Deferred which will eventually have errback for non-clean (exit code > 0) exit, with ProcessTerminated, or callback with None on exit code 0. """ # XXX what happens if spawn fails? reactor.spawnProcess(self._protocol, self._program, self._args, env=self._env) del self._env return self._protocol.resultDeferred.addErrback(lambda _: _.trap(ti_error.ProcessDone)) ############################## #### generatorToStream #### ############################## class _StreamIterator(object): done=False def __iter__(self): return self def next(self): if self.done: raise StopIteration return self.value wait=object() class _IteratorStream(object): length = None def __init__(self, fun, stream, args, kwargs): self._stream=stream self._streamIterator = _StreamIterator() self._gen = fun(self._streamIterator, *args, **kwargs) def read(self): try: val = self._gen.next() except StopIteration: return None else: if val is _StreamIterator.wait: newdata = self._stream.read() if isinstance(newdata, defer.Deferred): return newdata.addCallback(self._gotRead) else: return self._gotRead(newdata) return val def _gotRead(self, data): if data is None: self._streamIterator.done=True else: self._streamIterator.value=data return self.read() def close(self): self._stream.close() del self._gen, self._stream, self._streamIterator def split(self): return fallbackSplit(self) def generatorToStream(fun): """Converts a generator function into a stream. The function should take an iterator as its first argument, which will be converted *from* a stream by this wrapper, and yield items which are turned *into* the results from the stream's 'read' call. One important point: before every call to input.next(), you *MUST* do a "yield input.wait" first. Yielding this magic value takes care of ensuring that the input is not a deferred before you see it. >>> from twext.web2 import stream >>> from string import maketrans >>> alphabet = 'abcdefghijklmnopqrstuvwxyz' >>> >>> def encrypt(input, key): ... code = alphabet[key:] + alphabet[:key] ... translator = maketrans(alphabet+alphabet.upper(), code+code.upper()) ... yield input.wait ... for s in input: ... yield str(s).translate(translator) ... yield input.wait ... >>> encrypt = stream.generatorToStream(encrypt) >>> >>> plaintextStream = stream.MemoryStream('SampleSampleSample') >>> encryptedStream = encrypt(plaintextStream, 13) >>> encryptedStream.read() 'FnzcyrFnzcyrFnzcyr' >>> >>> plaintextStream = stream.MemoryStream('SampleSampleSample') >>> encryptedStream = encrypt(plaintextStream, 13) >>> evenMoreEncryptedStream = encrypt(encryptedStream, 13) >>> evenMoreEncryptedStream.read() 'SampleSampleSample' """ def generatorToStream_inner(stream, *args, **kwargs): return _IteratorStream(fun, stream, args, kwargs) return generatorToStream_inner ############################## #### BufferedStream #### ############################## class BufferedStream(object): """A stream which buffers its data to provide operations like readline and readExactly.""" data = "" def __init__(self, stream): self.stream = stream def _readUntil(self, f): """Internal helper function which repeatedly calls f each time after more data has been received, until it returns non-None.""" while True: r = f() if r is not None: yield r; return newdata = self.stream.read() if isinstance(newdata, defer.Deferred): newdata = defer.waitForDeferred(newdata) yield newdata; newdata = newdata.getResult() if newdata is None: # End Of File newdata = self.data self.data = '' yield newdata; return self.data += str(newdata) _readUntil = defer.deferredGenerator(_readUntil) def readExactly(self, size=None): """Read exactly size bytes of data, or, if size is None, read the entire stream into a string.""" if size is not None and size < 0: raise ValueError("readExactly: size cannot be negative: %s", size) def gotdata(): data = self.data if size is not None and len(data) >= size: pre,post = data[:size], data[size:] self.data = post return pre return self._readUntil(gotdata) def readline(self, delimiter='\r\n', size=None): """ Read a line of data from the string, bounded by delimiter. The delimiter is included in the return value. If size is specified, read and return at most that many bytes, even if the delimiter has not yet been reached. If the size limit falls within a delimiter, the rest of the delimiter, and the next line will be returned together. """ if size is not None and size < 0: raise ValueError("readline: size cannot be negative: %s" % (size, )) def gotdata(): data = self.data if size is not None: splitpoint = data.find(delimiter, 0, size) if splitpoint == -1: if len(data) >= size: splitpoint = size else: splitpoint += len(delimiter) else: splitpoint = data.find(delimiter) if splitpoint != -1: splitpoint += len(delimiter) if splitpoint != -1: pre = data[:splitpoint] self.data = data[splitpoint:] return pre return self._readUntil(gotdata) def pushback(self, pushed): """Push data back into the buffer.""" self.data = pushed + self.data def read(self): data = self.data if data: self.data = "" return data return self.stream.read() def _len(self): l = self.stream.length if l is None: return None return l + len(self.data) length = property(_len) def split(self, offset): off = offset - len(self.data) pre, post = self.stream.split(max(0, off)) pre = BufferedStream(pre) post = BufferedStream(post) if off < 0: pre.data = self.data[:-off] post.data = self.data[-off:] else: pre.data = self.data return pre, post ######################### #### MD5Stream #### ######################### class MD5Stream(SimpleStream): """ An wrapper which computes the MD5 hash of the data read from the wrapped stream. """ def __init__(self, wrap): if wrap is None: raise ValueError("Stream to wrap must be provided") self._stream = wrap self._md5 = md5() def _update(self, value): """ Update the MD5 hash object. @param value: L{None} or a L{str} with which to update the MD5 hash object. @return: C{value} """ if value is not None: self._md5.update(value) return value def read(self): """ Read from the wrapped stream and update the MD5 hash object. """ if self._stream is None: raise RuntimeError("Cannot read after stream is closed") b = self._stream.read() if isinstance(b, Deferred): b.addCallback(self._update) else: self._update(b) return b def close(self): """ Compute the final hex digest of the contents of the wrapped stream. """ SimpleStream.close(self) self._md5value = self._md5.hexdigest() self._stream = None self._md5 = None def getMD5(self): """ Return the hex encoded MD5 digest of the contents of the wrapped stream. This may only be called after C{close}. @rtype: C{str} @raise RuntimeError: If C{close} has not yet been called. """ if self._md5 is not None: raise RuntimeError("Cannot get MD5 value until stream is closed") return self._md5value __all__ = ['IStream', 'IByteStream', 'FileStream', 'MemoryStream', 'CompoundStream', 'readAndDiscard', 'fallbackSplit', 'ProducerStream', 'StreamProducer', 'BufferedStream', 'MD5Stream', 'readStream', 'ProcessStreamer', 'readIntoFile', 'generatorToStream'] calendarserver-5.2+dfsg/twext/web2/resource.py0000644000175000017500000003004712263343324020530 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_server,twext.web2.test.test_resource -*- ## # Copyright (c) 2001-2007 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ I hold the lowest-level L{Resource} class and related mix-in classes. """ # System Imports from zope.interface import implements from twisted.internet.defer import inlineCallbacks, returnValue from twext.web2 import iweb, http, server, responsecode from twisted.internet.defer import maybeDeferred class RenderMixin(object): """ Mix-in class for L{iweb.IResource} which provides a dispatch mechanism for handling HTTP methods. """ def allowedMethods(self): """ @return: A tuple of HTTP methods that are allowed to be invoked on this resource. """ if not hasattr(self, "_allowed_methods"): self._allowed_methods = tuple([name[5:] for name in dir(self) if name.startswith('http_') and getattr(self, name) is not None]) return self._allowed_methods def checkPreconditions(self, request): """ Checks all preconditions imposed by this resource upon a request made against it. @param request: the request to process. @raise http.HTTPError: if any precondition fails. @return: C{None} or a deferred whose callback value is C{request}. """ # # http.checkPreconditions() gets called by the server after every # GET or HEAD request. # # For other methods, we need to know to bail out before request # processing, especially for methods that modify server state (eg. PUT). # We also would like to do so even for methods that don't, if those # methods might be expensive to process. We're assuming that GET and # HEAD are not expensive. # if request.method not in ("GET", "HEAD"): http.checkPreconditions(request) # Check per-method preconditions method = getattr(self, "preconditions_" + request.method, None) if method: return method(request) @inlineCallbacks def renderHTTP(self, request): """ See L{iweb.IResource.renderHTTP}. This implementation will dispatch the given C{request} to another method of C{self} named C{http_}METHOD, where METHOD is the HTTP method used by C{request} (eg. C{http_GET}, C{http_POST}, etc.). Generally, a subclass should implement those methods instead of overriding this one. C{http_*} methods are expected provide the same interface and return the same results as L{iweb.IResource}C{.renderHTTP} (and therefore this method). C{etag} and C{last-modified} are added to the response returned by the C{http_*} header, if known. If an appropriate C{http_*} method is not found, a L{responsecode.NOT_ALLOWED}-status response is returned, with an appropriate C{allow} header. @param request: the request to process. @return: an object adaptable to L{iweb.IResponse}. """ method = getattr(self, "http_" + request.method, None) if method is None: response = http.Response(responsecode.NOT_ALLOWED) response.headers.setHeader("allow", self.allowedMethods()) returnValue(response) yield self.checkPreconditions(request) result = maybeDeferred(method, request) result.addErrback(self.methodRaisedException) returnValue((yield result)) def methodRaisedException(self, failure): """ An C{http_METHOD} method raised an exception; this is an errback for that exception. By default, simply propagate the error up; subclasses may override this for top-level exception handling. """ return failure def http_OPTIONS(self, request): """ Respond to a OPTIONS request. @param request: the request to process. @return: an object adaptable to L{iweb.IResponse}. """ response = http.Response(responsecode.OK) response.headers.setHeader("allow", self.allowedMethods()) return response # def http_TRACE(self, request): # """ # Respond to a TRACE request. # @param request: the request to process. # @return: an object adaptable to L{iweb.IResponse}. # """ # return server.doTrace(request) def http_HEAD(self, request): """ Respond to a HEAD request. @param request: the request to process. @return: an object adaptable to L{iweb.IResponse}. """ return self.http_GET(request) def http_GET(self, request): """ Respond to a GET request. This implementation validates that the request body is empty and then dispatches the given C{request} to L{render} and returns its result. @param request: the request to process. @return: an object adaptable to L{iweb.IResponse}. """ if request.stream.length != 0: return responsecode.REQUEST_ENTITY_TOO_LARGE return self.render(request) def render(self, request): """ Subclasses should implement this method to do page rendering. See L{http_GET}. @param request: the request to process. @return: an object adaptable to L{iweb.IResponse}. """ raise NotImplementedError("Subclass must implement render method.") class Resource(RenderMixin): """ An L{iweb.IResource} implementation with some convenient mechanisms for locating children. """ implements(iweb.IResource) addSlash = False def locateChild(self, request, segments): """ Locates a child resource of this resource. @param request: the request to process. @param segments: a sequence of URL path segments. @return: a tuple of C{(child, segments)} containing the child of this resource which matches one or more of the given C{segments} in sequence, and a list of remaining segments. """ w = getattr(self, 'child_%s' % (segments[0],), None) if w: r = iweb.IResource(w, None) if r: return r, segments[1:] return w(request), segments[1:] factory = getattr(self, 'childFactory', None) if factory is not None: r = factory(request, segments[0]) if r: return r, segments[1:] return None, [] def child_(self, request): """ This method locates a child with a trailing C{"/"} in the URL. @param request: the request to process. """ if self.addSlash and len(request.postpath) == 1: return self return None def getChild(self, path): """ Get a static child - when registered using L{putChild}. @param path: the name of the child to get @type path: C{str} @return: the child or C{None} if not present @rtype: L{iweb.IResource} """ return getattr(self, 'child_%s' % (path,), None) def putChild(self, path, child): """ Register a static child. This implementation registers children by assigning them to attributes with a C{child_} prefix. C{resource.putChild("foo", child)} is therefore same as C{o.child_foo = child}. @param path: the name of the child to register. You almost certainly don't want C{"/"} in C{path}. If you want to add a "directory" resource (e.g. C{/foo/}) specify C{path} as C{""}. @param child: an object adaptable to L{iweb.IResource}. """ setattr(self, 'child_%s' % (path,), child) def http_GET(self, request): if self.addSlash and request.prepath[-1] != '': # If this is a directory-ish resource... return http.RedirectResponse(request.unparseURL(path=request.path + '/')) return super(Resource, self).http_GET(request) class PostableResource(Resource): """ A L{Resource} capable of handling the POST request method. @cvar maxMem: maximum memory used during the parsing of the data. @type maxMem: C{int} @cvar maxFields: maximum number of form fields allowed. @type maxFields: C{int} @cvar maxSize: maximum size of the whole post allowed. @type maxSize: C{int} """ maxMem = 100 * 1024 maxFields = 1024 maxSize = 10 * 1024 * 1024 def http_POST(self, request): """ Respond to a POST request. Reads and parses the incoming body data then calls L{render}. @param request: the request to process. @return: an object adaptable to L{iweb.IResponse}. """ return server.parsePOSTData(request, self.maxMem, self.maxFields, self.maxSize ).addCallback(lambda res: self.render(request)) class LeafResource(RenderMixin): """ A L{Resource} with no children. """ implements(iweb.IResource) def locateChild(self, request, segments): return self, server.StopTraversal class RedirectResource(LeafResource): """ A L{LeafResource} which always performs a redirect. """ implements(iweb.IResource) def __init__(self, *args, **kwargs): """ Parameters are URL components and are the same as those for L{urlparse.urlunparse}. URL components which are not specified will default to the corresponding component of the URL of the request being redirected. """ self._args = args self._kwargs = kwargs def renderHTTP(self, request): return http.RedirectResponse(request.unparseURL(*self._args, **self._kwargs)) class WrapperResource(object): """ An L{iweb.IResource} implementation which wraps a L{RenderMixin} instance and provides a hook in which a subclass can implement logic that is called before request processing on the contained L{Resource}. """ implements(iweb.IResource) def __init__(self, resource): self.resource = resource def hook(self, request): """ Override this method in order to do something before passing control on to the wrapped resource's C{renderHTTP} and C{locateChild} methods. @return: None or a L{Deferred}. If a deferred object is returned, it's value is ignored, but C{renderHTTP} and C{locateChild} are chained onto the deferred as callbacks. """ raise NotImplementedError() def locateChild(self, request, segments): x = self.hook(request) if x is not None: return x.addCallback(lambda data: (self.resource, segments)) return self.resource, segments def renderHTTP(self, request): x = self.hook(request) if x is not None: return x.addCallback(lambda data: self.resource) return self.resource def getChild(self, name): return self.resource.getChild(name) __all__ = ['RenderMixin', 'Resource', 'PostableResource', 'LeafResource', 'WrapperResource'] calendarserver-5.2+dfsg/twext/web2/filter/0000755000175000017500000000000012322625325017610 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/web2/filter/gzip.py0000644000175000017500000000536411340001243021125 0ustar rahulrahulfrom __future__ import generators import struct import zlib from twext.web2 import stream # TODO: ungzip (can any browsers actually generate gzipped # upload data?) But it's necessary for client anyways. def gzipStream(input, compressLevel=6): crc, size = zlib.crc32(''), 0 # magic header, compression method, no flags header = '\037\213\010\000' # timestamp header += struct.pack('= size: end = size - 1 if start >= size: raise UnsatisfiableRangeRequest return start,end def makeUnsatisfiable(request, oldresponse): if request.headers.hasHeader('if-range'): return oldresponse # Return resource instead of error response = http.Response(responsecode.REQUESTED_RANGE_NOT_SATISFIABLE) response.headers.setHeader("content-range", ('bytes', None, None, oldresponse.stream.length)) return response def makeSegment(inputStream, lastOffset, start, end): offset = start - lastOffset length = end + 1 - start if offset != 0: before, inputStream = inputStream.split(offset) before.close() return inputStream.split(length) def rangefilter(request, oldresponse): if oldresponse.stream is None: return oldresponse size = oldresponse.stream.length if size is None: # Does not deal with indeterminate length outputs return oldresponse oldresponse.headers.setHeader('accept-ranges',('bytes',)) rangespec = request.headers.getHeader('range') # If we've got a range header and the If-Range header check passes, and # the range type is bytes, do a partial response. if (rangespec is not None and http.checkIfRange(request, oldresponse) and rangespec[0] == 'bytes'): # If it's a single range, return a simple response if len(rangespec[1]) == 1: try: start,end = canonicalizeRange(rangespec[1][0], size) except UnsatisfiableRangeRequest: return makeUnsatisfiable(request, oldresponse) response = http.Response(responsecode.PARTIAL_CONTENT, oldresponse.headers) response.headers.setHeader('content-range',('bytes',start, end, size)) content, after = makeSegment(oldresponse.stream, 0, start, end) after.close() response.stream = content return response else: # Return a multipart/byteranges response lastOffset = -1 offsetList = [] for arange in rangespec[1]: try: start,end = canonicalizeRange(arange, size) except UnsatisfiableRangeRequest: continue if start <= lastOffset: # Stupid client asking for out-of-order or overlapping ranges, PUNT! return oldresponse offsetList.append((start,end)) lastOffset = end if not offsetList: return makeUnsatisfiable(request, oldresponse) content_type = oldresponse.headers.getRawHeaders('content-type') boundary = "%x%x" % (int(time.time()*1000000), os.getpid()) response = http.Response(responsecode.PARTIAL_CONTENT, oldresponse.headers) response.headers.setHeader('content-type', http_headers.MimeType('multipart', 'byteranges', [('boundary', boundary)])) response.stream = out = stream.CompoundStream() lastOffset = 0 origStream = oldresponse.stream headerString = "\r\n--%s" % boundary if len(content_type) == 1: headerString+='\r\nContent-Type: %s' % content_type[0] headerString+="\r\nContent-Range: %s\r\n\r\n" for start,end in offsetList: out.addStream(headerString % http_headers.generateContentRange(('bytes', start, end, size))) content, origStream = makeSegment(origStream, lastOffset, start, end) lastOffset = end + 1 out.addStream(content) origStream.close() out.addStream("\r\n--%s--\r\n" % boundary) return response else: return oldresponse __all__ = ['rangefilter'] calendarserver-5.2+dfsg/twext/web2/filter/__init__.py0000644000175000017500000000023611337102650021716 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_cgi -*- # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # See LICENSE for details. """ Output filters. """ calendarserver-5.2+dfsg/twext/web2/filter/location.py0000644000175000017500000000170311337102650021767 0ustar rahulrahulfrom twext.web2 import responsecode import urlparse __all__ = ['addLocation'] def addLocation(request, location): """ Add a C{location} header to the response if the response status is CREATED. @param request: L{IRequest} the request being processed @param location: the URI to use in the C{location} header """ def locationFilter(request, response): if (response.code == responsecode.CREATED): # # Check to see whether we have an absolute URI or not. # If not, have the request turn it into an absolute URI. # (scheme, host, path, params, querystring, fragment) = urlparse.urlparse(location) if scheme == "": uri = request.unparseURL(path=location) else: uri = location response.headers.setHeader("location", uri) return response request.addResponseFilter(locationFilter) calendarserver-5.2+dfsg/twext/web2/iweb.py0000644000175000017500000002240512263343324017626 0ustar rahulrahul# -*- test-case-name: twext.web2.test -*- ## # Copyright (c) 2001-2008 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ I contain the interfaces for several web related objects including IRequest and IResource. I am based heavily on ideas from C{nevow.inevow}. """ from zope.interface import Attribute, Interface, interface # server.py interfaces class IResource(Interface): """ An HTTP resource. I serve 2 main purposes: one is to provide a standard representation for what HTTP specification calls an 'entity', and the other is to provide an mechanism for mapping URLs to content. """ def locateChild(req, segments): """ Locate another object which can be adapted to IResource. @return: A 2-tuple of (resource, remaining-path-segments), or a deferred which will fire the above. Causes the object publishing machinery to continue on with specified resource and segments, calling the appropriate method on the specified resource. If you return (self, L{server.StopTraversal}), this instructs web2 to immediately stop the lookup stage, and switch to the rendering stage, leaving the remaining path alone for your render function to handle. """ def renderHTTP(req): """ Return an IResponse or a deferred which will fire an IResponse. This response will be written to the web browser which initiated the request. """ # Is there a better way to do this than this funky extra class? _default = object() class SpecialAdaptInterfaceClass(interface.InterfaceClass): # A special adapter for IResource to handle the extra step of adapting # from IOldNevowResource-providing resources. def __call__(self, other, alternate=_default): result = super(SpecialAdaptInterfaceClass, self).__call__(other, alternate) if result is not alternate: return result result = IOldNevowResource(other, alternate) if result is not alternate: result = IResource(result) return result if alternate is not _default: return alternate raise TypeError('Could not adapt', other, self) IResource.__class__ = SpecialAdaptInterfaceClass class IOldNevowResource(Interface): # Shared interface with inevow.IResource """ I am a web resource. """ def locateChild(ctx, segments): """ Locate another object which can be adapted to IResource Return a tuple of resource, path segments """ def renderHTTP(ctx): """ Return a string or a deferred which will fire a string. This string will be written to the web browser which initiated this request. Unlike iweb.IResource, this expects the incoming data to have already been read and parsed into request.args and request.content, and expects to return a string instead of a response object. """ class ICanHandleException(Interface): # Shared interface with inevow.ICanHandleException def renderHTTP_exception(request, failure): """ Render an exception to the given request object. """ def renderInlineException(request, reason): """ Return stan representing the exception, to be printed in the page, not replacing the page.""" # http.py interfaces class IResponse(Interface): """ I'm a response. """ code = Attribute("The HTTP response code") headers = Attribute("A http_headers.Headers instance of headers to send") stream = Attribute("A stream.IByteStream of outgoing data, or else None.") class IRequest(Interface): """ I'm a request for a web resource. """ method = Attribute("The HTTP method from the request line, e.g. GET") uri = Attribute("The raw URI from the request line. May or may not include host.") clientproto = Attribute("Protocol from the request line, e.g. HTTP/1.1") headers = Attribute("A http_headers.Headers instance of incoming headers.") stream = Attribute("A stream.IByteStream of incoming data.") def writeResponse(response): """ Write an IResponse object to the client. """ chanRequest = Attribute("The ChannelRequest. I wonder if this is public really?") from twisted.web.iweb import IRequest as IOldRequest class IChanRequestCallbacks(Interface): """ The bits that are required of a Request for interfacing with a IChanRequest object """ def __init__(chanRequest, command, path, version, contentLength, inHeaders): """ Create a new Request object. @param chanRequest: the IChanRequest object creating this request @param command: the HTTP command e.g. GET @param path: the HTTP path e.g. /foo/bar.html @param version: the parsed HTTP version e.g. (1,1) @param contentLength: how much data to expect, or None if unknown @param inHeaders: the request headers""" def process(): """ Process the request. Called as soon as it's possibly reasonable to return a response. L{handleContentComplete} may or may not have been called already. """ def handleContentChunk(data): """ Called when a piece of incoming data has been received. """ def handleContentComplete(): """ Called when the incoming data stream is finished. """ def connectionLost(reason): """ Called if the connection was lost. """ class IChanRequest(Interface): def writeIntermediateResponse(code, headers=None): """ Write a non-terminating response. Intermediate responses cannot contain data. If the channel does not support intermediate responses, do nothing. @param code: The response code. Should be in the 1xx range. @type code: int @param headers: the headers to send in the response @type headers: C{twisted.web.http_headers.Headers} """ def writeHeaders(code, headers): """ Write a final response. @param code: The response code. Should not be in the 1xx range. @type code: int @param headers: the headers to send in the response. They will be augmented with any connection-oriented headers as necessary for the protocol. @type headers: C{twisted.web.http_headers.Headers} """ def write(data): """ Write some data. @param data: the data bytes @type data: str """ def finish(): """ Finish the request, and clean up the connection if necessary. """ def abortConnection(): """ Forcibly abort the connection without cleanly closing. Use if, for example, you can't write all the data you promised. """ def registerProducer(producer, streaming): """ Register a producer with the standard API. """ def unregisterProducer(): """ Unregister a producer. """ def getHostInfo(): """ Returns a tuple of (address, socket user connected to, boolean, was it secure). Note that this should not necessarily always return the actual local socket information from twisted. E.g. in a CGI, it should use the variables coming from the invoking script. """ def getRemoteHost(): """ Returns an address of the remote host. Like L{getHostInfo}, this information may come from the real socket, or may come from additional information, depending on the transport. """ persistent = Attribute("""Whether this request supports HTTP connection persistence. May be set to False. Should not be set to other values.""") class ISite(Interface): pass __all__ = ['ICanHandleException', 'IChanRequest', 'IChanRequestCallbacks', 'IOldNevowResource', 'IOldRequest', 'IRequest', 'IResource', 'IResponse', 'ISite'] calendarserver-5.2+dfsg/twext/web2/responsecode.py0000644000175000017500000001171312263343324021371 0ustar rahulrahul# -*- test-case-name: twext.web2.test -*- ## # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## CONTINUE = 100 SWITCHING = 101 OK = 200 CREATED = 201 ACCEPTED = 202 NON_AUTHORITATIVE_INFORMATION = 203 NO_CONTENT = 204 RESET_CONTENT = 205 PARTIAL_CONTENT = 206 MULTI_STATUS = 207 MULTIPLE_CHOICE = 300 MOVED_PERMANENTLY = 301 FOUND = 302 SEE_OTHER = 303 NOT_MODIFIED = 304 USE_PROXY = 305 TEMPORARY_REDIRECT = 307 BAD_REQUEST = 400 UNAUTHORIZED = 401 PAYMENT_REQUIRED = 402 FORBIDDEN = 403 NOT_FOUND = 404 NOT_ALLOWED = 405 NOT_ACCEPTABLE = 406 PROXY_AUTH_REQUIRED = 407 REQUEST_TIMEOUT = 408 CONFLICT = 409 GONE = 410 LENGTH_REQUIRED = 411 PRECONDITION_FAILED = 412 REQUEST_ENTITY_TOO_LARGE = 413 REQUEST_URI_TOO_LONG = 414 UNSUPPORTED_MEDIA_TYPE = 415 REQUESTED_RANGE_NOT_SATISFIABLE = 416 EXPECTATION_FAILED = 417 UNPROCESSABLE_ENTITY = 422 # RFC 2518 LOCKED = 423 # RFC 2518 FAILED_DEPENDENCY = 424 # RFC 2518 INTERNAL_SERVER_ERROR = 500 NOT_IMPLEMENTED = 501 BAD_GATEWAY = 502 SERVICE_UNAVAILABLE = 503 GATEWAY_TIMEOUT = 504 HTTP_VERSION_NOT_SUPPORTED = 505 LOOP_DETECTED = 506 INSUFFICIENT_STORAGE_SPACE = 507 NOT_EXTENDED = 510 RESPONSES = { # 100 CONTINUE: "Continue", SWITCHING: "Switching Protocols", # 200 OK: "OK", CREATED: "Created", ACCEPTED: "Accepted", NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", NO_CONTENT: "No Content", RESET_CONTENT: "Reset Content.", PARTIAL_CONTENT: "Partial Content", MULTI_STATUS: "Multi-Status", # 300 MULTIPLE_CHOICE: "Multiple Choices", MOVED_PERMANENTLY: "Moved Permanently", FOUND: "Found", SEE_OTHER: "See Other", NOT_MODIFIED: "Not Modified", USE_PROXY: "Use Proxy", # 306 unused TEMPORARY_REDIRECT: "Temporary Redirect", # 400 BAD_REQUEST: "Bad Request", UNAUTHORIZED: "Unauthorized", PAYMENT_REQUIRED: "Payment Required", FORBIDDEN: "Forbidden", NOT_FOUND: "Not Found", NOT_ALLOWED: "Method Not Allowed", NOT_ACCEPTABLE: "Not Acceptable", PROXY_AUTH_REQUIRED: "Proxy Authentication Required", REQUEST_TIMEOUT: "Request Time-out", CONFLICT: "Conflict", GONE: "Gone", LENGTH_REQUIRED: "Length Required", PRECONDITION_FAILED: "Precondition Failed", REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", REQUEST_URI_TOO_LONG: "Request-URI Too Long", UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range Not Satisfiable", EXPECTATION_FAILED: "Expectation Failed", UNPROCESSABLE_ENTITY: "Unprocessable Entity", LOCKED: "Locked", FAILED_DEPENDENCY: "Failed Dependency", # 500 INTERNAL_SERVER_ERROR: "Internal Server Error", NOT_IMPLEMENTED: "Not Implemented", BAD_GATEWAY: "Bad Gateway", SERVICE_UNAVAILABLE: "Service Unavailable", GATEWAY_TIMEOUT: "Gateway Time-out", HTTP_VERSION_NOT_SUPPORTED: "HTTP Version Not Supported", LOOP_DETECTED: "Loop In Linked or Bound Resource", INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", NOT_EXTENDED: "Not Extended" } # No __all__ necessary -- everything is exported calendarserver-5.2+dfsg/twext/web2/log.py0000644000175000017500000001620612263343324017463 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_log -*- ## # Copyright (c) 2001-2004 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ Logging tools. This is still in flux (even moreso than the rest of web2). NOTE: This is now using twext.python.log new-style logging and observers. """ import time from twisted.internet import defer from twext.python.log import Logger from twext.web2 import iweb, stream, resource from zope.interface import implements, Attribute, Interface log = Logger() class _LogByteCounter(object): implements(stream.IByteStream) def __init__(self, stream, done): self.stream=stream self.done=done self.len=0 length=property(lambda self: self.stream.length) def _callback(self, data): if data is None: if self.done: done=self.done; self.done=None done(True, self.len) else: self.len += len(data) return data def read(self): data = self.stream.read() if isinstance(data, defer.Deferred): return data.addCallback(self._callback) return self._callback(data) def close(self): if self.done: done=self.done; self.done=None done(False, self.len) self.stream.close() class ILogInfo(Interface): """Auxilliary information about the response useful for logging.""" bytesSent=Attribute("Number of bytes sent.") responseCompleted=Attribute("Whether or not the response was completed.") secondsTaken=Attribute("Number of seconds taken to serve the request.") startTime=Attribute("Time at which the request started") class LogInfo(object): implements(ILogInfo) responseCompleted=None secondsTaken=None bytesSent=None startTime=None def logFilter(request, response, startTime=None): if startTime is None: startTime = time.time() def _log(success, length): loginfo=LogInfo() loginfo.bytesSent=length loginfo.responseCompleted=success loginfo.secondsTaken=time.time()-startTime if length: request.timeStamp("t-resp-wr") log.info(interface=iweb.IRequest, request=request, response=response, loginfo=loginfo) # Or just... # ILogger(ctx).log(...) ? request.timeStamp("t-resp-gen") if response.stream: response.stream=_LogByteCounter(response.stream, _log) else: _log(True, 0) return response logFilter.handleErrors = True class LogWrapperResource(resource.WrapperResource): def hook(self, request): # Insert logger request.addResponseFilter(logFilter, atEnd=True, onlyOnce=True) monthname = [None, 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] class BaseCommonAccessLoggingObserver(object): """An abstract Twisted-based logger for creating access logs. Derived implementations of this class *must* implement the ``logMessage(message)`` method, which will send the message to an actual log/file or stream. """ logFormat = '%s - %s [%s] "%s" %s %d "%s" "%s"' def logMessage(self, message): raise NotImplemented, 'You must provide an implementation.' def computeTimezoneForLog(self, tz): if tz > 0: neg = 1 else: neg = 0 tz = -tz h, rem = divmod(tz, 3600) m, rem = divmod(rem, 60) if neg: return '-%02d%02d' % (h, m) else: return '+%02d%02d' % (h, m) tzForLog = None tzForLogAlt = None def logDateString(self, when): logtime = time.localtime(when) Y, M, D, h, m, s = logtime[:6] if not time.daylight: tz = self.tzForLog if tz is None: tz = self.computeTimezoneForLog(time.timezone) self.tzForLog = tz else: tz = self.tzForLogAlt if tz is None: tz = self.computeTimezoneForLog(time.altzone) self.tzForLogAlt = tz return '%02d/%s/%02d:%02d:%02d:%02d %s' % ( D, monthname[M], Y, h, m, s, tz) def emit(self, eventDict): if eventDict.get('interface') is not iweb.IRequest: return request = eventDict['request'] response = eventDict['response'] loginfo = eventDict['loginfo'] firstLine = '%s %s HTTP/%s' %( request.method, request.uri, '.'.join([str(x) for x in request.clientproto])) self.logMessage( '%s - %s [%s] "%s" %s %d "%s" "%s"' % ( request.remoteAddr.host, # XXX: Where to get user from? "-", self.logDateString( response.headers.getHeader('date', 0)), firstLine, response.code, loginfo.bytesSent, request.headers.getHeader('referer', '-'), request.headers.getHeader('user-agent', '-') ) ) def start(self): """Start observing log events.""" # Use the root publisher to bypass log level filtering log.publisher.addObserver(self.emit, filtered=False) def stop(self): """Stop observing log events.""" log.publisher.removeObserver(self.emit) class FileAccessLoggingObserver(BaseCommonAccessLoggingObserver): """I log requests to a single logfile """ def __init__(self, logpath): self.logpath = logpath def logMessage(self, message): self.f.write(message + '\n') def start(self): super(FileAccessLoggingObserver, self).start() self.f = open(self.logpath, 'a', 1) def stop(self): super(FileAccessLoggingObserver, self).stop() self.f.close() class DefaultCommonAccessLoggingObserver(BaseCommonAccessLoggingObserver): """Log requests to default twisted logfile.""" def logMessage(self, message): log.info(message) calendarserver-5.2+dfsg/twext/web2/metafd.py0000644000175000017500000003640212310121167020132 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_metafd -*- ## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Implementation of dispatching HTTP connections to child processes using L{twext.internet.sendfdport.InheritedSocketDispatcher}. """ from __future__ import print_function from zope.interface import implementer from twext.internet.sendfdport import ( InheritedPort, InheritedSocketDispatcher, InheritingProtocolFactory, IStatus) from twext.internet.tcp import MaxAcceptTCPServer from twext.python.log import Logger from twext.web2.channel.http import HTTPFactory from twisted.application.service import MultiService, Service from twisted.internet import reactor from twisted.python.util import FancyStrMixin from twisted.internet.tcp import Server from twext.internet.sendfdport import IStatusWatcher log = Logger() class JustEnoughLikeAPort(object): """ Fake out just enough of L{tcp.Port} to be acceptable to L{tcp.Server}... """ _realPortNumber = 'inherited' class ReportingHTTPService(Service, object): """ Service which starts up an HTTP server that can report back to its parent process via L{InheritedPort}. This is instantiated in the I{worker process}. @ivar site: a twext.web2 'site' object, i.e. a request factory @ivar fd: the file descriptor of a UNIX socket being used to receive connections from a master process calling accept() @type fd: C{int} @ivar contextFactory: A context factory for building SSL/TLS connections for inbound connections tagged with the string 'SSL' as their descriptive data, or None if SSL is not enabled for this server. @type contextFactory: L{twisted.internet.ssl.ContextFactory} or C{NoneType} """ _connectionCount = 0 def __init__(self, site, fd, contextFactory): self.contextFactory = contextFactory # Unlike other 'factory' constructions, config.MaxRequests and # config.MaxAccepts are dealt with in the master process, so we don't # need to propagate them here. self.site = site self.fd = fd def startService(self): """ Start reading on the inherited port. """ Service.startService(self) self.reportingFactory = ReportingHTTPFactory(self.site, vary=True) inheritedPort = self.reportingFactory.inheritedPort = InheritedPort( self.fd, self.createTransport, self.reportingFactory ) inheritedPort.startReading() inheritedPort.reportStatus("0") def stopService(self): """ Stop reading on the inherited port. @return: a Deferred which fires after the last outstanding request is complete. """ Service.stopService(self) # XXX stopping should really be destructive, because otherwise we will # always leak a file descriptor; i.e. this shouldn't be restartable. self.reportingFactory.inheritedPort.stopReading() # Let any outstanding requests finish return self.reportingFactory.allConnectionsClosed() def createTransport(self, skt, peer, data, protocol): """ Create a TCP transport, from a socket object passed by the parent. """ self._connectionCount += 1 transport = Server(skt, protocol, peer, JustEnoughLikeAPort, self._connectionCount, reactor) if data == 'SSL': transport.startTLS(self.contextFactory) transport.startReading() return transport class ReportingHTTPFactory(HTTPFactory): """ An L{HTTPFactory} which reports its status to a L{InheritedPort}. Since this is processing application-level bytes, it is of course instantiated in the I{worker process}, as is L{InheritedPort}. @ivar inheritedPort: an L{InheritedPort} to report status (the current number of outstanding connections) to. Since this - the L{ReportingHTTPFactory} - needs to be instantiated to be passed to L{InheritedPort}'s constructor, this attribute must be set afterwards but before any connections have occurred. """ def _report(self, message): """ Report a status message to the parent. """ self.inheritedPort.reportStatus(message) def addConnectedChannel(self, channel): """ Add the connected channel, and report the current number of open channels to the listening socket in the parent process. """ HTTPFactory.addConnectedChannel(self, channel) self._report("+") def removeConnectedChannel(self, channel): """ Remove the connected channel, and report the current number of open channels to the listening socket in the parent process. """ HTTPFactory.removeConnectedChannel(self, channel) self._report("-") @implementer(IStatus) class WorkerStatus(FancyStrMixin, object): """ The status of a worker process. """ showAttributes = ("acknowledged unacknowledged total started abandoned unclosed starting stopped" .split()) def __init__( self, acknowledged=0, unacknowledged=0, total=0, started=0, abandoned=0, unclosed=0, starting=1, stopped=0 ): """ Create a L{ConnectionStatus} with a number of sent connections and a number of un-acknowledged connections. @param acknowledged: the number of connections which we know the subprocess to be presently processing; i.e. those which have been transmitted to the subprocess. @param unacknowledged: The number of connections which we have sent to the subprocess which have never received a status response (a "C{+}" status message). @param total: The total number of acknowledged connections over the lifetime of this socket. @param started: The number of times this worker has been started. @param abandoned: The number of connections which have been sent to this worker, but were not acknowledged at the moment that the worker was stopped. @param unclosed: The number of sockets which have been sent to the subprocess but not yet closed. @param starting: The process that owns this socket is starting. Do not dispatch to it until we receive the started message. @param stopped: The process that owns this socket has stopped. Do not dispatch to it. """ self.acknowledged = acknowledged self.unacknowledged = unacknowledged self.total = total self.started = started self.abandoned = abandoned self.unclosed = unclosed self.starting = starting self.stopped = stopped def effective(self): """ The current effective load. """ return self.acknowledged + self.unacknowledged def active(self): """ Is the subprocess associated with this socket available to dispatch to. i.e, this socket is neither stopped nor starting """ return self.starting == 0 and self.stopped == 0 def start(self): """ The child process for this L{WorkerStatus} is about to (re)start. Reset the status to indicate it is starting - that should prevent any new connections being dispatched. """ return self.reset( starting=1, stopped=0, ) def restarted(self): """ The child process for this L{WorkerStatus} has indicated it is now available to accept connections, so reset the starting status so this socket will be available for dispatch. """ return self.reset( started=self.started + 1, starting=0, ) def stop(self): """ The child process for this L{WorkerStatus} has stopped. Stop the socket and clear out existing counters, but track abandoned connections. """ return self.reset( acknowledged=0, unacknowledged=0, abandoned=self.abandoned + self.unacknowledged, starting=0, stopped=1, ) def adjust(self, **kwargs): """ Update the L{WorkerStatus} by adding the supplied values to the specified attributes. """ for k, v in kwargs.items(): newval = getattr(self, k) + v setattr(self, k, max(newval, 0)) return self def reset(self, **kwargs): """ Reset the L{WorkerStatus} by setting the supplied values in the specified attributes. """ for k, v in kwargs.items(): setattr(self, k, v) return self @implementer(IStatusWatcher) class ConnectionLimiter(MultiService, object): """ Connection limiter for use with L{InheritedSocketDispatcher}. This depends on statuses being reported by L{ReportingHTTPFactory} """ _outstandingRequests = 0 def __init__(self, maxAccepts, maxRequests): """ Create a L{ConnectionLimiter} with an associated dispatcher and list of factories. """ MultiService.__init__(self) self.factories = [] # XXX dispatcher needs to be a service, so that it can shut down its # sub-sockets. self.dispatcher = InheritedSocketDispatcher(self) self.maxAccepts = maxAccepts self.maxRequests = maxRequests self.overloaded = False def startService(self): """ Start up multiservice, then start up the dispatcher. """ super(ConnectionLimiter, self).startService() self.dispatcher.startDispatching() def addPortService(self, description, port, interface, backlog, serverServiceMaker=MaxAcceptTCPServer): """ Add a L{MaxAcceptTCPServer} to bind a TCP port to a socket description. """ lipf = LimitingInheritingProtocolFactory(self, description) self.factories.append(lipf) serverServiceMaker( port, lipf, interface=interface, backlog=backlog ).setServiceParent(self) # IStatusWatcher def initialStatus(self): """ The status of a new worker added to the pool. """ return WorkerStatus() def statusFromMessage(self, previousStatus, message): """ Determine a subprocess socket's status from its previous status and a status message. """ if message == '-': # A connection has gone away in a subprocess; we should start # accepting connections again if we paused (see # newConnectionStatus) return previousStatus.adjust(acknowledged=-1) elif message == '0': # A new process just started accepting new connections. return previousStatus.restarted() else: # '+' acknowledges that the subprocess has taken on the work. return previousStatus.adjust( acknowledged=1, unacknowledged=-1, total=1, unclosed=1, ) def closeCountFromStatus(self, status): """ Determine the number of sockets to close from the current status. """ toClose = status.unclosed return (toClose, status.adjust(unclosed=-toClose)) def newConnectionStatus(self, previousStatus): """ A connection was just sent to the process, but not yet acknowledged. """ return previousStatus.adjust(unacknowledged=1) def statusesChanged(self, statuses): """ The L{InheritedSocketDispatcher} is reporting that the list of connection-statuses have changed. Check to see if we are overloaded or if there are no active processes left. If so, stop the protocol factory from processing more requests until capacity is back. (The argument to this function is currently duplicated by the C{self.dispatcher.statuses} attribute, which is what C{self.outstandingRequests} uses to compute it.) """ current = sum(status.effective() for status in self.dispatcher.statuses) self._outstandingRequests = current # preserve for or= field in log maximum = self.maxRequests overloaded = (current >= maximum) available = len(filter(lambda x: x.active(), self.dispatcher.statuses)) self.overloaded = (overloaded or available == 0) for f in self.factories: if self.overloaded: f.loadAboveMaximum() else: f.loadNominal() @property # make read-only def outstandingRequests(self): return self._outstandingRequests class LimitingInheritingProtocolFactory(InheritingProtocolFactory): """ An L{InheritingProtocolFactory} that supports the implicit factory contract required by L{MaxAcceptTCPServer}/L{MaxAcceptTCPPort}. Since L{InheritingProtocolFactory} is instantiated in the I{master process}, so is L{LimitingInheritingProtocolFactory}. @ivar outstandingRequests: a read-only property for the number of currently active connections. @ivar maxAccepts: The maximum number of times to call 'accept()' in a single reactor loop iteration. @ivar maxRequests: The maximum number of concurrent connections to accept at once - note that this is for the I{entire server}, whereas the value in the configuration file is for only a single process. """ def __init__(self, limiter, description): super(LimitingInheritingProtocolFactory, self).__init__( limiter.dispatcher, description) self.limiter = limiter self.maxAccepts = limiter.maxAccepts self.maxRequests = limiter.maxRequests self.stopping = False def stopFactory(self): """ Mark this factory as being stopped to prevent attempts to start reading on its port again when the limiter statuses change during shutdown. """ super(LimitingInheritingProtocolFactory, self).stopFactory() self.stopping = True def loadAboveMaximum(self): """ The current server load has exceeded the maximum allowable. """ self.myServer.myPort.stopReading() def loadNominal(self): """ The current server load is nominal; proceed with reading requests (but only if the server itself is still running). """ if not self.stopping: self.myServer.myPort.startReading() @property def outstandingRequests(self): return self.limiter.outstandingRequests calendarserver-5.2+dfsg/twext/web2/_version.py0000644000175000017500000000021011337102650020506 0ustar rahulrahul# This is an auto-generated file. Do not edit it. from twisted.python import versions version = versions.Version('twext.web2', 9, 0, 0) calendarserver-5.2+dfsg/twext/web2/__init__.py0000644000175000017500000000332412263343324020436 0ustar rahulrahul# -*- test-case-name: twext.web2.test -*- ## # Copyright (c) 2009 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ twext.web2: a transitional package for Calendar Server to move from a dependence on twisted.web2 to twisted.web. This is a copy of (most of) twisted.web2, but the intention is for this package to disappear and gradually get replaced with twisted.web. Features from this package are being merged into L{twisted.web}. Once that is complete, this package will be removed. See U{http://twistedmatrix.com/trac/wiki/WebDevelopmentWithTwisted}. """ from twext.web2._version import version __version__ = version.short() calendarserver-5.2+dfsg/twext/web2/static.py0000644000175000017500000005000412263343324020163 0ustar rahulrahul# -*- test-case-name: twext.web2.test.test_static -*- ## # Copyright (c) 2001-2008 Twisted Matrix Laboratories. # Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ## """ I deal with static resources. """ # System Imports import os, time import tempfile # Sibling Imports from twext.web2 import http_headers, resource from twext.web2 import http, iweb, stream, responsecode, server from twext.web2.http import HTTPError # Twisted Imports from twext.python.filepath import CachingFilePath as FilePath from twisted.internet.defer import inlineCallbacks, returnValue, succeed from zope.interface import implements class MetaDataMixin(object): """ Mix-in class for L{iweb.IResource} which provides methods for accessing resource metadata specified by HTTP. """ def etag(self): """ @return: The current etag for the resource if available, None otherwise. """ return succeed(None) def lastModified(self): """ @return: The last modified time of the resource if available, None otherwise. """ return None def creationDate(self): """ @return: The creation date of the resource if available, None otherwise. """ return None def contentLength(self): """ @return: The size in bytes of the resource if available, None otherwise. """ return None def contentType(self): """ @return: The MIME type of the resource if available, None otherwise. """ return None def contentEncoding(self): """ @return: The encoding of the resource if available, None otherwise. """ return None def displayName(self): """ @return: The display name of the resource if available, None otherwise. """ return None def exists(self): """ @return: True if the resource exists on the server, False otherwise. """ return True class StaticRenderMixin(resource.RenderMixin, MetaDataMixin): @inlineCallbacks def checkPreconditions(self, request): # This code replaces the code in resource.RenderMixin if request.method not in ("GET", "HEAD"): etag = (yield self.etag()) http.checkPreconditions( request, entityExists = self.exists(), etag = etag, lastModified = self.lastModified(), ) # Check per-method preconditions method = getattr(self, "preconditions_" + request.method, None) if method: returnValue((yield method(request))) @inlineCallbacks def renderHTTP(self, request): """ See L{resource.RenderMixIn.renderHTTP}. This implementation automatically sets some headers on the response based on data available from L{MetaDataMixin} methods. """ try: response = yield super(StaticRenderMixin, self).renderHTTP(request) except HTTPError, he: response = he.response response = iweb.IResponse(response) # Don't provide additional resource information to error responses if response.code < 400: # Content-* headers refer to the response content, not # (necessarily) to the resource content, so they depend on the # request method, and therefore can't be set here. etag = (yield self.etag()) for (header, value) in ( ("etag", etag), ("last-modified", self.lastModified()), ): if value is not None: response.headers.setHeader(header, value) returnValue(response) class Data(resource.Resource): """ This is a static, in-memory resource. """ def __init__(self, data, type): self.data = data self.type = http_headers.MimeType.fromString(type) self.created_time = time.time() def etag(self): lastModified = self.lastModified() return succeed(http_headers.ETag("%X-%X" % (lastModified, hash(self.data)), weak=(time.time() - lastModified <= 1))) def lastModified(self): return self.creationDate() def creationDate(self): return self.created_time def contentLength(self): return len(self.data) def contentType(self): return self.type def render(self, req): return http.Response( responsecode.OK, http_headers.Headers({'content-type': self.contentType()}), stream=self.data) class File(StaticRenderMixin): """ File is a resource that represents a plain non-interpreted file (although it can look for an extension like .rpy or .cgi and hand the file to a processor for interpretation if you wish). Its constructor takes a file path. Alternatively, you can give a directory path to the constructor. In this case the resource will represent that directory, and its children will be files underneath that directory. This provides access to an entire filesystem tree with a single Resource. If you map the URL C{http://server/FILE} to a resource created as File('/tmp'), C{http://server/FILE/foo/bar.html} will return the contents of C{/tmp/foo/bar.html} . """ implements(iweb.IResource) def _getContentTypes(self): if not hasattr(File, "_sharedContentTypes"): File._sharedContentTypes = loadMimeTypes() return File._sharedContentTypes contentTypes = property(_getContentTypes) contentEncodings = { ".gz" : "gzip", ".bz2": "bzip2" } processors = {} indexNames = ["index", "index.html", "index.htm", "index.trp", "index.rpy"] type = None def __init__(self, path, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None): """Create a file with the given path. """ super(File, self).__init__() self.putChildren = {} if isinstance(path, FilePath): self.fp = path else: assert isinstance(path, str), "This should be a string." self.fp = FilePath(path) # Remove the dots from the path to split self.defaultType = defaultType self.ignoredExts = list(ignoredExts) if processors is not None: self.processors = dict([ (key.lower(), value) for key, value in processors.items() ]) if indexNames is not None: self.indexNames = indexNames def comparePath(self, path): if isinstance(path, FilePath): return path.path == self.fp.path else: return path == self.fp.path def exists(self): return self.fp.exists() def etag(self): if not self.fp.exists(): return succeed(None) st = self.fp.statinfo # # Mark ETag as weak if it was modified more recently than we can # measure and report, as it could be modified again in that span # and we then wouldn't know to provide a new ETag. # weak = (time.time() - st.st_mtime <= 1) return succeed(http_headers.ETag( "%X-%X-%X" % (st.st_ino, st.st_size, st.st_mtime), weak=weak )) def lastModified(self): if self.fp.exists(): return self.fp.getmtime() else: return None def creationDate(self): if self.fp.exists(): return self.fp.getmtime() else: return None def contentLength(self): if self.fp.exists(): if self.fp.isfile(): return self.fp.getsize() else: # Computing this would require rendering the resource; let's # punt instead. return None else: return None def _initTypeAndEncoding(self): self._type, self._encoding = getTypeAndEncoding( self.fp.basename(), self.contentTypes, self.contentEncodings, self.defaultType ) # Handle cases not covered by getTypeAndEncoding() if self.fp.isdir(): self._type = "httpd/unix-directory" def contentType(self): if not hasattr(self, "_type"): self._initTypeAndEncoding() return http_headers.MimeType.fromString(self._type) def contentEncoding(self): if not hasattr(self, "_encoding"): self._initTypeAndEncoding() return self._encoding def displayName(self): if self.fp.exists(): return self.fp.basename() else: return None def ignoreExt(self, ext): """Ignore the given extension. Serve file.ext if file is requested """ self.ignoredExts.append(ext) def putChild(self, name, child): """ Register a child with the given name with this resource. @param name: the name of the child (a URI path segment) @param child: the child to register """ self.putChildren[name] = child def getChild(self, name): """ Look up a child resource. @return: the child of this resource with the given name. """ if name == "": return self child = self.putChildren.get(name, None) if child: return child child_fp = self.fp.child(name) if hasattr(self, "knownChildren"): if name in self.knownChildren: child_fp.existsCached = True if child_fp.exists(): return self.createSimilarFile(child_fp) else: return None def listChildren(self): """ @return: a sequence of the names of all known children of this resource. """ children = self.putChildren.keys() if self.fp.isdir(): children += [c for c in self.fp.listdir() if c not in children] self.knownChildren = set(children) return children def locateChild(self, req, segments): """ See L{IResource}C{.locateChild}. """ # If getChild() finds a child resource, return it child = self.getChild(segments[0]) if child is not None: return (child, segments[1:]) # If we're not backed by a directory, we have no children. # But check for existance first; we might be a collection resource # that the request wants created. self.fp.restat(False) if self.fp.exists() and not self.fp.isdir(): return (None, ()) # OK, we need to return a child corresponding to the first segment path = segments[0] if path: fpath = self.fp.child(path) else: # Request is for a directory (collection) resource return (self, server.StopTraversal) # Don't run processors on directories - if someone wants their own # customized directory rendering, subclass File instead. if fpath.isfile(): processor = self.processors.get(fpath.splitext()[1].lower()) if processor: return ( processor(fpath.path), segments[1:]) elif not fpath.exists(): sibling_fpath = fpath.siblingExtensionSearch(*self.ignoredExts) if sibling_fpath is not None: fpath = sibling_fpath return self.createSimilarFile(fpath.path), segments[1:] def renderHTTP(self, req): self.fp.changed() return super(File, self).renderHTTP(req) def render(self, req): """You know what you doing.""" if not self.fp.exists(): return responsecode.NOT_FOUND if self.fp.isdir(): if req.path[-1] != "/": # Redirect to include trailing '/' in URI return http.RedirectResponse(req.unparseURL(path=req.path+'/')) else: ifp = self.fp.childSearchPreauth(*self.indexNames) if ifp: # Render from the index file standin = self.createSimilarFile(ifp.path) else: # Directory listing is in twistedcaldav.extensions standin = Data( "\n".join(["Directory: " + str(req.path), "---"] + [x.basename() + ("/" if x.isdir() else "") for x in self.fp.children()]), "text/plain") return standin.render(req) try: f = self.fp.open() except IOError, e: import errno if e[0] == errno.EACCES: return responsecode.FORBIDDEN elif e[0] == errno.ENOENT: return responsecode.NOT_FOUND else: raise response = http.Response() response.stream = stream.FileStream(f, 0, self.fp.getsize()) for (header, value) in ( ("content-type", self.contentType()), ("content-encoding", self.contentEncoding()), ): if value is not None: response.headers.setHeader(header, value) return response def createSimilarFile(self, path): return self.__class__(path, self.defaultType, self.ignoredExts, self.processors, self.indexNames[:]) class FileSaver(resource.PostableResource): allowedTypes = (http_headers.MimeType('text', 'plain'), http_headers.MimeType('text', 'html'), http_headers.MimeType('text', 'css')) def __init__(self, destination, expectedFields=[], allowedTypes=None, maxBytes=1000000, permissions=0644): self.destination = destination self.allowedTypes = allowedTypes or self.allowedTypes self.maxBytes = maxBytes self.expectedFields = expectedFields self.permissions = permissions def makeUniqueName(self, filename): """Called when a unique filename is needed. filename is the name of the file as given by the client. Returns the fully qualified path of the file to create. The file must not yet exist. """ return tempfile.mktemp(suffix=os.path.splitext(filename)[1], dir=self.destination) def isSafeToWrite(self, filename, mimetype, filestream): """Returns True if it's "safe" to write this file, otherwise it raises an exception. """ if filestream.length > self.maxBytes: raise IOError("%s: File exceeds maximum length (%d > %d)" % (filename, filestream.length, self.maxBytes)) if mimetype not in self.allowedTypes: raise IOError("%s: File type not allowed %s" % (filename, mimetype)) return True def writeFile(self, filename, mimetype, fileobject): """Does the I/O dirty work after it calls isSafeToWrite to make sure it's safe to write this file. """ filestream = stream.FileStream(fileobject) if self.isSafeToWrite(filename, mimetype, filestream): outname = self.makeUniqueName(filename) flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | getattr(os, "O_BINARY", 0) fileobject = os.fdopen(os.open(outname, flags, self.permissions), 'wb', 0) stream.readIntoFile(filestream, fileobject) return outname def render(self, req): content = [""] if req.files: for fieldName in req.files: if fieldName in self.expectedFields: for finfo in req.files[fieldName]: try: outname = self.writeFile(*finfo) content.append("Saved file %s
" % outname) except IOError, err: content.append(str(err) + "
") else: content.append("%s is not a valid field" % fieldName) else: content.append("No files given") content.append("") return http.Response(responsecode.OK, {}, stream='\n'.join(content)) # FIXME: hi there I am a broken class # """I contain AsIsProcessor, which serves files 'As Is' # Inspired by Apache's mod_asis # """ # # class ASISProcessor: # implements(iweb.IResource) # # def __init__(self, path): # self.path = path # # def renderHTTP(self, request): # request.startedWriting = 1 # return File(self.path) # # def locateChild(self, request): # return None, () ## # Utilities ## dangerousPathError = http.HTTPError(responsecode.NOT_FOUND) #"Invalid request URL." def isDangerous(path): return path == '..' or '/' in path or os.sep in path def addSlash(request): return "http%s://%s%s/" % ( request.isSecure() and 's' or '', request.getHeader("host"), (request.uri.split('?')[0])) def loadMimeTypes(mimetype_locations=['/etc/mime.types']): """ Multiple file locations containing mime-types can be passed as a list. The files will be sourced in that order, overriding mime-types from the files sourced beforehand, but only if a new entry explicitly overrides the current entry. """ import mimetypes # Grab Python's built-in mimetypes dictionary. contentTypes = mimetypes.types_map #@UndefinedVariable # Update Python's semi-erroneous dictionary with a few of the # usual suspects. contentTypes.update( { '.conf': 'text/plain', '.diff': 'text/plain', '.exe': 'application/x-executable', '.flac': 'audio/x-flac', '.java': 'text/plain', '.ogg': 'application/ogg', '.oz': 'text/x-oz', '.swf': 'application/x-shockwave-flash', '.tgz': 'application/x-gtar', '.wml': 'text/vnd.wap.wml', '.xul': 'application/vnd.mozilla.xul+xml', '.py': 'text/plain', '.patch': 'text/plain', } ) # Users can override these mime-types by loading them out configuration # files (this defaults to ['/etc/mime.types']). for location in mimetype_locations: if os.path.exists(location): contentTypes.update(mimetypes.read_mime_types(location)) return contentTypes def getTypeAndEncoding(filename, types, encodings, defaultType): p, ext = os.path.splitext(filename) ext = ext.lower() if encodings.has_key(ext): enc = encodings[ext] ext = os.path.splitext(p)[1].lower() else: enc = None type = types.get(ext, defaultType) return type, enc ## # Test code ## if __name__ == '__builtin__': # Running from twistd -y from twisted.application import service, strports res = File('/') application = service.Application("demo") s = strports.service('8080', server.Site(res)) s.setServiceParent(application) calendarserver-5.2+dfsg/twext/__init__.py0000644000175000017500000000127412263343324017601 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extensions to the Twisted Framework. """ from twext import patches patches del(patches) calendarserver-5.2+dfsg/twext/internet/0000755000175000017500000000000012322625326017315 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/internet/decorate.py0000644000175000017500000001065612263343324021464 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Decorators. """ __all__ = [ "memoizedKey", ] from inspect import getargspec from twisted.internet.defer import Deferred, succeed class Memoizable(object): """ A class that stores itself in the memo dictionary. """ def memoMe(self, key, memo): """ Add this object to the memo dictionary in whatever fashion is appropriate. @param key: key used for lookup @type key: C{object} (typically C{str} or C{int}) @param memo: the dict to store to @type memo: C{dict} """ raise NotImplementedError def memoizedKey(keyArgument, memoAttribute, deferredResult=True): """ Decorator which memoizes the result of a method on that method's instance. If the instance is derived from class Memoizable, then the memoMe method is used to store the result, otherwise it is stored directly in the dict. @param keyArgument: The name of the "key" argument. @type keyArgument: C{str} @param memoAttribute: The name of the attribute on the instance which should be used for memoizing the result of this method; the attribute itself must be a dictionary. Alternately, if the specified argument is callable, it is a callable that takes the arguments passed to the decorated method and returns the memo dictionaries. @type memoAttribute: C{str} or C{callable} @param deferredResult: Whether the result must be a deferred. """ def getarg(argname, argspec, args, kw): """ Get an argument from some arguments. @param argname: The name of the argument to retrieve. @param argspec: The result of L{inspect.getargspec}. @param args: positional arguments passed to the function specified by argspec. @param kw: keyword arguments passed to the function specified by argspec. @return: The value of the argument named by 'argname'. """ argnames = argspec[0] try: argpos = argnames.index(argname) except ValueError: argpos = None if argpos is not None: if len(args) > argpos: return args[argpos] if argname in kw: return kw[argname] else: raise TypeError("could not find key argument %r in %r/%r (%r)" % (argname, args, kw, argpos) ) def decorate(thunk): # cheater move to try to get the right argspec from inlineCallbacks. # This could probably be more robust, but the 'cell_contents' thing # probably can't (that's the only real reference to the underlying # function). if thunk.func_code.co_name == "unwindGenerator": specTarget = thunk.func_closure[0].cell_contents else: specTarget = thunk spec = getargspec(specTarget) def outer(*a, **kw): self = a[0] if callable(memoAttribute): memo = memoAttribute(*a, **kw) else: memo = getattr(self, memoAttribute) key = getarg(keyArgument, spec, a, kw) if key in memo: memoed = memo[key] if deferredResult: return succeed(memoed) else: return memoed result = thunk(*a, **kw) if isinstance(result, Deferred): def memoResult(finalResult): if isinstance(finalResult, Memoizable): finalResult.memoMe(key, memo) elif finalResult is not None: memo[key] = finalResult return finalResult result.addCallback(memoResult) elif result is not None: memo[key] = result return result return outer return decorate calendarserver-5.2+dfsg/twext/internet/threadutils.py0000644000175000017500000000631712263343324022225 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## import sys from Queue import Queue from twisted.python.failure import Failure from twisted.internet.defer import Deferred _DONE = object() _STATE_STOPPED = 'STOPPED' _STATE_RUNNING = 'RUNNING' _STATE_STOPPING = 'STOPPING' class ThreadHolder(object): """ A queue which will hold a reactor threadpool thread open until all of the work in that queue is done. """ def __init__(self, reactor): self._reactor = reactor self._state = _STATE_STOPPED self._stopper = None self._q = None def _run(self): """ Worker function which runs in a non-reactor thread. """ while self._qpull(): pass def _qpull(self): """ Pull one item off the queue and react appropriately. Return whether or not to keep going. """ work = self._q.get() if work is _DONE: def finishStopping(): self._state = _STATE_STOPPED self._q = None s = self._stopper self._stopper = None s.callback(None) self._reactor.callFromThread(finishStopping) return False self._oneWorkUnit(*work) return True def _oneWorkUnit(self, deferred, instruction): try: result = instruction() except: etype, evalue, etb = sys.exc_info() def relayFailure(): f = Failure(evalue, etype, etb) deferred.errback(f) self._reactor.callFromThread(relayFailure) else: self._reactor.callFromThread(deferred.callback, result) def submit(self, work): """ Submit some work to be run. @param work: a 0-argument callable, which will be run in a thread. @return: L{Deferred} that fires with the result of L{work} """ if self._state != _STATE_RUNNING: raise RuntimeError("not running") d = Deferred() self._q.put((d, work)) return d def start(self): """ Start this thing, if it's stopped. """ if self._state != _STATE_STOPPED: raise RuntimeError("Not stopped.") self._state = _STATE_RUNNING self._q = Queue(0) self._reactor.callInThread(self._run) def stop(self): """ Stop this thing and release its thread, if it's running. """ if self._state != _STATE_RUNNING: raise RuntimeError("Not running.") s = self._stopper = Deferred() self._state = _STATE_STOPPING self._q.put(_DONE) return s calendarserver-5.2+dfsg/twext/internet/fswatch.py0000644000175000017500000001126212263343324021327 0ustar rahulrahul## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Watch the availablity of a file system directory """ import os from zope.interface import Interface from twisted.internet import reactor from twisted.python.log import Logger try: from select import (kevent, KQ_FILTER_VNODE, KQ_EV_ADD, KQ_EV_ENABLE, KQ_EV_CLEAR, KQ_NOTE_DELETE, KQ_NOTE_RENAME, KQ_EV_EOF) kqueueSupported = True except ImportError: # kqueue not supported on this platform kqueueSupported = False class IDirectoryChangeListenee(Interface): """ A delegate of DirectoryChangeListener """ def disconnected(): #@NoSelf """ The directory has been unmounted """ def deleted(): #@NoSelf """ The directory has been deleted """ def renamed(): #@NoSelf """ The directory has been renamed """ def connectionLost(reason): #@NoSelf """ The file descriptor has been closed """ #TODO: better way to tell if reactor is kqueue or not if kqueueSupported and hasattr(reactor, "_doWriteOrRead"): def patchReactor(reactor): # Wrap _doWriteOrRead to support KQ_FILTER_VNODE origDoWriteOrRead = reactor._doWriteOrRead def _doWriteOrReadOrVNodeEvent(selectable, fd, event): origDoWriteOrRead(selectable, fd, event) if event.filter == KQ_FILTER_VNODE: selectable.vnodeEventHappened(event) reactor._doWriteOrRead = _doWriteOrReadOrVNodeEvent patchReactor(reactor) class DirectoryChangeListener(Logger, object): """ Listens for the removal, renaming, or general unavailability of a given directory, and lets a delegate listenee know about them. """ def __init__(self, reactor, dirname, listenee): """ @param reactor: the reactor @param dirname: the full path to the directory to watch; it must already exist @param listenee: the delegate to call @type listenee: IDirectoryChangeListenee """ self._reactor = reactor self._fd = os.open(dirname, os.O_RDONLY) self._dirname = dirname self._listenee = listenee def logPrefix(self): return repr(self._dirname) def fileno(self): return self._fd def vnodeEventHappened(self, evt): if evt.flags & KQ_EV_EOF: self._listenee.disconnected() if evt.fflags & KQ_NOTE_DELETE: self._listenee.deleted() if evt.fflags & KQ_NOTE_RENAME: self._listenee.renamed() def startListening(self): ke = kevent(self._fd, filter=KQ_FILTER_VNODE, flags=(KQ_EV_ADD | KQ_EV_ENABLE | KQ_EV_CLEAR), fflags=KQ_NOTE_DELETE | KQ_NOTE_RENAME) self._reactor._kq.control([ke], 0, None) self._reactor._selectables[self._fd] = self def connectionLost(self, reason): os.close(self._fd) self._listenee.connectionLost(reason) else: # TODO: implement this for systems without kqueue support: class DirectoryChangeListener(Logger, object): """ Listens for the removal, renaming, or general unavailability of a given directory, and lets a delegate listenee know about them. """ def __init__(self, reactor, dirname, listenee): """ @param reactor: the reactor @param dirname: the full path to the directory to watch @param listenee: """ self._reactor = reactor self._fd = os.open(dirname, os.O_RDONLY) self._dirname = dirname self._listenee = listenee def logPrefix(self): return repr(self._dirname) def fileno(self): return self._fd def vnodeEventHappened(self, evt): pass def startListening(self): pass def connectionLost(self, reason): os.close(self._fd) self._listenee.connectionLost(reason) calendarserver-5.2+dfsg/twext/internet/ssl.py0000644000175000017500000000361712263343324020476 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extentions to twisted.internet.ssl. """ __all__ = [ "ChainingOpenSSLContextFactory", ] from OpenSSL.SSL import Context as SSLContext, SSLv3_METHOD from twisted.internet.ssl import DefaultOpenSSLContextFactory class ChainingOpenSSLContextFactory (DefaultOpenSSLContextFactory): def __init__( self, privateKeyFileName, certificateFileName, sslmethod=SSLv3_METHOD, certificateChainFile=None, passwdCallback=None, ciphers=None ): self.certificateChainFile = certificateChainFile self.passwdCallback = passwdCallback self.ciphers = ciphers DefaultOpenSSLContextFactory.__init__( self, privateKeyFileName, certificateFileName, sslmethod=sslmethod ) def cacheContext(self): # Unfortunate code duplication. ctx = SSLContext(self.sslmethod) if self.ciphers is not None: ctx.set_cipher_list(self.ciphers) if self.passwdCallback is not None: ctx.set_passwd_cb(self.passwdCallback) ctx.use_certificate_file(self.certificateFileName) ctx.use_privatekey_file(self.privateKeyFileName) if self.certificateChainFile != "": ctx.use_certificate_chain_file(self.certificateChainFile) self._context = ctx calendarserver-5.2+dfsg/twext/internet/spawnsvc.py0000644000175000017500000001644512263343324021544 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Utility service that can spawn subprocesses. """ import os import sys from twisted.python import log from twisted.python.reflect import namedAny from twisted.internet.stdio import StandardIO from twisted.internet.error import ReactorNotRunning if __name__ == '__main__': sys.stdout = sys.stderr there = sys.argv[1] protocolClass = namedAny(there) proto = protocolClass() origLost = proto.connectionLost def goodbye(reason): """ Stop the process if stdin is closed. """ try: reactor.stop() except ReactorNotRunning: pass return origLost(reason) proto.connectionLost = goodbye StandardIO(proto) from twisted.internet import reactor reactor.run() os._exit(0) import sys from zope.interface import implements from twisted.internet.interfaces import ITransport, IPushProducer, IConsumer from twisted.application.service import Service from twisted.python.reflect import qual from twisted.internet.protocol import ProcessProtocol from twisted.internet.defer import Deferred, succeed class BridgeTransport(object): """ ITransport implementation for the protocol in the parent process running a L{SpawnerService}. """ implements(ITransport, IPushProducer, IConsumer) def __init__(self, processTransport): """ Create this bridge transport connected to an L{IProcessTransport}. """ self.transport = processTransport def __getattr__(self, name): """ Delegate all attribute accesses to the process traansport. """ return getattr(self.transport, name) def getPeer(self): """ Get a fake peer address indicating the subprocess's pid. """ return "Peer:PID:" + str(self.transport.pid) def getHost(self): """ Get a fake host address indicating the subprocess's pid. """ return "Host:PID:" + str(self.transport.pid) class BridgeProtocol(ProcessProtocol, object): """ Process protocol implementation that delivers data to the C{hereProto} associated with an invocation of L{SpawnerService.spawn}. @ivar service: a L{SpawnerService} that created this L{BridgeProtocol} @ivar protocol: a reference to the L{IProtocol}. @ivar killTimeout: number of seconds after sending SIGINT that this process will send SIGKILL. """ def __init__(self, service, protocol, killTimeout=15.0): self.service = service self.protocol = protocol self.killTimeout = killTimeout self.service.addBridge(self) def connectionMade(self): """ The subprocess was started. """ self.protocol.makeConnection(BridgeTransport(self.transport)) def outReceived(self, data): """ Some data was received to standard output; relay it to the protocol. """ self.protocol.dataReceived(data) def errReceived(self, data): """ Some standard error was received from the subprocess. """ log.msg("Error output from process: " + data, isError=True) _killTimeout = None def eventuallyStop(self): """ Eventually stop this subprocess. Send it a SIGTERM, and if it hasn't stopped by C{self.killTimeout} seconds, send it a SIGKILL. """ self.transport.signalProcess('TERM') def reallyStop(): self.transport.signalProcess("KILL") self._killTimeout = None self._killTimeout = ( self.service.reactor.callLater(self.killTimeout, reallyStop) ) def processEnded(self, reason): """ The process has ended; notify the L{SpawnerService} that this bridge has stopped. """ if self._killTimeout is not None: self._killTimeout.cancel() self.protocol.connectionLost(reason) self.service.removeBridge(self) class SpawnerService(Service, object): """ Process to spawn services and then shut them down. @ivar reactor: an L{IReactorProcess}/L{IReactorTime} @ivar pendingSpawns: a C{list} of 2-C{tuple}s of hereProto, thereProto. @ivar bridges: a C{list} of L{BridgeProtocol} instances. """ def __init__(self, reactor=None): if reactor is None: from twisted.internet import reactor self.reactor = reactor self.pendingSpawns = [] self.bridges = [] self._stopAllDeferred = None def spawn(self, hereProto, thereProto, childFDs=None): """ Spawn a subprocess with a connected pair of protocol objects, one in the current process, one in the subprocess. @param hereProto: a L{Protocol} instance to listen in this process. @param thereProto: a top-level class or function that will be imported and called in the spawned subprocess. @param childFDs: File descriptors to share with the subprocess; same format as L{IReactorProcess.spawnProcess}. @return: a L{Deferred} that fires when C{hereProto} is ready. """ if not self.running: self.pendingSpawns.append((hereProto, thereProto)) return name = qual(thereProto) argv = [sys.executable, '-u', '-m', __name__, name] self.reactor.spawnProcess( BridgeProtocol(self, hereProto), sys.executable, argv, os.environ, childFDs=childFDs ) return succeed(hereProto) def startService(self): """ Start the service; spawn any processes previously started with spawn(). """ super(SpawnerService, self).startService() for spawn in self.pendingSpawns: self.spawn(*spawn) self.pendingSpawns = [] def addBridge(self, bridge): """ Add a L{BridgeProtocol} to the list to be tracked. """ self.bridges.append(bridge) def removeBridge(self, bridge): """ The process controlled by a L{BridgeProtocol} has terminated; remove it from the active list, and fire any outstanding Deferred. @param bridge: the protocol which has ended. """ self.bridges.remove(bridge) if self._stopAllDeferred is not None: if len(self.bridges) == 0: self._stopAllDeferred.callback(None) self._stopAllDeferred = None def stopService(self): """ Stop the service. """ super(SpawnerService, self).stopService() if self.bridges: self._stopAllDeferred = Deferred() for bridge in self.bridges: bridge.eventuallyStop() return self._stopAllDeferred return succeed(None) calendarserver-5.2+dfsg/twext/internet/test/0000755000175000017500000000000012322625326020274 5ustar rahulrahulcalendarserver-5.2+dfsg/twext/internet/test/test_gaiendpoint.py0000644000175000017500000000643512263343324024215 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Test cases for L{twext.internet.gaiendpoint} """ from socket import getaddrinfo, AF_INET, SOCK_STREAM from twext.internet.gaiendpoint import GAIEndpoint from twisted.trial.unittest import TestCase from twisted.internet.defer import Deferred from twisted.internet.protocol import Factory, Protocol from twisted.internet.task import Clock class FakeTCPEndpoint(object): def __init__(self, reactor, host, port, contextFactory): self._reactor = reactor self._host = host self._port = port self._attempt = None self._contextFactory = contextFactory def connect(self, factory): self._attempt = Deferred() self._factory = factory return self._attempt class GAIEndpointTestCase(TestCase): """ Test cases for L{GAIEndpoint}. """ def makeEndpoint(self, host="abcd.example.com", port=4321): gaie = GAIEndpoint(self.clock, host, port) gaie.subEndpoint = self.subEndpoint gaie.deferToThread = self.deferToSomething return gaie def subEndpoint(self, reactor, host, port, contextFactory): ftcpe = FakeTCPEndpoint(reactor, host, port, contextFactory) self.fakeRealEndpoints.append(ftcpe) return ftcpe def deferToSomething(self, func, *a, **k): """ Test replacement for L{deferToThread}, which can only call L{getaddrinfo}. """ d = Deferred() if func is not getaddrinfo: self.fail("Only getaddrinfo should be invoked in a thread.") self.inThreads.append((d, func, a, k)) return d def gaiResult(self, family, socktype, proto, canonname, sockaddr): """ A call to L{getaddrinfo} has succeeded; invoke the L{Deferred} waiting on it. """ d, f, a, k = self.inThreads.pop(0) d.callback([(family, socktype, proto, canonname, sockaddr)]) def setUp(self): """ Set up! """ self.inThreads = [] self.clock = Clock() self.fakeRealEndpoints = [] self.makeEndpoint() def test_simpleSuccess(self): """ If C{getaddrinfo} gives one L{GAIEndpoint.connect}. """ gaiendpoint = self.makeEndpoint() protos = [] f = Factory() f.protocol = Protocol gaiendpoint.connect(f).addCallback(protos.append) WHO_CARES = 0 WHAT_EVER = "" self.gaiResult(AF_INET, SOCK_STREAM, WHO_CARES, WHAT_EVER, ("1.2.3.4", 4321)) self.clock.advance(1.0) attempt = self.fakeRealEndpoints[0]._attempt attempt.callback(self.fakeRealEndpoints[0]._factory.buildProtocol(None)) self.assertEqual(len(protos), 1) calendarserver-5.2+dfsg/twext/internet/test/test_fswatch.py0000644000175000017500000001166212263343324023351 0ustar rahulrahul## # Copyright (c) 2013-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.internet.fswatch}. """ from twext.internet.fswatch import DirectoryChangeListener, patchReactor, \ IDirectoryChangeListenee from twisted.internet.kqreactor import KQueueReactor from twisted.python.filepath import FilePath from twisted.trial.unittest import TestCase from zope.interface import implements class KQueueReactorTestFixture(object): def __init__(self, testCase, action=None, timeout=10): """ Creates a kqueue reactor for use in unit tests. The reactor is patched with the vnode event handler. Once the reactor is running, it will call a supplied method. It's expected that the method will ultimately trigger the stop() of the reactor. The reactor will time out after 10 seconds. @param testCase: a test method which is needed for adding cleanup to @param action: a method which will get called after the reactor is running @param timeout: how many seconds to keep the reactor running before giving up and stopping it """ self.testCase = testCase self.reactor = KQueueReactor() patchReactor(self.reactor) self.action = action self.timeout = timeout def maybeStop(): if self.reactor.running: return self.reactor.stop() self.testCase.addCleanup(maybeStop) def runReactor(self): """ Run the test reactor, adding cleanup code to stop if after a timeout, and calling the action method """ def getReadyToStop(): self.reactor.callLater(self.timeout, self.reactor.stop) self.reactor.callWhenRunning(getReadyToStop) if self.action is not None: self.reactor.callWhenRunning(self.action) self.reactor.run(installSignalHandlers=False) class DataStoreMonitor(object): """ Stub IDirectoryChangeListenee """ implements(IDirectoryChangeListenee) def __init__(self, reactor, storageService): """ @param storageService: the service making use of the DataStore directory; we send it a hardStop() to shut it down """ self._reactor = reactor self._storageService = storageService self.methodCalled = "" def disconnected(self): self.methodCalled = "disconnected" self._storageService.hardStop() self._reactor.stop() def deleted(self): self.methodCalled = "deleted" self._storageService.hardStop() self._reactor.stop() def renamed(self): self.methodCalled = "renamed" self._storageService.hardStop() self._reactor.stop() def connectionLost(self, reason): pass class StubStorageService(object): """ Implements hardStop for testing """ def __init__(self, ignored): self.stopCalled = False def hardStop(self): self.stopCalled = True class DirectoryChangeListenerTestCase(TestCase): def test_delete(self): """ Verify directory deletions can be monitored """ self.tmpdir = FilePath(self.mktemp()) self.tmpdir.makedirs() def deleteAction(): self.tmpdir.remove() resource = KQueueReactorTestFixture(self, deleteAction) storageService = StubStorageService(resource.reactor) delegate = DataStoreMonitor(resource.reactor, storageService) dcl = DirectoryChangeListener(resource.reactor, self.tmpdir.path, delegate) dcl.startListening() resource.runReactor() self.assertTrue(storageService.stopCalled) self.assertEquals(delegate.methodCalled, "deleted") def test_rename(self): """ Verify directory renames can be monitored """ self.tmpdir = FilePath(self.mktemp()) self.tmpdir.makedirs() def renameAction(): self.tmpdir.moveTo(FilePath(self.mktemp())) resource = KQueueReactorTestFixture(self, renameAction) storageService = StubStorageService(resource.reactor) delegate = DataStoreMonitor(resource.reactor, storageService) dcl = DirectoryChangeListener(resource.reactor, self.tmpdir.path, delegate) dcl.startListening() resource.runReactor() self.assertTrue(storageService.stopCalled) self.assertEquals(delegate.methodCalled, "renamed") calendarserver-5.2+dfsg/twext/internet/test/test_adaptendpoint.py0000644000175000017500000001747012263343324024547 0ustar rahulrahul## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.internet.adaptendpoint}. """ from zope.interface.verify import verifyObject from twext.internet.adaptendpoint import connect from twisted.internet.defer import Deferred, CancelledError from twisted.python.failure import Failure from twisted.internet.protocol import ClientFactory, Protocol from twisted.internet.interfaces import IConnector from twisted.trial.unittest import TestCase class names(object): def __init__(self, **kw): self.__dict__.update(kw) class RecordingProtocol(Protocol, object): def __init__(self): super(RecordingProtocol, self).__init__() self.made = [] self.data = [] self.lost = [] def connectionMade(self): self.made.append(self.transport) def dataReceived(self, data): self.data.append(data) def connectionLost(self, why): self.lost.append(why) class RecordingClientFactory(ClientFactory): """ L{ClientFactory} subclass that records the things that happen to it. """ def __init__(self): """ Create some records of things that are about to happen. """ self.starts = [] self.built = [] self.fails = [] self.lost = [] def startedConnecting(self, ctr): self.starts.append(ctr) def clientConnectionFailed(self, ctr, reason): self.fails.append(names(connector=ctr, reason=reason)) def clientConnectionLost(self, ctr, reason): self.lost.append(names(connector=ctr, reason=reason)) def buildProtocol(self, addr): b = RecordingProtocol() self.built.append(names(protocol=b, addr=addr)) return b class RecordingEndpoint(object): def __init__(self): self.attempts = [] def connect(self, factory): d = Deferred() self.attempts.append(names(deferred=d, factory=factory)) return d class RecordingTransport(object): def __init__(self): self.lose = [] def loseConnection(self): self.lose.append(self) class AdaptEndpointTests(TestCase): """ Tests for L{connect} and the objects that it coordinates. """ def setUp(self): self.factory = RecordingClientFactory() self.endpoint = RecordingEndpoint() self.connector = connect(self.endpoint, self.factory) def connectionSucceeds(self, addr=object()): """ The most recent connection attempt succeeds, returning the L{ITransport} provider produced by its success. """ transport = RecordingTransport() attempt = self.endpoint.attempts[-1] proto = attempt.factory.buildProtocol(addr) proto.makeConnection(transport) transport.protocol = proto attempt.deferred.callback(proto) return transport def connectionFails(self, reason): """ The most recent in-progress connection fails. """ self.endpoint.attempts[-1].deferred.errback(reason) def test_connectStartsConnection(self): """ When used with a successful endpoint, L{connect} will simulate all aspects of the connection process; C{buildProtocol}, C{connectionMade}, C{dataReceived}. """ self.assertIdentical(self.connector.getDestination(), self.endpoint) verifyObject(IConnector, self.connector) self.assertEqual(self.factory.starts, [self.connector]) self.assertEqual(len(self.endpoint.attempts), 1) self.assertEqual(len(self.factory.built), 0) transport = self.connectionSucceeds() self.assertEqual(len(self.factory.built), 1) made = transport.protocol.made self.assertEqual(len(made), 1) self.assertIdentical(made[0], transport) def test_connectionLost(self): """ When the connection is lost, both the protocol and the factory will be notified via C{connectionLost} and C{clientConnectionLost}. """ why = Failure(RuntimeError()) proto = self.connectionSucceeds().protocol proto.connectionLost(why) self.assertEquals(len(self.factory.built), 1) self.assertEquals(self.factory.built[0].protocol.lost, [why]) self.assertEquals(len(self.factory.lost), 1) self.assertIdentical(self.factory.lost[0].reason, why) def test_connectionFailed(self): """ When the L{Deferred} from the endpoint fails, the L{ClientFactory} gets notified via C{clientConnectionFailed}. """ why = Failure(RuntimeError()) self.connectionFails(why) self.assertEquals(len(self.factory.fails), 1) self.assertIdentical(self.factory.fails[0].reason, why) def test_disconnectWhileConnecting(self): """ When the L{IConnector} is told to C{disconnect} before an in-progress L{Deferred} from C{connect} has fired, it will cancel that L{Deferred}. """ self.connector.disconnect() self.assertEqual(len(self.factory.fails), 1) self.assertTrue(self.factory.fails[0].reason.check(CancelledError)) def test_disconnectWhileConnected(self): """ When the L{IConnector} is told to C{disconnect} while an existing connection is established, that connection will be dropped via C{loseConnection}. """ transport = self.connectionSucceeds() self.factory.starts[0].disconnect() self.assertEqual(transport.lose, [transport]) def test_connectAfterFailure(self): """ When the L{IConnector} is told to C{connect} after a connection attempt has failed, a new connection attempt is started. """ why = Failure(ZeroDivisionError()) self.connectionFails(why) self.connector.connect() self.assertEqual(len(self.factory.starts), 2) self.assertEqual(len(self.endpoint.attempts), 2) self.connectionSucceeds() def test_reConnectTooSoon(self): """ When the L{IConnector} is told to C{connect} while another attempt is still in flight, it synchronously raises L{RuntimeError}. """ self.assertRaises(RuntimeError, self.connector.connect) self.assertEqual(len(self.factory.starts), 1) self.assertEqual(len(self.endpoint.attempts), 1) def test_stopConnectingWhileConnecting(self): """ When the L{IConnector} is told to C{stopConnecting} while another attempt is still in flight, it cancels that connection. """ self.connector.stopConnecting() self.assertEqual(len(self.factory.fails), 1) self.assertTrue(self.factory.fails[0].reason.check(CancelledError)) def test_stopConnectingWhileConnected(self): """ When the L{IConnector} is told to C{stopConnecting} while already connected, it raises a L{RuntimeError}. """ self.connectionSucceeds() self.assertRaises(RuntimeError, self.connector.stopConnecting) def test_stopConnectingWhileNotConnected(self): """ When the L{IConnector} is told to C{stopConnecting} while it is not connected or connecting, it raises L{RuntimeError}. """ self.connectionFails(Failure(ZeroDivisionError())) self.assertRaises(RuntimeError, self.connector.stopConnecting) calendarserver-5.2+dfsg/twext/internet/test/__init__.py0000644000175000017500000000113612263343324022405 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## calendarserver-5.2+dfsg/twext/internet/test/test_sendfdport.py0000644000175000017500000002037012306427141024054 0ustar rahulrahul# -*- test-case-name: twext.internet.test.test_sendfdport -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for L{twext.internet.sendfdport}. """ import os import fcntl from zope.interface.verify import verifyClass from zope.interface import implementer from twext.internet.sendfdport import InheritedSocketDispatcher from twext.internet.sendfdport import IStatusWatcher, IStatus from twext.web2.metafd import ConnectionLimiter from twisted.internet.interfaces import IReactorFDSet from twisted.trial.unittest import TestCase def verifiedImplementer(interface): def _(cls): result = implementer(interface)(cls) verifyClass(interface, result) return result return _ @verifiedImplementer(IReactorFDSet) class ReaderAdder(object): def __init__(self): self.readers = [] self.writers = [] def addReader(self, reader): self.readers.append(reader) def getReaders(self): return self.readers[:] def addWriter(self, writer): self.writers.append(writer) def removeAll(self): self.__init__() def getWriters(self): return self.writers[:] def removeReader(self, reader): self.readers.remove(reader) def removeWriter(self, writer): self.writers.remove(writer) def isNonBlocking(skt): """ Determine if the given socket is blocking or not. @param skt: a socket. @type skt: L{socket.socket} @return: L{True} if the socket is non-blocking, L{False} if the socket is blocking. @rtype: L{bool} """ return bool(fcntl.fcntl(skt.fileno(), fcntl.F_GETFL) & os.O_NONBLOCK) @verifiedImplementer(IStatus) class Status(object): def __init__(self): self.count = 0 self.available = False def effective(self): return self.count def active(self): return self.available def start(self): self.available = False return self def restarted(self): self.available = True return self def stop(self): self.count = 0 self.available = False return self @verifiedImplementer(IStatusWatcher) class Watcher(object): def __init__(self, q): self.q = q self._closeCounter = 1 def newConnectionStatus(self, previous): previous.count += 1 return previous def statusFromMessage(self, previous, message): previous.count -= 1 return previous def statusesChanged(self, statuses): self.q.append([(status.count, status.available) for status in statuses]) def initialStatus(self): return Status() def closeCountFromStatus(self, status): result = (self._closeCounter, status) self._closeCounter += 1 return result class InheritedSocketDispatcherTests(TestCase): """ Inherited socket dispatcher tests. """ def setUp(self): self.dispatcher = InheritedSocketDispatcher(ConnectionLimiter(2, 20)) self.dispatcher.reactor = ReaderAdder() def test_closeSomeSockets(self): """ L{InheritedSocketDispatcher} determines how many sockets to close from L{IStatusWatcher.closeCountFromStatus}. """ self.dispatcher.statusWatcher = Watcher([]) class SocketForClosing(object): blocking = True closed = False def setblocking(self, b): self.blocking = b def fileno(self): return object() def close(self): self.closed = True one = SocketForClosing() two = SocketForClosing() three = SocketForClosing() skt = self.dispatcher.addSocket( lambda: (SocketForClosing(), SocketForClosing()) ) skt.restarted() self.dispatcher.sendFileDescriptor(one, "one") self.dispatcher.sendFileDescriptor(two, "two") self.dispatcher.sendFileDescriptor(three, "three") def sendfd(unixSocket, tcpSocket, description): pass # Put something into the socket-close queue. self.dispatcher._subprocessSockets[0].doWrite(sendfd) # Nothing closed yet. self.assertEquals(one.closed, False) self.assertEquals(two.closed, False) self.assertEquals(three.closed, False) def recvmsg(fileno): return 'data', 0, 0 self.dispatcher._subprocessSockets[0].doRead(recvmsg) # One socket closed. self.assertEquals(one.closed, True) self.assertEquals(two.closed, False) self.assertEquals(three.closed, False) def test_nonBlocking(self): """ Creating a L{_SubprocessSocket} via L{InheritedSocketDispatcher.addSocket} results in a non-blocking L{socket.socket} object being assigned to its C{skt} attribute, as well as a non-blocking L{socket.socket} object being returned. """ dispatcher = self.dispatcher dispatcher.startDispatching() inputSocket = dispatcher.addSocket() outputSocket = self.dispatcher.reactor.readers[-1] self.assertTrue(isNonBlocking(inputSocket), "Input is blocking.") self.assertTrue(isNonBlocking(outputSocket), "Output is blocking.") def test_addAfterStart(self): """ Adding a socket to an L{InheritedSocketDispatcher} after it has already been started results in it immediately starting reading. """ dispatcher = self.dispatcher dispatcher.startDispatching() dispatcher.addSocket() self.assertEquals(dispatcher.reactor.getReaders(), dispatcher._subprocessSockets) def test_statusesChangedOnNewConnection(self): """ L{InheritedSocketDispatcher.sendFileDescriptor} will update its C{statusWatcher} via C{statusesChanged}. """ q = [] dispatcher = self.dispatcher dispatcher.statusWatcher = Watcher(q) description = "whatever" # Need to have a socket that will accept the descriptors. skt = dispatcher.addSocket() skt.restarted() dispatcher.sendFileDescriptor(object(), description) dispatcher.sendFileDescriptor(object(), description) self.assertEquals(q, [[(0, True)], [(1, True)], [(2, True)]]) def test_statusesChangedOnStatusMessage(self): """ L{InheritedSocketDispatcher.sendFileDescriptor} will update its C{statusWatcher} will update its C{statusWatcher} via C{statusesChanged}. """ q = [] dispatcher = self.dispatcher dispatcher.statusWatcher = Watcher(q) message = "whatever" # Need to have a socket that will accept the descriptors. dispatcher.addSocket() subskt = dispatcher._subprocessSockets[0] dispatcher.statusMessage(subskt, message) dispatcher.statusMessage(subskt, message) self.assertEquals(q, [[(-1, False)], [(-2, False)]]) def test_statusesChangedOnStartRestartStop(self): """ L{_SubprocessSocket} will update its C{status} when state change. """ q = [] dispatcher = self.dispatcher dispatcher.statusWatcher = Watcher(q) message = "whatever" # Need to have a socket that will accept the descriptors. subskt = dispatcher.addSocket() subskt.start() subskt.restarted() dispatcher.sendFileDescriptor(subskt, message) subskt.stop() subskt.start() subskt.restarted() self.assertEquals( q, [ [(0, False)], [(0, True)], [(1, True)], [(0, False)], [(0, False)], [(0, True)], ] ) calendarserver-5.2+dfsg/twext/internet/tcp.py0000644000175000017500000001401512263343324020455 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extentions to twisted.internet.tcp. """ __all__ = [ "MaxAcceptTCPServer", "MaxAcceptSSLServer", ] import socket from OpenSSL import SSL from twisted.application import internet from twisted.internet import tcp, ssl from twisted.internet.defer import succeed from twext.python.log import Logger log = Logger() class MaxAcceptPortMixin(object): """ Mixin for resetting maxAccepts. """ def doRead(self): self.numberAccepts = min( self.factory.maxRequests - self.factory.outstandingRequests, self.factory.maxAccepts ) tcp.Port.doRead(self) class MaxAcceptTCPPort(MaxAcceptPortMixin, tcp.Port): """ Use for non-inheriting tcp ports. """ class MaxAcceptSSLPort(MaxAcceptPortMixin, ssl.Port): """ Use for non-inheriting SSL ports. """ class InheritedTCPPort(MaxAcceptTCPPort): """ A tcp port which uses an inherited file descriptor. """ def __init__(self, fd, factory, reactor): tcp.Port.__init__(self, 0, factory, reactor=reactor) # MOR: careful because fromfd dup()'s the socket, so we need to # make sure we don't leak file descriptors self.socket = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) self._realPortNumber = self.port = self.socket.getsockname()[1] def createInternetSocket(self): return self.socket def startListening(self): log.info("%s starting on %s" % (self.factory.__class__, self._realPortNumber)) self.factory.doStart() self.connected = 1 self.fileno = self.socket.fileno self.numberAccepts = self.factory.maxRequests self.startReading() class InheritedSSLPort(InheritedTCPPort): """ An SSL port which uses an inherited file descriptor. """ _socketShutdownMethod = 'sock_shutdown' transport = ssl.Server def __init__(self, fd, factory, ctxFactory, reactor): InheritedTCPPort.__init__(self, fd, factory, reactor) self.ctxFactory = ctxFactory self.socket = SSL.Connection(self.ctxFactory.getContext(), self.socket) def _preMakeConnection(self, transport): transport._startTLS() return tcp.Port._preMakeConnection(self, transport) def _allConnectionsClosed(protocolFactory): """ Check to see if protocolFactory implements allConnectionsClosed( ) and if so, call it. Otherwise, return immediately. This allows graceful shutdown by waiting for all requests to be completed. @param protocolFactory: (usually) an HTTPFactory implementing allConnectionsClosed which returns a Deferred which fires when all connections are closed. @return: A Deferred firing None when all connections are closed, or immediately if the given factory does not track its connections (e.g. InheritingProtocolFactory) """ if hasattr(protocolFactory, "allConnectionsClosed"): return protocolFactory.allConnectionsClosed() return succeed(None) class MaxAcceptTCPServer(internet.TCPServer): """ TCP server which will uses MaxAcceptTCPPorts (and optionally, inherited ports) @ivar myPort: When running, this is set to the L{IListeningPort} being managed by this service. """ def __init__(self, *args, **kwargs): internet.TCPServer.__init__(self, *args, **kwargs) self.protocolFactory = self.args[1] self.protocolFactory.myServer = self self.inherit = self.kwargs.get("inherit", False) self.backlog = self.kwargs.get("backlog", None) self.interface = self.kwargs.get("interface", None) def _getPort(self): from twisted.internet import reactor if self.inherit: port = InheritedTCPPort(self.args[0], self.args[1], reactor) else: port = MaxAcceptTCPPort(self.args[0], self.args[1], self.backlog, self.interface, reactor) port.startListening() self.myPort = port return port def stopService(self): """ Wait for outstanding requests to finish @return: a Deferred which fires when all outstanding requests are complete """ internet.TCPServer.stopService(self) return _allConnectionsClosed(self.protocolFactory) class MaxAcceptSSLServer(internet.SSLServer): """ SSL server which will uses MaxAcceptSSLPorts (and optionally, inherited ports) """ def __init__(self, *args, **kwargs): internet.SSLServer.__init__(self, *args, **kwargs) self.protocolFactory = self.args[1] self.protocolFactory.myServer = self self.inherit = self.kwargs.get("inherit", False) self.backlog = self.kwargs.get("backlog", None) self.interface = self.kwargs.get("interface", None) def _getPort(self): from twisted.internet import reactor if self.inherit: port = InheritedSSLPort(self.args[0], self.args[1], self.args[2], reactor) else: port = MaxAcceptSSLPort(self.args[0], self.args[1], self.args[2], self.backlog, self.interface, self.reactor) port.startListening() self.myPort = port return port def stopService(self): """ Wait for outstanding requests to finish @return: a Deferred which fires when all outstanding requests are complete """ internet.SSLServer.stopService(self) # TODO: check for an ICompletionWaiter interface return _allConnectionsClosed(self.protocolFactory) calendarserver-5.2+dfsg/twext/internet/adaptendpoint.py0000644000175000017500000001205412263343324022522 0ustar rahulrahul# -*- test-case-name: twext.internet.test.test_adaptendpoint -*- ## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Adapter for old-style connectTCP/connectSSL code to use endpoints and be happy; specifically, to receive the additional duplicate notifications that it wants to receive, L{clientConnectionLost} and L{clientConnectionFailed} on the factory. """ from zope.interface import implements from twisted.internet.interfaces import IConnector from twisted.internet.protocol import Factory from twisted.python import log __all__ = [ "connect" ] class _WrappedProtocol(object): """ A protocol providing a thin wrapper that relays the connectionLost notification. """ def __init__(self, wrapped, wrapper): """ @param wrapped: the wrapped L{IProtocol} provider, to which all methods will be relayed. @param wrapper: The L{LegacyClientFactoryWrapper} that holds the relevant L{ClientFactory}. """ self._wrapped = wrapped self._wrapper = wrapper def __getattr__(self, attr): """ Relay all undefined methods to the wrapped protocol. """ return getattr(self._wrapped, attr) def connectionLost(self, reason): """ When the connection is lost, return the connection. """ try: self._wrapped.connectionLost(reason) except: log.err() self._wrapper.legacyFactory.clientConnectionLost(self._wrapper, reason) class LegacyClientFactoryWrapper(Factory): implements(IConnector) def __init__(self, legacyFactory, endpoint): self.currentlyConnecting = False self.legacyFactory = legacyFactory self.endpoint = endpoint self._connectedProtocol = None self._outstandingAttempt = None def getDestination(self): """ Implement L{IConnector.getDestination}. @return: the endpoint being connected to as the destination. """ return self.endpoint def buildProtocol(self, addr): """ Implement L{Factory.buildProtocol} to return a wrapper protocol that will capture C{connectionLost} notifications. @return: a L{Protocol}. """ return _WrappedProtocol(self.legacyFactory.buildProtocol(addr), self) def connect(self): """ Implement L{IConnector.connect} to connect the endpoint. """ if self._outstandingAttempt is not None: raise RuntimeError("connection already in progress") self.legacyFactory.startedConnecting(self) d = self._outstandingAttempt = self.endpoint.connect(self) @d.addBoth def attemptDone(result): self._outstandingAttempt = None return result def rememberProto(proto): self._connectedProtocol = proto return proto def callClientConnectionFailed(reason): self.legacyFactory.clientConnectionFailed(self, reason) d.addCallbacks(rememberProto, callClientConnectionFailed) def disconnect(self): """ Implement L{IConnector.disconnect}. """ if self._connectedProtocol is not None: self._connectedProtocol.transport.loseConnection() elif self._outstandingAttempt is not None: self._outstandingAttempt.cancel() def stopConnecting(self): """ Implement L{IConnector.stopConnecting}. """ if self._outstandingAttempt is None: raise RuntimeError("no connection attempt in progress") self.disconnect() def connect(endpoint, clientFactory): """ Connect a L{twisted.internet.protocol.ClientFactory} to a remote host using the given L{twisted.internet.interfaces.IStreamClientEndpoint}. This relays C{clientConnectionFailed} and C{clientConnectionLost} notifications as legacy code using the L{ClientFactory} interface, such as, L{ReconnectingClientFactory} would expect. @param endpoint: The endpoint to connect to. @type endpoint: L{twisted.internet.interfaces.IStreamClientEndpoint} @param clientFactory: The client factory doing the connecting. @type clientFactory: L{twisted.internet.protocol.ClientFactory} @return: A connector object representing the connection attempt just initiated. @rtype: L{IConnector} """ wrap = LegacyClientFactoryWrapper(clientFactory, endpoint) wrap.noisy = clientFactory.noisy # relay the noisy attribute to the wrapper wrap.connect() return wrap calendarserver-5.2+dfsg/twext/internet/kqreactor.py0000644000175000017500000001620311347513520021662 0ustar rahulrahul# Copyright (c) 2001-2008 Twisted Matrix Laboratories. # See LICENSE for details. """ A kqueue()/kevent() based implementation of the Twisted main loop. To install the event loop (and you should do this before any connections, listeners or connectors are added):: | from twisted.internet import kqreactor | kqreactor.install() Maintainer: U{Itamar Shtull-Trauring} """ import errno, sys try: from select import KQ_FILTER_READ, KQ_FILTER_WRITE, KQ_EV_DELETE, KQ_EV_ADD from select import kqueue, kevent, KQ_EV_ENABLE, KQ_EV_DISABLE, KQ_EV_EOF except ImportError: from select26 import KQ_FILTER_READ, KQ_FILTER_WRITE, KQ_EV_DELETE, KQ_EV_ADD from select26 import kqueue, kevent, KQ_EV_ENABLE, KQ_EV_DISABLE, KQ_EV_EOF from zope.interface import implements from twisted.python import log from twisted.internet import main, posixbase from twisted.internet.interfaces import IReactorFDSet class KQueueReactor(posixbase.PosixReactorBase): """ A reactor that uses kqueue(2)/kevent(2). @ivar _kq: A L{kqueue} which will be used to check for I/O readiness. @ivar _selectables: A dictionary mapping integer file descriptors to instances of L{FileDescriptor} which have been registered with the reactor. All L{FileDescriptors} which are currently receiving read or write readiness notifications will be present as values in this dictionary. @ivar _reads: A set storing integer file descriptors. These values will be registered with C{_kq} for read readiness notifications which will be dispatched to the corresponding L{FileDescriptor} instances in C{_selectables}. @ivar _writes: A set storing integer file descriptors. These values will be registered with C{_kq} for write readiness notifications which will be dispatched to the corresponding L{FileDescriptor} instances in C{_selectables}. """ implements(IReactorFDSet) def __init__(self): """ Initialize kqueue object, file descriptor tracking sets, and the base class. """ self._kq = kqueue() self._reads = set() self._writes = set() self._selectables = {} posixbase.PosixReactorBase.__init__(self) def _updateRegistration(self, fd, filter, flags): ev = kevent(fd, filter, flags) self._kq.control([ev], 0, 0) def addReader(self, reader): """ Add a FileDescriptor for notification of data available to read. """ fd = reader.fileno() if fd not in self._reads: if fd not in self._selectables: self._updateRegistration(fd, KQ_FILTER_READ, KQ_EV_ADD|KQ_EV_ENABLE) self._updateRegistration(fd, KQ_FILTER_WRITE, KQ_EV_ADD|KQ_EV_DISABLE) self._selectables[fd] = reader else: self._updateRegistration(fd, KQ_FILTER_READ, KQ_EV_ENABLE) self._reads.add(fd) def addWriter(self, writer): """ Add a FileDescriptor for notification of data available to write. """ fd = writer.fileno() if fd not in self._writes: if fd not in self._selectables: self._updateRegistration(fd, KQ_FILTER_WRITE, KQ_EV_ADD|KQ_EV_ENABLE) self._updateRegistration(fd, KQ_FILTER_READ, KQ_EV_ADD|KQ_EV_DISABLE) self._selectables[fd] = writer else: self._updateRegistration(fd, KQ_FILTER_WRITE, KQ_EV_ENABLE) self._writes.add(fd) def removeReader(self, reader): """ Remove a Selectable for notification of data available to read. """ fd = reader.fileno() if fd == -1: for fd, fdes in self._selectables.iteritems(): if reader is fdes: break else: return if fd in self._reads: self._reads.discard(fd) if fd not in self._writes: del self._selectables[fd] self._updateRegistration(fd, KQ_FILTER_READ, KQ_EV_DISABLE) def removeWriter(self, writer): """ Remove a Selectable for notification of data available to write. """ fd = writer.fileno() if fd == -1: for fd, fdes in self._selectables.iteritems(): if writer is fdes: break else: return if fd in self._writes: self._writes.discard(fd) if fd not in self._reads: del self._selectables[fd] self._updateRegistration(fd, KQ_FILTER_WRITE, KQ_EV_DISABLE) def removeAll(self): """ Remove all selectables, and return a list of them. """ if self.waker is not None: self.removeReader(self.waker) result = self._selectables.values() for fd in self._reads: self._updateRegistration(fd, KQ_FILTER_READ, KQ_EV_DELETE) for fd in self._writes: self._updateRegistration(fd, KQ_FILTER_WRITE, KQ_EV_DELETE) self._reads.clear() self._writes.clear() self._selectables.clear() if self.waker is not None: self.addReader(self.waker) return result def getReaders(self): return [self._selectables[fd] for fd in self._reads] def getWriters(self): return [self._selectables[fd] for fd in self._writes] def doKEvent(self, timeout): """ Poll the kqueue for new events. """ if timeout is None: timeout = 1 try: l = self._kq.control([], len(self._selectables), timeout) except OSError, e: if e[0] == errno.EINTR: return else: raise _drdw = self._doWriteOrRead for event in l: fd = event.ident try: selectable = self._selectables[fd] except KeyError: # Handles the infrequent case where one selectable's # handler disconnects another. continue log.callWithLogger(selectable, _drdw, selectable, fd, event) def _doWriteOrRead(self, selectable, fd, event): why = None inRead = False filter, flags, data, fflags = event.filter, event.flags, event.data, event.fflags if flags & KQ_EV_EOF and data and fflags: why = main.CONNECTION_LOST else: try: if filter == KQ_FILTER_READ: inRead = True why = selectable.doRead() if filter == KQ_FILTER_WRITE: inRead = False why = selectable.doWrite() if not selectable.fileno() == fd: inRead = False why = main.CONNECTION_LOST except: log.err() why = sys.exc_info()[1] if why: self._disconnectSelectable(selectable, why, inRead) doIteration = doKEvent def install(): k = KQueueReactor() main.installReactor(k) __all__ = ["KQueueReactor", "install"] calendarserver-5.2+dfsg/twext/internet/gaiendpoint.py0000644000175000017500000001451312263343324022173 0ustar rahulrahul# -*- test-case-name: twext.internet.test.test_gaiendpoint -*- ## # Copyright (c) 2012-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from __future__ import print_function """ L{getaddrinfo}()-based endpoint """ from socket import getaddrinfo, AF_UNSPEC, AF_INET, AF_INET6, SOCK_STREAM from twisted.internet.endpoints import TCP4ClientEndpoint, SSL4ClientEndpoint from twisted.internet.defer import Deferred from twisted.internet.threads import deferToThread from twisted.internet.task import LoopingCall class MultiFailure(Exception): def __init__(self, failures): super(MultiFailure, self).__init__("Failure with multiple causes.") self.failures = failures class GAIEndpoint(object): """ Client endpoint that will call L{getaddrinfo} in a thread and then attempt to connect to each endpoint (almost) in parallel. @ivar reactor: The reactor to attempt the connection with. @type reactor: provider of L{IReactorTCP} and L{IReactorTime} @ivar host: The host to resolve. @type host: L{str} @ivar port: The port number to resolve. @type port: L{int} @ivar deferToThread: A function like L{deferToThread}, used to invoke getaddrinfo. (Replaceable mainly for testing purposes.) """ deferToThread = staticmethod(deferToThread) def subEndpoint(self, reactor, host, port, contextFactory): """ Create an endpoint to connect to based on a single address result from L{getaddrinfo}. @param reactor: the reactor to connect to @type reactor: L{IReactorTCP} @param host: The IP address of the host to connect to, in presentation format. @type host: L{str} @param port: The numeric port number to connect to. @type port: L{int} @param contextFactory: If not L{None}, the OpenSSL context factory to use to produce client connections. @return: a stream client endpoint that will connect to the given host and port via the given reactor. @rtype: L{IStreamClientEndpoint} """ if contextFactory is None: return TCP4ClientEndpoint(reactor, host, port) else: return SSL4ClientEndpoint(reactor, host, port, contextFactory) def __init__(self, reactor, host, port, contextFactory=None): self.reactor = reactor self.host = host self.port = port self.contextFactory = contextFactory def connect(self, factory): dgai = self.deferToThread(getaddrinfo, self.host, self.port, AF_UNSPEC, SOCK_STREAM) @dgai.addCallback def gaiToEndpoints(gairesult): for family, socktype, proto, canonname, sockaddr in gairesult: if family in [AF_INET6, AF_INET]: yield self.subEndpoint(self.reactor, sockaddr[0], sockaddr[1], self.contextFactory) @gaiToEndpoints.addCallback def connectTheEndpoints(endpoints): doneTrying = [] outstanding = [] errors = [] succeeded = [] actuallyDidIt = Deferred() def removeMe(result, attempt): outstanding.remove(attempt) return result def connectingDone(result): if lc.running: lc.stop() succeeded.append(True) for o in outstanding[::]: o.cancel() actuallyDidIt.callback(result) return None def lastChance(): if doneTrying and not outstanding and not succeeded: # We've issued our last attempts. There are no remaining # outstanding attempts; they've all failed. We haven't # succeeded. Time... to die. actuallyDidIt.errback(MultiFailure(errors)) def connectingFailed(why): errors.append(why) lastChance() return None def nextOne(): try: endpoint = endpoints.next() except StopIteration: # Out of endpoints to try! Now it's time to wait for all of # the outstanding attempts to complete, and, if none of them # have been successful, then to give up with a relevant # error. They'll all be dealt with by connectingDone or # connectingFailed. doneTrying.append(True) lc.stop() lastChance() else: attempt = endpoint.connect(factory) attempt.addBoth(removeMe, attempt) attempt.addCallbacks(connectingDone, connectingFailed) outstanding.append(attempt) lc = LoopingCall(nextOne) lc.clock = self.reactor lc.start(0.0) return actuallyDidIt return dgai if __name__ == '__main__': from twisted.internet import reactor import sys if sys.argv[1:]: host = sys.argv[1] port = int(sys.argv[2]) else: host = "localhost" port = 22 gaie = GAIEndpoint(reactor, host, port) from twisted.internet.protocol import Factory, Protocol class HelloGoobye(Protocol, object): def connectionMade(self): print('Hello!') self.transport.loseConnection() def connectionLost(self, reason): print('Goodbye') class MyFactory(Factory, object): def buildProtocol(self, addr): print('Building protocol for:', addr) return HelloGoobye() def bye(what): print('bye', what) reactor.stop() gaie.connect(MyFactory()).addBoth(bye) reactor.run() calendarserver-5.2+dfsg/twext/internet/__init__.py0000644000175000017500000000120712263343324021425 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extensions to twisted.internet. """ calendarserver-5.2+dfsg/twext/internet/sendfdport.py0000644000175000017500000004361712306427141022047 0ustar rahulrahul# -*- test-case-name: twext.internet.test.test_sendfdport -*- ## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Implementation of a TCP/SSL port that uses sendmsg/recvmsg as implemented by L{twext.python.sendfd}. """ from os import close from errno import EAGAIN, ENOBUFS from socket import (socketpair, fromfd, error as SocketError, AF_UNIX, SOCK_STREAM, SOCK_DGRAM) from zope.interface import Interface from twisted.internet.abstract import FileDescriptor from twisted.internet.protocol import Protocol, Factory from twext.python.log import Logger from twext.python.sendmsg import sendmsg, recvmsg from twext.python.sendfd import sendfd, recvfd from twext.python.sendmsg import getsockfam log = Logger() class InheritingProtocol(Protocol, object): """ When a connection comes in on this protocol, stop reading and writing, and dispatch the socket to another process via its factory. """ def connectionMade(self): """ A connection was received; transmit the file descriptor to another process via L{InheritingProtocolFactory} and remove my transport from the reactor. """ self.transport.stopReading() self.transport.stopWriting() skt = self.transport.getHandle() self.factory.sendSocket(skt) class InheritingProtocolFactory(Factory, object): """ An L{InheritingProtocolFactory} is a protocol factory which listens for incoming connections in a I{master process}, then sends those connections off to be inherited by a I{worker process} via an L{InheritedSocketDispatcher}. L{InheritingProtocolFactory} is instantiated in the master process. @ivar dispatcher: an L{InheritedSocketDispatcher} to use to dispatch incoming connections to an appropriate subprocess. @ivar description: the string to send along with connections received on this factory. """ protocol = InheritingProtocol def __init__(self, dispatcher, description): self.dispatcher = dispatcher self.description = description def sendSocket(self, socketObject): """ Send the given socket object on to my dispatcher. """ self.dispatcher.sendFileDescriptor(socketObject, self.description) class _SubprocessSocket(FileDescriptor, object): """ A socket in the master process pointing at a file descriptor that can be used to transmit sockets to a subprocess. @ivar outSocket: the UNIX socket used as the sendmsg() transport. @type outSocket: L{socket.socket} @ivar outgoingSocketQueue: an outgoing queue of sockets to send to the subprocess, along with their descriptions (strings describing their protocol so that the subprocess knows how to handle them; as of this writing, either C{"TCP"} or C{"SSL"}) @ivar outgoingSocketQueue: a C{list} of 2-tuples of C{(socket-object, bytes)} @ivar status: a record of the last status message received (via recvmsg) from the subprocess: this is an application-specific indication of how ready this subprocess is to receive more connections. A typical usage would be to count the open connections: this is what is passed to @type status: See L{IStatusWatcher} for an explanation of which methods determine this type. @ivar dispatcher: The socket dispatcher that owns this L{_SubprocessSocket} @type dispatcher: L{InheritedSocketDispatcher} """ def __init__(self, dispatcher, inSocket, outSocket, status, slavenum): FileDescriptor.__init__(self, dispatcher.reactor) self.status = status self.slavenum = slavenum self.dispatcher = dispatcher self.inSocket = inSocket self.outSocket = outSocket # XXX needs to be set non-blocking by somebody self.fileno = outSocket.fileno self.outgoingSocketQueue = [] self.pendingCloseSocketQueue = [] def childSocket(self): """ Return the socket that the child process will use to communicate with the master. """ return self.inSocket def start(self): """ The master process monitor is about to start the child process associated with this socket. Update status to ensure dispatcher know what is going on. """ self.status.start() self.dispatcher.statusChanged() def restarted(self): """ The child process associated with this socket has signaled it is ready. Update status to ensure dispatcher know what is going on. """ self.status.restarted() self.dispatcher.statusChanged() def stop(self): """ The master process monitor has determined the child process associated with this socket has died. Update status to ensure dispatcher know what is going on. """ self.status.stop() self.dispatcher.statusChanged() def remove(self): """ Remove this socket. """ self.status.stop() self.dispatcher.statusChanged() self.dispatcher.removeSocket() def sendSocketToPeer(self, skt, description): """ Enqueue a socket to send to the subprocess. """ self.outgoingSocketQueue.append((skt, description)) self.startWriting() def doRead(self, recvmsg=recvmsg): """ Receive a status / health message and record it. """ try: data, _ignore_flags, _ignore_ancillary = recvmsg(self.outSocket.fileno()) except SocketError, se: if se.errno not in (EAGAIN, ENOBUFS): raise else: closeCount = self.dispatcher.statusMessage(self, data) for ignored in xrange(closeCount): self.pendingCloseSocketQueue.pop(0).close() def doWrite(self, sendfd=sendfd): """ Transmit as many queued pending file descriptors as we can. """ while self.outgoingSocketQueue: skt, desc = self.outgoingSocketQueue.pop(0) try: sendfd(self.outSocket.fileno(), skt.fileno(), desc) except SocketError, se: if se.errno in (EAGAIN, ENOBUFS): self.outgoingSocketQueue.insert(0, (skt, desc)) return raise # Ready to close this socket; wait until it is acknowledged. self.pendingCloseSocketQueue.append(skt) if not self.outgoingSocketQueue: self.stopWriting() class IStatus(Interface): """ Defines the status of a socket. This keeps track of active connections etc. """ def effective(): """ The current effective load. @return: The current effective load. @rtype: L{int} """ def active(): """ Whether the socket should be active (able to be dispatched to). @return: Active state. @rtype: L{bool} """ def start(): """ Worker process is starting. Mark status accordingly but do not make it active. @return: C{self} """ def restarted(): """ Worker process has signaled it is ready so make this active. @return: C{self} """ def stop(): """ Worker process has stopped so make this inactive. @return: C{self} """ class IStatusWatcher(Interface): """ A provider of L{IStatusWatcher} tracks the I{status messages} reported by the worker processes over their control sockets, and computes internal I{status values} for those messages. The I{messages} are individual octets, representing one of three operations. C{0} meaning "a new worker process has started, with zero connections being processed", C{+} meaning "I have received and am processing your request; I am confirming that my requests-being-processed count has gone up by one", and C{-} meaning "I have completed processing a request, my requests-being-processed count has gone down by one". The I{status value} tracked by L{_SubprocessSocket.status} is an integer, indicating the current requests-being-processed value. (FIXME: the intended design here is actually just that all I{this} object knows about is that L{_SubprocessSocket.status} is an orderable value, and that this C{statusWatcher} will compute appropriate values so the status that I{sorts the least} is the socket to which new connections should be directed; also, the format of the status messages is only known / understood by the C{statusWatcher}, not the L{InheritedSocketDispatcher}. It's hard to explain it in that manner though.) @note: the intention of this interface is to eventually provide a broader notion of what might constitute 'status', so the above explanation just explains the current implementation, in for expediency's sake, rather than the somewhat more abstract language that would be accurate. """ def initialStatus(): """ A new socket was created and added to the dispatcher. Compute an initial value for its status. @return: the new status. """ def newConnectionStatus(previousStatus): """ A new connection was sent to a given socket. Compute its status based on the previous status of that socket. @param previousStatus: A status value for the socket being sent work, previously returned by one of the methods on this interface. @return: the socket's status after incrementing its outstanding work. """ def statusFromMessage(previousStatus, message): """ A status message was received by a worker. Convert the previous status value (returned from L{newConnectionStatus}, L{initialStatus}, or L{statusFromMessage}). @param previousStatus: A status value for the socket being sent work, previously returned by one of the methods on this interface. @return: the socket's status after taking the reported message into account. """ def closeCountFromStatus(previousStatus): """ Based on a status previously returned from a method on this L{IStatusWatcher}, determine how many sockets may be closed. @return: a 2-tuple of C{number of sockets that may safely be closed}, C{new status}. @rtype: 2-tuple of (C{int}, C{}) """ class InheritedSocketDispatcher(object): """ Used by one or more L{InheritingProtocolFactory}s, this keeps track of a list of available sockets that connect to I{worker process}es and sends inbound connections to be inherited over those sockets, by those processes. L{InheritedSocketDispatcher} is therefore instantiated in the I{master process}. @ivar statusWatcher: The object which will handle status messages and convert them into current statuses, as well as . @type statusWatcher: L{IStatusWatcher} """ def __init__(self, statusWatcher): """ Create a socket dispatcher. """ self._subprocessSockets = [] self.statusWatcher = statusWatcher from twisted.internet import reactor self.reactor = reactor self._isDispatching = False @property def statuses(self): """ Yield the current status of all subprocess sockets in the current priority order. """ for subsocket in self._subprocessSockets: yield subsocket.status @property def slavestates(self): """ Yield the current status of all subprocess sockets, ordered by slave number. """ for subsocket in sorted(self._subprocessSockets, key=lambda x: x.slavenum): yield (subsocket.slavenum, subsocket.status,) def statusChanged(self): """ Someone is telling us a child socket status changed. """ self.statusWatcher.statusesChanged(self.statuses) def statusMessage(self, subsocket, message): """ The status of a connection has changed; update all registered status change listeners. """ status = self.statusWatcher.statusFromMessage(subsocket.status, message) closeCount, subsocket.status = self.statusWatcher.closeCountFromStatus(status) self.statusChanged() return closeCount def sendFileDescriptor(self, skt, description): """ A connection has been received. Dispatch it to active sockets, sorted by how much work they have. @param skt: the I{connection socket} (i.e.: not the listening socket) @type skt: L{socket.socket} @param description: some text to identify to the subprocess's L{InheritedPort} what type of transport to create for this socket. @type description: C{bytes} """ self._subprocessSockets.sort(key=lambda x: x.status.effective()) selectedSocket = filter(lambda x: x.status.active(), self._subprocessSockets)[0] selectedSocket.sendSocketToPeer(skt, description) # XXX Maybe want to send along 'description' or 'skt' or some # properties thereof? -glyph selectedSocket.status = self.statusWatcher.newConnectionStatus( selectedSocket.status ) self.statusChanged() def startDispatching(self): """ Start listening on all subprocess sockets. """ self._isDispatching = True for subSocket in self._subprocessSockets: subSocket.startReading() def addSocket(self, slavenum=0, socketpair=lambda: socketpair(AF_UNIX, SOCK_DGRAM)): """ Add a C{sendmsg()}-oriented AF_UNIX socket to the pool of sockets being used for transmitting file descriptors to child processes. @return: a socket object for the receiving side; pass this object's C{fileno()} as part of the C{childFDs} argument to C{spawnProcess()}, then close it. """ i, o = socketpair() i.setblocking(False) o.setblocking(False) a = _SubprocessSocket(self, i, o, self.statusWatcher.initialStatus(), slavenum) self._subprocessSockets.append(a) if self._isDispatching: a.startReading() return a def removeSocket(self, skt): """ Removes a previously added socket from the pool of sockets being used for transmitting file descriptors to child processes. """ self._subprocessSockets.remove(skt) class InheritedPort(FileDescriptor, object): """ An L{InheritedPort} is an L{IReadDescriptor}/L{IWriteDescriptor} created in the I{worker process} to handle incoming connections dispatched via C{sendmsg}. """ def __init__(self, fd, transportFactory, protocolFactory): """ @param fd: the file descriptor representing a UNIX socket connected to a I{master process}. We will call C{recvmsg} on this socket to receive file descriptors. @type fd: C{int} @param transportFactory: a 4-argument function that takes the socket object produced from the file descriptor, the peer address of that socket, the (non-ancillary) data sent along with the incoming file descriptor, and the protocol built along with it, and returns an L{ITransport} provider. Note that this should NOT call C{makeConnection} on the protocol that it produces, as this class will do that. @param protocolFactory: an L{IProtocolFactory} """ FileDescriptor.__init__(self) self.fd = fd self.transportFactory = transportFactory self.protocolFactory = protocolFactory self.statusQueue = [] def fileno(self): """ Get the FD number for this socket. """ return self.fd def doRead(self): """ A message is ready to read. Receive a file descriptor from our parent process. """ try: fd, description = recvfd(self.fd) except SocketError, se: if se.errno != EAGAIN: raise else: try: skt = fromfd(fd, getsockfam(fd), SOCK_STREAM) close(fd) # fromfd() calls dup() try: peeraddr = skt.getpeername() except SocketError: peeraddr = ('0.0.0.0', 0) protocol = self.protocolFactory.buildProtocol(peeraddr) transport = self.transportFactory(skt, peeraddr, description, protocol) protocol.makeConnection(transport) except: log.failure("doRead()") def doWrite(self): """ Write some data. """ while self.statusQueue: msg = self.statusQueue.pop(0) try: sendmsg(self.fd, msg, 0) except SocketError, se: if se.errno in (EAGAIN, ENOBUFS): self.statusQueue.insert(0, msg) return raise self.stopWriting() def reportStatus(self, statusMessage): """ Report a status message to the L{_SubprocessSocket} monitoring this L{InheritedPort}'s health in the master process. """ self.statusQueue.append(statusMessage) self.startWriting() calendarserver-5.2+dfsg/LICENSE0000644000175000017500000002613610464710524015327 0ustar rahulrahul Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. calendarserver-5.2+dfsg/lib-patches/0000755000175000017500000000000012322625311016477 5ustar rahulrahulcalendarserver-5.2+dfsg/lib-patches/setproctitle/0000755000175000017500000000000012322625311021220 5ustar rahulrahulcalendarserver-5.2+dfsg/lib-patches/setproctitle/HAVE_DECL_STRLCPY.patch0000644000175000017500000000077012107015563025002 0ustar rahulrahuldiff -ur ./src/spt_config.h ../setproctitle-1.1.6.patched/src/spt_config.h --- ./src/spt_config.h 2012-04-27 03:00:09.000000000 -0700 +++ ../setproctitle-1.1.6.patched/src/spt_config.h 2013-02-13 14:53:20.000000000 -0800 @@ -15,7 +15,7 @@ /* Define to 1 if you have the declaration of `strlcpy', and to 0 if you don't. */ -#define HAVE_DECL_STRLCPY 0 +#define HAVE_DECL_STRLCPY 1 /* GCC 4.0 and later have support for specifying symbol visibility */ #if __GNUC__ >= 4 && !defined(__MINGW32__) calendarserver-5.2+dfsg/lib-patches/cx_Oracle/0000755000175000017500000000000012322625311020376 5ustar rahulrahulcalendarserver-5.2+dfsg/lib-patches/cx_Oracle/nclob-fixes-and-prefetch.patch0000644000175000017500000000756412030675176026214 0ustar rahulrahulIndex: Connection.c =================================================================== --- Connection.c 2011-03-19 16:05:30.000000000 -0700 +++ Connection.c 2012-08-01 09:22:17.000000000 -0700 @@ -713,6 +713,21 @@ if (newPasswordObj) return Connection_ChangePassword(self, self->password, newPasswordObj); +#ifdef OCI_ATTR_DEFAULT_LOBPREFETCH_SIZE + /* set lob prefetch attribute to session */ + ub4 default_lobprefetch_size = 4096; /* Set default size to 4K */ + status = OCIAttrSet (self->sessionHandle, (ub4) OCI_HTYPE_SESSION, + (void *)&default_lobprefetch_size, /* attribute value */ + 0, /* attribute size; not required to specify; */ + (ub4) OCI_ATTR_DEFAULT_LOBPREFETCH_SIZE, + self->environment->errorHandle); + if (Environment_CheckForError(self->environment, status, + "Connection_Connect(): OCI_ATTR_DEFAULT_LOBPREFETCH_SIZE") < 0) { + self->sessionHandle = NULL; + return -1; + } + +#endif // begin the session Py_BEGIN_ALLOW_THREADS status = OCISessionBegin(self->handle, self->environment->errorHandle, Index: Cursor.c =================================================================== --- Cursor.c 2011-03-19 16:05:30.000000000 -0700 +++ Cursor.c 2012-08-01 09:15:53.000000000 -0700 @@ -1813,8 +1813,8 @@ } } Py_BEGIN_ALLOW_THREADS - status = OCIStmtFetch(self->handle, self->environment->errorHandle, - numRows, OCI_FETCH_NEXT, OCI_DEFAULT); + status = OCIStmtFetch2(self->handle, self->environment->errorHandle, + numRows, OCI_FETCH_NEXT, 0, OCI_DEFAULT); Py_END_ALLOW_THREADS if (status != OCI_NO_DATA) { if (Environment_CheckForError(self->environment, status, Index: ExternalLobVar.c =================================================================== --- ExternalLobVar.c 2011-03-19 16:05:30.000000000 -0700 +++ ExternalLobVar.c 2012-07-31 14:26:16.000000000 -0700 @@ -170,6 +170,8 @@ int offset) // offset { sword status; + oraub8 blength = 0; + oraub8 clength = *length; if (var->lobVar->isFile) { Py_BEGIN_ALLOW_THREADS @@ -183,11 +185,13 @@ } Py_BEGIN_ALLOW_THREADS - status = OCILobRead(var->lobVar->connection->handle, + status = OCILobRead2(var->lobVar->connection->handle, var->lobVar->environment->errorHandle, - var->lobVar->data[var->pos], length, offset, buffer, - bufferSize, NULL, NULL, 0, var->lobVar->type->charsetForm); + var->lobVar->data[var->pos], &blength, &clength, offset, buffer, + bufferSize, OCI_ONE_PIECE, NULL, NULL, 0, var->lobVar->type->charsetForm); Py_END_ALLOW_THREADS + *length = blength; + if (Environment_CheckForError(var->lobVar->environment, status, "ExternalLobVar_LobRead()") < 0) { OCILobFileClose(var->lobVar->connection->handle, @@ -219,10 +223,10 @@ udt_ExternalLobVar *var) // variable to return the size of { sword status; - ub4 length; + oraub8 length; Py_BEGIN_ALLOW_THREADS - status = OCILobGetLength(var->lobVar->connection->handle, + status = OCILobGetLength2(var->lobVar->connection->handle, var->lobVar->environment->errorHandle, var->lobVar->data[var->pos], &length); Py_END_ALLOW_THREADS @@ -259,10 +263,9 @@ amount = 1; } length = amount; - if (var->lobVar->type == &vt_CLOB) + if ((var->lobVar->type == &vt_CLOB) || (var->lobVar->type == &vt_NCLOB)) + // Always use environment setting for character LOBs bufferSize = amount * var->lobVar->environment->maxBytesPerCharacter; - else if (var->lobVar->type == &vt_NCLOB) - bufferSize = amount * 2; else bufferSize = amount; // create a string for retrieving the value calendarserver-5.2+dfsg/lib-patches/pycrypto/0000755000175000017500000000000012322625311020370 5ustar rahulrahulcalendarserver-5.2+dfsg/lib-patches/pycrypto/__init__.py.patch0000644000175000017500000000032612046304325023602 0ustar rahulrahulIndex: lib/Crypto/Random/Fortuna/__init__.py =================================================================== --- lib/Crypto/Random/Fortuna/__init__.py +++ lib/Crypto/Random/Fortuna/__init__.py @@ -0,0 +1 @@ +# calendarserver-5.2+dfsg/HACKING0000644000175000017500000003400312235042574015302 0ustar rahulrahulDeveloper's Guide to Hacking the Calendar Server ================================================ If you are interested in contributing to the Calendar and Contacts Server project, please read this document. Participating in the Community ============================== Although the Calendar and Contacts Server is sponsored and hosted by Apple Inc. (http://www.apple.com/), it's a true open-source project under an Apache license. Contributions from other developers are welcome, and, as with all open development projects, may lead to "commit access" and a voice in the future of the project. The community exists mainly through mailing lists and a Subversion repository. To participate, go to: http://trac.calendarserver.org/projects/calendarserver/wiki/MailLists and join the appropriate mailing lists. We also use IRC, as described here: http://trac.calendarserver.org/projects/calendarserver/wiki/IRC There are many ways to join the project. One may write code, test the software and file bugs, write documentation, etc. The bug tracking database is here: http://trac.calendarserver.org/projects/calendarserver/report To help manage the issues database, read over the issue summaries, looking and testing for issues that are either invalid, or are duplicates of other issues. Both kinds are very common, the first because bugs often get unknowingly fixed as side effects of other changes in the code, and the second because people sometimes file an issue without noticing that it has already been reported. If you are not sure about an issue, post a question to calendarserver-dev@lists.macosforge.org. Before filing bugs, please take a moment to perform a quick search to see if someone else has already filed your bug. In that case, add a comment to the existing bug if appropriate and monitor it, rather than filing a duplicate. Obtaining the Code ================== The source code to the Calendar and Contacts Server is available via Subversion at this repository URL: http://svn.calendarserver.org/repository/calendarserver/CalendarServer/trunk/ You can also browse the repository directly using your web browser, or use WebDAV clients to browse the repository, such as Mac OS X's Finder (`Go -> Connect to Server`). A richer web interface which provides access to version history and logs is available via Trac here: http://trac.calendarserver.org/browser/ Most developers will want to use a full-featured Subversion client. More information about Subversion, including documentation and client download instructions, is available from the Subversion project: http://subversion.tigris.org/ Directory Layout ================ A rough guide to the source tree: * ``doc/`` - User and developer documentation, including relevant protocol specifications and extensions. * ``bin/`` - Executable programs. * ``conf/`` - Configuration files. * ``calendarserver/`` - Source code for the Calendar and Contacts Server * ``twistedcaldav/`` - Source code for CalDAV library * ``twistedcaldav/`` - Source code for extensions to Twisted * ``lib-patches/`` - Patch files which modify 3rd-party software required by the Calendar and Contacts Server. In an ideal world, this would be empty. * ``twisted/`` - Files required to set up the Calendar and Contacts Server as a Twisted service. Twisted (http://twistedmatrix.com/) is a networking framework upon which the Calendar and Contacts Server is built. * ``locales/`` - Localization files. * ``contrib/`` - Extra stuff that works with the Calendar and Contacts Server, or that helps integrate with other software (including operating systems), but that the Calendar and Contacts Server does not depend on. * ``support/`` - Support files of possible use to developers. Coding Standards ================ The vast majority of the Calendar and Contacts Server is written in the Python programming language. When writing Python code for the Calendar and Contacts Server, please observe the following conventions. Please note that all of our code at present does not follow these standards, but that does not mean that one shouldn't bother to do so. On the contrary, code changes that do nothing but reformat code to comply with these standards are welcome, and code changes that do not conform to these standards are discouraged. **We require Python 2.6 or higher.** It therefore is OK to write code that does not work with Python versions older than 2.6. Read PEP-8: http://www.python.org/dev/peps/pep-0008/ For the most part, our code should follow PEP-8, with a few exceptions and a few additions. It is also useful to review the Twisted Coding Standard, from which we borrow some standards, though we don't strictly follow it: http://twistedmatrix.com/trac/browser/trunk/doc/development/policy/coding-standard.xhtml?format=raw Key items to follow, and specifics: * Indent level is 4 spaces. * Never indent code with tabs. Always use spaces. PEP-8 items we do not follow: * PEP-8 recommends using a backslash to break long lines up: :: if width == 0 and height == 0 and \ color == 'red' and emphasis == 'strong' or \ highlight > 100: raise ValueError("sorry, you lose") Don't do that, it's gross, and the indentation for the ``raise`` line gets confusing. Use parentheses: :: if ( width == 0 and height == 0 and color == "red" and emphasis == "strong" or highlight > 100 ): raise ValueError("sorry, you lose") Just don't do it the way PEP-8 suggests: :: if width == 0 and height == 0 and (color == 'red' or emphasis is None): raise ValueError("I don't think so") Because that's just silly. Additions: * Close parentheses and brackets such as ``()``, ``[]`` and ``{}`` at the same indent level as the line in which you opened it: :: launchAtTarget( target="David", object=PaperWad( message="Yo!", crumpleFactor=0.7, ), speed=0.4, ) * Long lines are often due to long strings. Try to break strings up into multiple lines: :: processString( "This is a very long string with a lot of text. " "Fortunately, it is easy to break it up into parts " "like this." ) Similarly, callables that take many arguments can be broken up into multiple lines, as in the ``launchAtTarget()`` example above. * Breaking generator expressions and list comprehensions into multiple lines can improve readability. For example: :: myStuff = ( item.obtainUsefulValue() for item in someDataStore if item.owner() == me ) * Import symbols (especially class names) from modules instead of importing modules and referencing the symbol via the module unless it doesn't make sense to do so. For example: :: from subprocess import Popen process = Popen(...) Instead of: :: import subprocess process = subprocess.Popen(...) This makes code shorter and makes it easier to replace one implementation with another. * All files should have an ``__all__`` specification. Put them at the top of the file, before imports (PEP-8 puts them at the top, but after the imports), so you can see what the public symbols are for a file right at the top. * It is more important that symbol names are meaningful than it is that they be concise. ``x`` is rarely an appropriate name for a variable. Avoid contractions: ``transmogrifierStatus`` is more useful to the reader than ``trmgStat``. * A deferred that will be immediately returned may be called ``d``: :: d = doThisAndThat() d.addCallback(onResult) d.addErrback(onError) return d * Do not use ``deferredGenerator``. Use ``inlineCallbacks`` instead. * That said, avoid using ``inlineCallbacks`` when chaining deferreds is straightforward, as they are more expensive. Use ``inlineCallbacks`` when necessary for keeping code maintainable, such as when creating serialized deferreds in a for loop. * ``_`` may be used to denote unused callback arguments: :: def onCompletion(_): # Don't care about result of doThisAndThat() in here; # we only care that it has completed. doNextThing() d = doThisAndThat() d.addCallback(onCompletion) return d * Do not prefix symbols with ``_`` unless they might otherwise be exposed as a public symbol: a private method name should begin with ``_``, but a locally scoped variable should not, as there is no danger of it being exposed. Locally scoped variables are already private. * Per twisted convention, use camel-case (``fuzzyWidget``, ``doThisAndThat()``) for symbol names instead of using underscores (``fuzzy_widget``, ``do_this_and_that()``). Use of underscores is reserved for implied dispatching and the like (eg. ``http_FOO()``). See the Twisted Coding Standard for details. * Do not use ``%``-formatting: :: error = "Unexpected value: %s" % (value,) Use PEP-3101 formatting instead: :: error = "Unexpected value: {value}".format(value=value) * If you must use ``%``-formatting for some reason, always use a tuple as the format argument, even when only one value is being provided: :: error = "Unexpected value: %s" % (value,) Never use the non-tuple form: :: error = "Unexpected value: %s" % value Which is allowed in Python, but results in a programming error if ``type(value) is tuple and len(value) != 1``. * Don't use a trailing ``,`` at the end of a tuple if it's on one line: :: numbers = (1,2,3,) # No numbers = (1,2,3) # Yes The trailing comma is desirable on multiple lines, though, as that makes re-ordering items easy, and avoids a diff on the last line when adding another: :: strings = ( "This is a string.", "And so is this one.", "And here is yet another string.", ) * Docstrings are important. All public symbols (anything declared in ``__all__``) must have a correct docstring. The script ``docs/Developer/gendocs`` will generate the API documentation using ``pydoctor``. See the ``pydoctor`` documentation for details on the formatting: http://codespeak.net/~mwh/pydoctor/ Note: existing docstrings need a complete review. * Use PEP-257 as a guideline for docstrings. * Begin all multi-line docstrings with 3 double quotes and a newline: :: def doThisAndThat(...): """ Do this, and that. ... """ Best Practices ============== * If a callable is going to return a Deferred some of the time, it should return a deferred all of the time. Return ``succeed(value)`` instead of ``value`` if necessary. This avoids forcing the caller to check as to whether the value is a deferred or not (eg. by using ``maybeDeferred()``), which is both annoying to code and potentially expensive at runtime. * Be proactive about closing files and file-like objects. For a lot of Python software, letting Python close the stream for you works fine, but in a long-lived server that's processing many data streams at a time, it is important to close them as soon as possible. On some platforms (eg. Windows), deleting a file will fail if the file is still open. By leaving it up to Python to decide when to close a file, you may find yourself being unable to reliably delete it. The most reliable way to ensure that a stream is closed is to put the call to ``close()`` in a ``finally`` block: :: stream = file(somePath) try: ... do something with stream ... finally: stream.close() Testing ======= Be sure that all of the units tests pass before you commit new code. Code that breaks units tests may be reverted without further discussion; it is up to the committer to fix the problem and try again. Note that repeatedly committing code that breaks units tests presents a possible time sink for other developers, and is not looked upon favorably. Units tests can be run rather easily by executing the ``test`` script at the top of the Calendar and Contacts Server source tree. By default, it will run all of the Calendar and Contacts Server tests followed by all of the Twisted tests. You can run specific tests by specifying them as arguments like this: :: ./test twistedcaldav.static All non-trivial public callables must have unit tests. (Note we don't don't totally comply with this rule; that's a problem we'd like to fix.) All other callables should have unit tests. Units tests are written using the ``twisted.trial`` framework. Test module names should start with ``test_``. Twisted has some tips on writing tests here: http://twistedmatrix.com/projects/core/documentation/howto/testing.html http://twistedmatrix.com/trac/browser/trunk/doc/development/policy/test-standard.xhtml?format=raw We also use CalDAVTester (which is a companion to the Calendar and Contacts Server in the same Mac OS Forge project), which performs more "black box"-type testing against the server to ensure compliance with the CalDAV protocol. That requires running the server with a test configuration and then running CalDAVTester against it. For information about CalDAVTester is available here: http://trac.calendarserver.org/projects/calendarserver/wiki/CalDAVTester Commit Policy ============= We follow a commit-then-review policy for relatively "safe" changes to the code. If you have a rather straightforward change or are working on new functionality that does not affect existing functionality, you can commit that code without review at your discretion. Developers are encouraged to monitor the commit notifications that are sent via email after each commit and review/critique/comment on modifications as appropriate. Any changes that impact existing functionality should be reviewed by another developer before being committed. Large changes should be made on a branch and merged after review. This policy relies on the discretion of committers. calendarserver-5.2+dfsg/twistedcaldav/0000755000175000017500000000000012322625316017147 5ustar rahulrahulcalendarserver-5.2+dfsg/twistedcaldav/mkcolxml.py0000644000175000017500000000301412263343324021345 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Extended MKCOL Support. This module provides XML utilities for use with Extended MKCOL. This API is considered private to static.py and is therefore subject to change. See RFC 5689. """ from txdav.xml import element as davxml from txdav.xml.element import registerElement ## # Extended MKCOL objects ## mkcol_compliance = ( "extended-mkcol", ) @registerElement class MakeCollection (davxml.WebDAVElement): """ Top-level element for request body in MKCOL. (Extended-MKCOL, RFC 5689 section 5.1) """ name = "mkcol" allowed_children = { (davxml.dav_namespace, "set"): (0, 1) } child_types = { "WebDAVUnknownElement": (0, None) } @registerElement class MakeCollectionResponse (davxml.WebDAVElement): """ Top-level element for response body in MKCOL. (Extended-MKCOL, RFC 5689 section 5.2) """ name = "mkcol-response" allowed_children = { davxml.WebDAVElement: (0, None) } calendarserver-5.2+dfsg/twistedcaldav/directory-listing.html0000644000175000017500000000441511633725601023516 0ustar rahulrahul Collection listing for <t:slot name="name" />

Collection Listing

NameSizeLast Modified MIME Type

Properties

NameValue
calendarserver-5.2+dfsg/twistedcaldav/instance.py0000644000175000017500000005235612263343324021340 0ustar rahulrahul## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ iCalendar Recurrence Expansion Utilities """ from twistedcaldav.config import config from twistedcaldav.dateops import normalizeForIndex, differenceDateTime from pycalendar.datetime import PyCalendarDateTime from pycalendar.duration import PyCalendarDuration from pycalendar.period import PyCalendarPeriod from pycalendar.timezone import PyCalendarTimezone class TooManyInstancesError(Exception): def __init__(self): Exception.__init__(self) self.max_allowed = config.MaxAllowedInstances def __str__(self): return "Too many recurrence instances." def __repr__(self): return "<%s max:%s>" % (self.__class__.__name__, self.max_allowed) class InvalidOverriddenInstanceError(Exception): def __init__(self, rid): Exception.__init__(self) self.rid = rid def __str__(self): return "Invalid overridden instance :%s" % (self.rid,) def __repr__(self): return "<%s invalid:%s>" % (self.__class__.__name__, self.rid) class Instance(object): def __init__(self, component, start=None, end=None, rid=None, overridden=False, future=False): self.component = component self.start = component.getStartDateUTC() if start is None else start self.end = component.getEndDateUTC() if end is None else end self.rid = self.start if rid is None else rid self.overridden = overridden self.future = future def getAlarmTriggers(self): """ Get the set of alarm triggers for this instance. @return: a set containing the UTC datetime's of each trigger in each alarm """ triggers = set() for alarm in [x for x in self.component.subcomponents() if x.name() == "VALARM"]: (trigger, related, repeat, duration) = alarm.getTriggerDetails() # Handle relative vs absolute triggers if isinstance(trigger, PyCalendarDateTime): # Absolute trigger start = trigger else: # Relative trigger start = (self.start if related else self.end) + trigger triggers.add(start) # Handle repeats if repeat > 0: tstart = start.duplicate() for _ignore in xrange(1, repeat + 1): tstart += duration triggers.add(tstart) return triggers def isMasterInstance(self): return not self.overridden and self.start == self.component.getStartDateUTC() class InstanceList(object): def __init__(self, ignoreInvalidInstances=False, normalizeFunction=normalizeForIndex): self.instances = {} self.limit = None self.lowerLimit = None self.ignoreInvalidInstances = ignoreInvalidInstances self.normalizeFunction = normalizeFunction self.adjustedLowerLimit = None self.adjustedUpperLimit = None def __iter__(self): # Return keys in sorted order via iterator for i in sorted(self.instances.keys()): yield i def __getitem__(self, key): return self.instances[key] def expandTimeRanges(self, componentSet, limit, lowerLimit=None): """ Expand the set of recurrence instances up to the specified date limit. What we do is first expand the master instance into the set of generate instances. Then we merge the overridden instances, taking into account THISANDFUTURE and THISANDPRIOR. @param componentSet: the set of components that are to make up the recurrence set. These MUST all be components with the same UID and type, forming a proper recurring set. @param limit: L{PyCalendarDateTime} value representing the end of the expansion. """ # Look at each component type master = None overrides = [] for component in componentSet: if component.name() == "VEVENT": if component.hasProperty("RECURRENCE-ID"): overrides.append(component) else: self._addMasterEventComponent(component, lowerLimit, limit) master = component elif component.name() == "VTODO": if component.hasProperty("RECURRENCE-ID"): overrides.append(component) else: self._addMasterToDoComponent(component, lowerLimit, limit) master = component elif component.name() == "VJOURNAL": #TODO: VJOURNAL raise NotImplementedError("VJOURNAL recurrence expansion not supported yet") elif component.name() == "VFREEBUSY": self._addFreeBusyComponent(component, lowerLimit, limit) elif component.name() == "VAVAILABILITY": self._addAvailabilityComponent(component, lowerLimit, limit) elif component.name() == "AVAILABLE": if component.hasProperty("RECURRENCE-ID"): overrides.append(component) else: # AVAILABLE components are just like VEVENT components self._addMasterEventComponent(component, lowerLimit, limit) master = component for component in overrides: if component.name() == "VEVENT": self._addOverrideEventComponent(component, lowerLimit, limit, master) elif component.name() == "VTODO": self._addOverrideToDoComponent(component, lowerLimit, limit, master) elif component.name() == "VJOURNAL": #TODO: VJOURNAL raise NotImplementedError("VJOURNAL recurrence expansion not supported yet") elif component.name() == "AVAILABLE": # AVAILABLE components are just like VEVENT components self._addOverrideEventComponent(component, lowerLimit, limit, master) def addInstance(self, instance): """ Add the supplied instance to the map. @param instance: the instance to add """ self.instances[str(instance.rid)] = instance # Check for too many instances if config.MaxAllowedInstances and len(self.instances) > config.MaxAllowedInstances: raise TooManyInstancesError() def _setupLimits(self, dt, lowerLimit, upperLimit): """ Change the limits to account for testing against DATE only values. The lower limit is simply truncated to its date value. The upper limit is truncated to one day past the date value. """ if self.adjustedUpperLimit is None: if dt.isDateOnly(): if lowerLimit: self.adjustedLowerLimit = lowerLimit.duplicate() self.adjustedLowerLimit.setDateOnly(True) self.adjustedUpperLimit = upperLimit.duplicate() self.adjustedUpperLimit.setDateOnly(True) self.adjustedUpperLimit.offsetDay(1) else: self.adjustedLowerLimit = lowerLimit self.adjustedUpperLimit = upperLimit return (self.adjustedLowerLimit, self.adjustedUpperLimit,) def _getMasterEventDetails(self, component): """ Logic here comes from RFC4791 Section 9.9 """ start = component.getStartDateUTC() if start is None: return None rulestart = component.propertyValue("DTSTART") end = component.getEndDateUTC() duration = None if end is None: if not start.isDateOnly(): # Timed event with zero duration duration = PyCalendarDuration(days=0) else: # All day event default duration is one day duration = PyCalendarDuration(days=1) end = start + duration else: duration = differenceDateTime(start, end) return (rulestart, start, end, duration,) def _addMasterEventComponent(self, component, lowerLimit, upperLimit): """ Add the specified master VEVENT Component to the instance list, expanding it within the supplied time range. @param component: the Component to expand @param limit: the end L{PyCalendarDateTime} for expansion """ details = self._getMasterEventDetails(component) if details is None: return rulestart, start, end, duration = details lowerLimit, upperLimit = self._setupLimits(start, lowerLimit, upperLimit) self._addMasterComponent(component, lowerLimit, upperLimit, rulestart, start, end, duration) def _addOverrideEventComponent(self, component, lowerLimit, upperLimit, master): """ Add the specified overridden VEVENT Component to the instance list, replacing the one generated by the master component. @param component: the overridden Component. @param master: the master component which has already been expanded, or C{None}. """ #TODO: This does not take into account THISANDPRIOR - only THISANDFUTURE details = self._getMasterEventDetails(component) if details is None: return _ignore_rulestart, start, end, _ignore_duration = details lowerLimit, upperLimit = self._setupLimits(start, lowerLimit, upperLimit) self._addOverrideComponent(component, lowerLimit, upperLimit, start, end, master) def _getMasterToDoDetails(self, component): """ Logic here comes from RFC4791 Section 9.9 """ dtstart = component.getStartDateUTC() dtend = component.getEndDateUTC() dtdue = component.getDueDateUTC() # DTSTART and DURATION or DUE case if dtstart is not None: rulestart = component.propertyValue("DTSTART") start = dtstart if dtend is not None: end = dtend elif dtdue is not None: end = dtdue else: end = dtstart # DUE case elif dtdue is not None: rulestart = component.propertyValue("DUE") start = end = dtdue # Fall back to COMPLETED or CREATED - cannot be recurring else: rulestart = None from twistedcaldav.ical import maxDateTime, minDateTime dtcreated = component.getCreatedDateUTC() dtcompleted = component.getCompletedDateUTC() if dtcompleted: end = dtcompleted start = dtcreated if dtcreated else end elif dtcreated: start = dtcreated end = maxDateTime else: start = minDateTime end = maxDateTime duration = differenceDateTime(start, end) return (rulestart, start, end, duration,) def _addMasterToDoComponent(self, component, lowerLimit, upperLimit): """ Add the specified master VTODO Component to the instance list, expanding it within the supplied time range. @param component: the Component to expand @param limit: the end L{PyCalendarDateTime} for expansion """ details = self._getMasterToDoDetails(component) if details is None: return rulestart, start, end, duration = details lowerLimit, upperLimit = self._setupLimits(start, lowerLimit, upperLimit) self._addMasterComponent(component, lowerLimit, upperLimit, rulestart, start, end, duration) def _addOverrideToDoComponent(self, component, lowerLimit, upperLimit, master): """ Add the specified overridden VTODO Component to the instance list, replacing the one generated by the master component. @param component: the overridden Component. @param master: the master component which has already been expanded, or C{None}. """ #TODO: This does not take into account THISANDPRIOR - only THISANDFUTURE details = self._getMasterToDoDetails(component) if details is None: return _ignore_rulestart, start, end, _ignore_duration = details lowerLimit, upperLimit = self._setupLimits(start, lowerLimit, upperLimit) self._addOverrideComponent(component, lowerLimit, upperLimit, start, end, master) def _addMasterComponent(self, component, lowerLimit, upperlimit, rulestart, start, end, duration): rrules = component.getRecurrenceSet() if rrules is not None and rulestart is not None: # Do recurrence set expansion expanded = [] # Begin expansion far in the past because there may be RDATEs earlier # than the master DTSTART, and if we exclude those, the associated # overridden instances will cause an InvalidOverriddenInstance. limited = rrules.expand(rulestart, PyCalendarPeriod(PyCalendarDateTime(1900, 1, 1), upperlimit), expanded) for startDate in expanded: startDate = self.normalizeFunction(startDate) endDate = startDate + duration if lowerLimit is None or endDate >= lowerLimit: self.addInstance(Instance(component, startDate, endDate)) else: self.lowerLimit = lowerLimit if limited: self.limit = upperlimit else: # Always add main instance if included in range. if start < upperlimit: if lowerLimit is None or end >= lowerLimit: start = self.normalizeFunction(start) end = self.normalizeFunction(end) self.addInstance(Instance(component, start, end)) else: self.lowerLimit = lowerLimit else: self.limit = upperlimit self.master_cancelled = component.propertyValue("STATUS") == "CANCELLED" def _addOverrideComponent(self, component, lowerLimit, upperlimit, start, end, master): # Get the recurrence override info rid = component.getRecurrenceIDUTC() range = component.getRange() # Now add this instance, effectively overriding the one with the matching R-ID start = self.normalizeFunction(start) end = self.normalizeFunction(end) rid = self.normalizeFunction(rid) # Make sure start is within the limit if start > upperlimit and rid > upperlimit: return if lowerLimit is not None and end < lowerLimit and rid < lowerLimit: return # Make sure override RECURRENCE-ID is a valid instance of the master cancelled = component.propertyValue("STATUS") == "CANCELLED" if master is not None: if str(rid) not in self.instances and rid < upperlimit and (lowerLimit is None or rid >= lowerLimit): if self.master_cancelled or cancelled: # Ignore invalid overrides when either the master or override is cancelled pass elif self.ignoreInvalidInstances: return elif component.name() == "VEVENT": # Try to fix the R-ID in the case where the hour/minute/second components are all zero original_rid = component.propertyValue("RECURRENCE-ID").duplicate() if not original_rid.isDateOnly() and original_rid.mHours == 0 and original_rid.mMinutes == 0 and original_rid.mSeconds == 0: master_start = master.propertyValue("DTSTART") original_rid.setHHMMSS(master_start.mHours, master_start.mMinutes, master_start.mSeconds) rid = original_rid.duplicateAsUTC() rid = self.normalizeFunction(rid) if str(rid) not in self.instances: raise InvalidOverriddenInstanceError(str(rid)) else: component.getProperty("RECURRENCE-ID").setValue(original_rid) else: raise InvalidOverriddenInstanceError(str(rid)) else: raise InvalidOverriddenInstanceError(str(rid)) self.addInstance(Instance(component, start, end, rid, True, range)) # Handle THISANDFUTURE if present if range: # Iterate over all the instances after this one, replacing those # with a version based on this override component # We need to account for a time shift in the overridden component by # applying that shift to the future instances as well timeShift = (start != rid) if timeShift: offsetTime = start - rid newDuration = end - start # First get sorted instance keys greater than the current components R-ID for key in sorted(x for x in self.instances.keys() if x > str(rid)): oldinstance = self.instances[key] # Do not override instance that is already overridden if oldinstance.overridden: continue # Determine the start/end of the new instance originalStart = oldinstance.rid start = oldinstance.start end = oldinstance.end if timeShift: start += offsetTime end = start + newDuration # Now replacing existing entry with the new one self.addInstance(Instance(component, start, end, originalStart, False, False)) def _addFreeBusyComponent(self, component, lowerLimit, upperLimit): """ Add the specified master VFREEBUSY Component to the instance list, expanding it within the supplied time range. @param component: the Component to expand @param limit: the end L{PyCalendarDateTime} for expansion """ start = component.getStartDateUTC() end = component.getEndDateUTC() if end is None and start is not None: raise ValueError("VFREEBUSY component must have both DTSTART and DTEND: %r" % (component,)) if start: lowerLimit, upperLimit = self._setupLimits(start, lowerLimit, upperLimit) # If the free busy is beyond the end of the range we want, ignore it if start is not None and start >= upperLimit: return # If the free busy is before the start of the range we want, ignore it if lowerLimit is not None and end is not None and end < lowerLimit: return # Now look at each FREEBUSY property for fb in component.properties("FREEBUSY"): # Look at each period in the property assert isinstance(fb.value(), list), "FREEBUSY property does not contain a list of values: %r" % (fb,) for period in fb.value(): # Ignore if period starts after limit period = period.getValue() if period.getStart() >= upperLimit: continue start = self.normalizeFunction(period.getStart()) end = self.normalizeFunction(period.getEnd()) self.addInstance(Instance(component, start, end)) def _addAvailabilityComponent(self, component, lowerLimit, upperLimit): """ Add the specified master VAVAILABILITY Component to the instance list, expanding it within the supplied time range. VAVAILABILITY components are not recurring, they have an optional DTSTART and DTEND/DURATION defining a single time-range which may be bounded depending on the presence of the properties. If unbounded at one or both ends, we will set the time to 1/1/1900 in the past and 1/1/3000 in the future. @param component: the Component to expand @param limit: the end L{PyCalendarDateTime} for expansion """ start = component.getStartDateUTC() if start: lowerLimit, upperLimit = self._setupLimits(start, lowerLimit, upperLimit) if start is not None and start >= upperLimit: # If the availability is beyond the end of the range we want, ignore it return if start is None: start = PyCalendarDateTime(1900, 1, 1, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)) start = self.normalizeFunction(start) end = component.getEndDateUTC() if lowerLimit is not None and end is not None and end < lowerLimit: # If the availability is before the start of the range we want, ignore it return if end is None: end = PyCalendarDateTime(2100, 1, 1, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)) end = self.normalizeFunction(end) self.addInstance(Instance(component, start, end)) calendarserver-5.2+dfsg/twistedcaldav/xmlutil.py0000644000175000017500000000624112263343324021222 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from __future__ import with_statement import cStringIO as StringIO import xml.etree.ElementTree as XML try: from xml.etree.ElementTree import ParseError as XMLParseError except ImportError: from xml.parsers.expat import ExpatError as XMLParseError # Utilities for working with ElementTree def readXMLString(xmldata, expectedRootTag=None): io = StringIO.StringIO(xmldata) return readXML(io, expectedRootTag) def readXML(xmlfile, expectedRootTag=None): """ Read in XML data from a file and parse into ElementTree. Optionally verify the root node is what we expect. @param xmlfile: file to read from @type xmlfile: C{File} @param expectedRootTag: root tag (qname) to test or C{None} @type expectedRootTag: C{str} @return: C{tuple} of C{ElementTree}, C{Element} """ # Read in XML try: etree = XML.ElementTree(file=xmlfile) except XMLParseError, e: raise ValueError("Unable to parse file '%s' because: %s" % (xmlfile, e,)) if expectedRootTag: root = etree.getroot() if root.tag != expectedRootTag: raise ValueError("Ignoring file '%s' because it is not a %s file" % (xmlfile, expectedRootTag,)) return etree, etree.getroot() def elementToXML(element): return XML.tostring(element, "utf-8") def writeXML(xmlfile, root): data = """ """ % (root.tag, root.tag) INDENT = 2 # Generate indentation def _indentNode(node, level=0): if node.text is not None and node.text.strip(): return elif len(node): indent = "\n" + " " * (level + 1) * INDENT node.text = indent for child in node: child.tail = indent _indentNode(child, level + 1) if len(node): node[-1].tail = "\n" + " " * level * INDENT _indentNode(root, 0) data += XML.tostring(root) + "\n" with open(xmlfile, "w") as f: f.write(data) def newElementTreeWithRoot(roottag): root = createElement(roottag) etree = XML.ElementTree(root) return etree, root def createElement(tag, text=None, **attrs): child = XML.Element(tag, attrs) child.text = text return child def addSubElement(parent, tag, text=None): child = XML.SubElement(parent, tag) child.text = text return child def changeSubElementText(parent, tag, text): child = parent.find(tag) if child is not None: child.text = text else: addSubElement(parent, tag, text) calendarserver-5.2+dfsg/twistedcaldav/client/0000755000175000017500000000000012322625316020425 5ustar rahulrahulcalendarserver-5.2+dfsg/twistedcaldav/client/pool.py0000644000175000017500000003351612263343324021760 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## __all__ = [ "installPools", "installPool", "getHTTPClientPool", ] import OpenSSL import urlparse from twext.python.log import Logger from twext.internet.gaiendpoint import GAIEndpoint from twext.internet.adaptendpoint import connect from twext.internet.ssl import ChainingOpenSSLContextFactory from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.internet.error import ConnectionLost, ConnectionDone, ConnectError from twisted.internet.protocol import ClientFactory from twext.web2 import responsecode from twext.web2.client.http import HTTPClientProtocol from twext.web2.http import StatusResponse, HTTPError from twext.web2.dav.util import allDataFromStream from twext.web2.stream import MemoryStream class PooledHTTPClientFactory(ClientFactory): """ A client factory for HTTPClient that notifies a pool of it's state. It the connection fails in the middle of a request it will retry the request. @ivar protocol: The current instance of our protocol that we pass to our connectionPool. @ivar connectionPool: A managing connection pool that we notify of events. """ log = Logger() protocol = HTTPClientProtocol connectionPool = None def __init__(self, reactor): self.reactor = reactor self.instance = None self.onConnect = Deferred() self.afterConnect = Deferred() def clientConnectionLost(self, connector, reason): """ Notify the connectionPool that we've lost our connection. """ if hasattr(self, "afterConnect"): self.reactor.callLater(0, self.afterConnect.errback, reason) del self.afterConnect if self.connectionPool.shutdown_requested: # The reactor is stopping; don't reconnect return def clientConnectionFailed(self, connector, reason): """ Notify the connectionPool that we're unable to connect """ if hasattr(self, "onConnect"): self.reactor.callLater(0, self.onConnect.errback, reason) del self.onConnect elif hasattr(self, "afterConnect"): self.reactor.callLater(0, self.afterConnect.errback, reason) del self.afterConnect def buildProtocol(self, addr): self.instance = self.protocol() self.reactor.callLater(0, self.onConnect.callback, self.instance) del self.onConnect return self.instance class HTTPClientPool(object): """ A connection pool for HTTPClientProtocol instances. @ivar clientFactory: The L{ClientFactory} implementation that will be used for each protocol. @ivar _maxClients: A C{int} indicating the maximum number of clients. @ivar _endpoint: An L{IStreamClientEndpoint} provider indicating the server to connect to. @ivar _reactor: The L{IReactorTCP} provider used to initiate new connections. @ivar _busyClients: A C{set} that contains all currently busy clients. @ivar _freeClients: A C{set} that contains all currently free clients. @ivar _pendingConnects: A C{int} indicating how many connections are in progress. """ log = Logger() clientFactory = PooledHTTPClientFactory maxRetries = 2 def __init__(self, name, scheme, endpoint, secureEndpoint, maxClients=5, reactor=None): """ @param endpoint: An L{IStreamClientEndpoint} indicating the server to connect to. @param maxClients: A C{int} indicating the maximum number of clients. @param reactor: An L{IReactorTCP} provider used to initiate new connections. """ self._name = name self._scheme = scheme self._endpoint = endpoint self._secureEndpoint = secureEndpoint self._maxClients = maxClients if reactor is None: from twisted.internet import reactor self._reactor = reactor self.shutdown_deferred = None self.shutdown_requested = False reactor.addSystemEventTrigger('before', 'shutdown', self._shutdownCallback) self._busyClients = set([]) self._freeClients = set([]) self._pendingConnects = 0 self._pendingRequests = [] def _isIdle(self): return ( len(self._busyClients) == 0 and len(self._pendingRequests) == 0 and self._pendingConnects == 0 ) def _shutdownCallback(self): self.shutdown_requested = True if self._isIdle(): return None self.shutdown_deferred = Deferred() return self.shutdown_deferred def _newClientConnection(self): """ Create a new client connection. @return: A L{Deferred} that fires with the L{IProtocol} instance. """ self._pendingConnects += 1 self.log.debug("Initating new client connection to: %r" % ( self._endpoint,)) self._logClientStats() factory = self.clientFactory(self._reactor) factory.connectionPool = self if self._scheme == "https": connect(self._secureEndpoint, factory) elif self._scheme == "http": connect(self._endpoint, factory) else: raise ValueError("URL scheme for client pool not supported") def _doneOK(client): self._pendingConnects -= 1 def _goneClientAfterError(f, client): f.trap(ConnectionLost, ConnectionDone, ConnectError) self.clientGone(client) d2 = factory.afterConnect d2.addErrback(_goneClientAfterError, client) return client def _doneError(result): self._pendingConnects -= 1 return result d = factory.onConnect d.addCallbacks(_doneOK, _doneError) return d def _performRequestOnClient(self, client, request, *args, **kwargs): """ Perform the given request on the given client. @param client: A L{PooledMemCacheProtocol} that will be used to perform the given request. @param command: A C{str} representing an attribute of L{MemCacheProtocol}. @parma args: Any positional arguments that should be passed to C{command}. @param kwargs: Any keyword arguments that should be passed to C{command}. @return: A L{Deferred} that fires with the result of the given command. """ def _freeClientAfterRequest(result): self.clientFree(client) return result def _goneClientAfterError(result): self.clientGone(client) return result self.clientBusy(client) d = client.submitRequest(request, closeAfter=True) d.addCallbacks(_freeClientAfterRequest, _goneClientAfterError) return d @inlineCallbacks def submitRequest(self, request, *args, **kwargs): """ Select an available client and perform the given request on it. @param command: A C{str} representing an attribute of L{MemCacheProtocol}. @parma args: Any positional arguments that should be passed to C{command}. @param kwargs: Any keyword arguments that should be passed to C{command}. @return: A L{Deferred} that fires with the result of the given command. """ # Since we may need to replay the request we have to read the request.stream # into memory and reset request.stream to use a MemoryStream each time we repeat # the request data = (yield allDataFromStream(request.stream)) # Try this maxRetries times for ctr in xrange(self.maxRetries + 1): try: request.stream = MemoryStream(data if data is not None else "") request.stream.doStartReading = None response = (yield self._submitRequest(request, args, kwargs)) except (ConnectionLost, ConnectionDone, ConnectError), e: self.log.error("HTTP pooled client connection error (attempt: %d) - retrying: %s" % (ctr + 1, e,)) continue # TODO: find the proper cause of these assertions and fix except (AssertionError,), e: self.log.error("HTTP pooled client connection assertion error (attempt: %d) - retrying: %s" % (ctr + 1, e,)) continue else: returnValue(response) else: self.log.error("HTTP pooled client connection error - exhausted retry attempts.") raise HTTPError(StatusResponse(responsecode.BAD_GATEWAY, "Could not connect to HTTP pooled client host.")) def _submitRequest(self, request, *args, **kwargs): """ Select an available client and perform the given request on it. @param command: A C{str} representing an attribute of L{MemCacheProtocol}. @parma args: Any positional arguments that should be passed to C{command}. @param kwargs: Any keyword arguments that should be passed to C{command}. @return: A L{Deferred} that fires with the result of the given command. """ if len(self._freeClients) > 0: d = self._performRequestOnClient(self._freeClients.pop(), request, *args, **kwargs) elif len(self._busyClients) + self._pendingConnects >= self._maxClients: d = Deferred() self._pendingRequests.append((d, request, args, kwargs)) self.log.debug("Request queued: %s, %r, %r" % (request, args, kwargs)) self._logClientStats() else: d = self._newClientConnection() d.addCallback(self._performRequestOnClient, request, *args, **kwargs) return d def _logClientStats(self): self.log.debug("Clients #free: %d, #busy: %d, " "#pending: %d, #queued: %d" % ( len(self._freeClients), len(self._busyClients), self._pendingConnects, len(self._pendingRequests))) def clientGone(self, client): """ Notify that the given client is to be removed from the pool completely. @param client: An instance of L{PooledMemCacheProtocol}. """ if client in self._busyClients: self._busyClients.remove(client) elif client in self._freeClients: self._freeClients.remove(client) self.log.debug("Removed client: %r" % (client,)) self._logClientStats() self._processPending() def clientBusy(self, client): """ Notify that the given client is being used to complete a request. @param client: An instance of C{self.clientFactory} """ if client in self._freeClients: self._freeClients.remove(client) self._busyClients.add(client) self.log.debug("Busied client: %r" % (client,)) self._logClientStats() def clientFree(self, client): """ Notify that the given client is free to handle more requests. @param client: An instance of C{self.clientFactory} """ if client in self._busyClients: self._busyClients.remove(client) self._freeClients.add(client) if self.shutdown_deferred and self._isIdle(): self.shutdown_deferred.callback(None) self.log.debug("Freed client: %r" % (client,)) self._logClientStats() self._processPending() def _processPending(self): if len(self._pendingRequests) > 0: d, request, args, kwargs = self._pendingRequests.pop(0) self.log.debug("Performing Queued Request: %s, %r, %r" % ( request, args, kwargs)) self._logClientStats() _ign_d = self._submitRequest(request, *args, **kwargs) _ign_d.addCallbacks(d.callback, d.errback) def suggestMaxClients(self, maxClients): """ Suggest the maximum number of concurrently connected clients. @param maxClients: A C{int} indicating how many client connections we should keep open. """ self._maxClients = maxClients _clientPools = {} # Maps a host:port to a pool object def installPools(hosts, maxClients=5, reactor=None): for name, url in hosts: installPool( name, url, maxClients, reactor, ) def _configuredClientContextFactory(): """ Get a client context factory from the configuration. """ from twistedcaldav.config import config return ChainingOpenSSLContextFactory( config.SSLPrivateKey, config.SSLCertificate, certificateChainFile=config.SSLAuthorityChain, sslmethod=getattr(OpenSSL.SSL, config.SSLMethod) ) def installPool(name, url, maxClients=5, reactor=None): if reactor is None: from twisted.internet import reactor parsedURL = urlparse.urlparse(url) ctxf = _configuredClientContextFactory() pool = HTTPClientPool( name, parsedURL.scheme, GAIEndpoint(reactor, parsedURL.hostname, parsedURL.port), GAIEndpoint(reactor, parsedURL.hostname, parsedURL.port, ctxf), maxClients, reactor, ) _clientPools[name] = pool def getHTTPClientPool(name): return _clientPools[name] calendarserver-5.2+dfsg/twistedcaldav/client/test/0000755000175000017500000000000012322625316021404 5ustar rahulrahulcalendarserver-5.2+dfsg/twistedcaldav/client/test/test_reverseproxy.py0000644000175000017500000000572012263343324025576 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twext.web2.client.http import ClientRequest from twext.web2.http import HTTPError from twext.web2.test.test_server import SimpleRequest from twistedcaldav.client.pool import _clientPools from twistedcaldav.client.reverseproxy import ReverseProxyResource from twistedcaldav.config import config import twistedcaldav.test.util class ReverseProxyNoLoop (twistedcaldav.test.util.TestCase): """ Prevent loops in reverse proxy """ def setUp(self): class DummyPool(object): def submitRequest(self, request): return request _clientPools["pool"] = DummyPool() super(ReverseProxyNoLoop, self).setUp() def test_No_Header(self): proxy = ReverseProxyResource("pool") request = SimpleRequest(proxy, "GET", "/") self.assertIsInstance(proxy.renderHTTP(request), ClientRequest) def test_Header_Other_Server(self): proxy = ReverseProxyResource("pool") request = SimpleRequest(proxy, "GET", "/") request.headers.addRawHeader("x-forwarded-server", "foobar.example.com") self.assertIsInstance(proxy.renderHTTP(request), ClientRequest) def test_Header_Other_Servers(self): proxy = ReverseProxyResource("pool") request = SimpleRequest(proxy, "GET", "/") request.headers.setHeader("x-forwarded-server", ("foobar.example.com", "bar.example.com",)) self.assertIsInstance(proxy.renderHTTP(request), ClientRequest) def test_Header_Our_Server(self): proxy = ReverseProxyResource("pool") request = SimpleRequest(proxy, "GET", "/") request.headers.addRawHeader("x-forwarded-server", config.ServerHostName) self.assertRaises(HTTPError, proxy.renderHTTP, request) def test_Header_Our_Server_Moxied(self): proxy = ReverseProxyResource("pool") request = SimpleRequest(proxy, "GET", "/") request.headers.setHeader("x-forwarded-server", ("foobar.example.com", "bar.example.com", config.ServerHostName,)) self.assertRaises(HTTPError, proxy.renderHTTP, request) def test_Header_Our_Server_Allowed(self): proxy = ReverseProxyResource("pool") proxy.allowMultiHop = True request = SimpleRequest(proxy, "GET", "/") request.headers.addRawHeader("x-forwarded-server", config.ServerHostName) self.assertIsInstance(proxy.renderHTTP(request), ClientRequest) calendarserver-5.2+dfsg/twistedcaldav/client/test/__init__.py0000644000175000017500000000122212263343324023512 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for the twistedcaldav.client module. """ calendarserver-5.2+dfsg/twistedcaldav/client/geturl.py0000644000175000017500000000745512263343324022314 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twext.python.log import Logger from twisted.internet import reactor, protocol from twisted.internet.defer import inlineCallbacks, Deferred, returnValue from twisted.web import http_headers from twisted.web.client import Agent from twisted.web.http import MOVED_PERMANENTLY, TEMPORARY_REDIRECT, FOUND from urlparse import urlparse from urlparse import urlunparse __all__ = [ "getURL", ] log = Logger() class AccumulatingProtocol(protocol.Protocol): """ L{AccumulatingProtocol} is an L{IProtocol} implementation which collects the data delivered to it and can fire a Deferred when it is connected or disconnected. @ivar made: A flag indicating whether C{connectionMade} has been called. @ivar data: A string giving all the data passed to C{dataReceived}. @ivar closed: A flag indicated whether C{connectionLost} has been called. @ivar closedReason: The value of the I{reason} parameter passed to C{connectionLost}. @ivar closedDeferred: If set to a L{Deferred}, this will be fired when C{connectionLost} is called. """ made = closed = 0 closedReason = None closedDeferred = None data = "" factory = None def connectionMade(self): self.made = 1 if (self.factory is not None and self.factory.protocolConnectionMade is not None): d = self.factory.protocolConnectionMade self.factory.protocolConnectionMade = None d.callback(self) def dataReceived(self, data): self.data += data def connectionLost(self, reason): self.closed = 1 self.closedReason = reason if self.closedDeferred is not None: d, self.closedDeferred = self.closedDeferred, None d.callback(None) @inlineCallbacks def getURL(url, method="GET", redirect=0): if isinstance(url, unicode): url = url.encode("utf-8") agent = Agent(reactor) headers = http_headers.Headers({}) try: response = (yield agent.request(method, url, headers, None)) except Exception, e: log.error(str(e)) response = None else: if response.code in (MOVED_PERMANENTLY, FOUND, TEMPORARY_REDIRECT,): if redirect > 3: log.error("Too many redirects") else: location = response.headers.getRawHeaders("location") if location: newresponse = (yield getURL(location[0], method=method, redirect=redirect + 1)) if response.code == MOVED_PERMANENTLY: scheme, netloc, url, _ignore_params, _ignore_query, _ignore_fragment = urlparse(location[0]) newresponse.location = urlunparse((scheme, netloc, url, None, None, None,)) returnValue(newresponse) else: log.error("Redirect without a Location header") if response is not None and response.code / 100 == 2: protocol = AccumulatingProtocol() response.deliverBody(protocol) whenFinished = protocol.closedDeferred = Deferred() yield whenFinished response.data = protocol.data else: log.error("Failed getURL: %s" % (url,)) returnValue(response) calendarserver-5.2+dfsg/twistedcaldav/client/reverseproxy.py0000644000175000017500000000601712263343324023560 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## __all__ = [ "ReverseProxyResource", ] from zope.interface.declarations import implements from twext.web2 import iweb, responsecode from twext.web2.client.http import ClientRequest from twext.web2.http import StatusResponse, HTTPError from twext.web2.resource import LeafResource from twext.python.log import Logger from twistedcaldav.client.pool import getHTTPClientPool from twistedcaldav.config import config class ReverseProxyResource(LeafResource): """ A L{LeafResource} which always performs a reverse proxy operation. """ log = Logger() implements(iweb.IResource) def __init__(self, poolID, *args, **kwargs): """ @param poolID: idenitifier of the pool to use @type poolID: C{str} """ self.poolID = poolID self._args = args self._kwargs = kwargs self.allowMultiHop = False def isCollection(self): return True def exists(self): return False def renderHTTP(self, request): """ Do the reverse proxy request and return the response. @param request: the incoming request that needs to be proxied. @type request: L{Request} @return: Deferred L{Response} """ self.log.info("%s %s %s" % (request.method, request.uri, "HTTP/%s.%s" % request.clientproto)) # Check for multi-hop if not self.allowMultiHop: x_server = request.headers.getHeader("x-forwarded-server") if x_server: for item in x_server: if item.lower() == config.ServerHostName.lower(): raise HTTPError(StatusResponse(responsecode.BAD_GATEWAY, "Too many x-forwarded-server hops")) clientPool = getHTTPClientPool(self.poolID) proxyRequest = ClientRequest(request.method, request.uri, request.headers, request.stream) # Need x-forwarded-(for|host|server) headers. First strip any existing ones out, then add ours proxyRequest.headers.removeHeader("x-forwarded-host") proxyRequest.headers.removeHeader("x-forwarded-for") proxyRequest.headers.removeHeader("x-forwarded-server") proxyRequest.headers.addRawHeader("x-forwarded-host", request.host) proxyRequest.headers.addRawHeader("x-forwarded-for", request.remoteAddr.host) proxyRequest.headers.addRawHeader("x-forwarded-server", config.ServerHostName) return clientPool.submitRequest(proxyRequest) calendarserver-5.2+dfsg/twistedcaldav/client/__init__.py0000644000175000017500000000205612263343324022541 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twext.web2.http_headers import DefaultHTTPHandler, tokenize, generateList, singleHeader DefaultHTTPHandler.updateParsers({ "x-forwarded-for": (tokenize, list), "x-forwarded-host": (tokenize, list), "x-forwarded-server": (tokenize, list), }) DefaultHTTPHandler.updateGenerators({ "x-forwarded-for": (generateList, singleHeader), "x-forwarded-host": (generateList, singleHeader), "x-forwarded-server": (generateList, singleHeader), }) calendarserver-5.2+dfsg/twistedcaldav/extensions.py0000644000175000017500000010443612263343324021730 0ustar rahulrahul# -*- test-case-name: twistedcaldav.test.test_extensions -*- ## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from __future__ import print_function """ Extensions to web2.dav """ __all__ = [ "DAVResource", "DAVResourceWithChildrenMixin", "DAVPrincipalResource", "DAVFile", "ReadOnlyWritePropertiesResourceMixIn", "ReadOnlyResourceMixIn", "CachingPropertyStore", ] import urllib import time from itertools import cycle from twisted.internet.defer import succeed, maybeDeferred from twisted.internet.defer import inlineCallbacks, returnValue from twisted.web.template import Element, XMLFile, renderer, tags, flattenString from twisted.python.modules import getModule from twext.web2 import responsecode, server from twext.web2.http import HTTPError, Response, RedirectResponse from twext.web2.http import StatusResponse from twext.web2.http_headers import MimeType from twext.web2.stream import FileStream from twext.web2.static import MetaDataMixin, StaticRenderMixin from txdav.xml import element from txdav.xml.base import encodeXMLName from txdav.xml.element import dav_namespace from twext.web2.dav.http import MultiStatusResponse from twext.web2.dav.static import DAVFile as SuperDAVFile from twext.web2.dav.resource import DAVResource as SuperDAVResource from twext.web2.dav.resource import ( DAVPrincipalResource as SuperDAVPrincipalResource ) from twisted.internet.defer import gatherResults from twext.web2.dav.method import prop_common from twext.python.log import Logger from twistedcaldav import customxml from twistedcaldav.customxml import calendarserver_namespace from twistedcaldav.method.report import http_REPORT from twistedcaldav.config import config thisModule = getModule(__name__) log = Logger() class DirectoryPrincipalPropertySearchMixIn(object): @inlineCallbacks def report_DAV__principal_property_search(self, request, principal_property_search): """ Generate a principal-property-search REPORT. (RFC 3744, section 9.4) Overrides twisted implementation, targeting only directory-enabled searching. """ # Verify root element if not isinstance(principal_property_search, element.PrincipalPropertySearch): msg = "%s expected as root element, not %s." % (element.PrincipalPropertySearch.sname(), principal_property_search.sname()) log.warn(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) # Should we AND (the default) or OR (if test="anyof")? testMode = principal_property_search.attributes.get("test", "allof") if testMode not in ("allof", "anyof"): msg = "Bad XML: unknown value for test attribute: %s" % (testMode,) log.warn(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) operand = "and" if testMode == "allof" else "or" # Are we narrowing results down to a single CUTYPE? cuType = principal_property_search.attributes.get("type", None) if cuType not in ("INDIVIDUAL", "GROUP", "RESOURCE", "ROOM", None): msg = "Bad XML: unknown value for type attribute: %s" % (cuType,) log.warn(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) # Only handle Depth: 0 depth = request.headers.getHeader("depth", "0") if depth != "0": log.error("Error in principal-property-search REPORT, Depth set to %s" % (depth,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,))) # Get any limit value from xml clientLimit = None # Get a single DAV:prop element from the REPORT request body propertiesForResource = None propElement = None propertySearches = [] applyTo = False for child in principal_property_search.children: if child.qname() == (dav_namespace, "prop"): propertiesForResource = prop_common.propertyListForResource propElement = child elif child.qname() == (dav_namespace, "apply-to-principal-collection-set"): applyTo = True elif child.qname() == (dav_namespace, "property-search"): props = child.childOfType(element.PropertyContainer) props.removeWhitespaceNodes() match = child.childOfType(element.Match) caseless = match.attributes.get("caseless", "yes") if caseless not in ("yes", "no"): msg = "Bad XML: unknown value for caseless attribute: %s" % (caseless,) log.warn(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) caseless = (caseless == "yes") matchType = match.attributes.get("match-type", u"contains").encode("utf-8") if matchType not in ("starts-with", "contains", "equals"): msg = "Bad XML: unknown value for match-type attribute: %s" % (matchType,) log.warn(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) # Ignore any query strings under three letters matchText = str(match) if len(matchText) >= 3: propertySearches.append((props.children, matchText, caseless, matchType)) elif child.qname() == (calendarserver_namespace, "limit"): try: nresults = child.childOfType(customxml.NResults) clientLimit = int(str(nresults)) except (TypeError, ValueError,): msg = "Bad XML: unknown value for element" log.warn(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) # Run report resultsWereLimited = None resources = [] if applyTo or not hasattr(self, "directory"): for principalCollection in self.principalCollections(): uri = principalCollection.principalCollectionURL() resource = (yield request.locateResource(uri)) if resource: resources.append((resource, uri)) else: resources.append((self, request.uri)) # We need to access a directory service principalCollection = resources[0][0] if not hasattr(principalCollection, "directory"): # Use Twisted's implementation instead in this case result = (yield super(DirectoryPrincipalPropertySearchMixIn, self).report_DAV__principal_property_search(request, principal_property_search)) returnValue(result) dir = principalCollection.directory # See if we can take advantage of the directory fields = [] nonDirectorySearches = [] for props, match, caseless, matchType in propertySearches: nonDirectoryProps = [] for prop in props: try: fieldName, match = principalCollection.propertyToField( prop, match) except ValueError, e: raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, str(e))) if fieldName: fields.append((fieldName, match, caseless, matchType)) else: nonDirectoryProps.append(prop) if nonDirectoryProps: nonDirectorySearches.append((nonDirectoryProps, match, caseless, matchType)) matchingResources = [] matchcount = 0 # nonDirectorySearches are ignored if fields: records = (yield dir.recordsMatchingFieldsWithCUType(fields, operand=operand, cuType=cuType)) for record in records: resource = principalCollection.principalForRecord(record) if resource: matchingResources.append(resource) # We've determined this is a matching resource matchcount += 1 if clientLimit is not None and matchcount >= clientLimit: resultsWereLimited = ("client", matchcount) break if matchcount >= config.MaxPrincipalSearchReportResults: resultsWereLimited = ("server", matchcount) break # Generate the response responses = [] for resource in matchingResources: url = resource.url() yield prop_common.responseForHref( request, responses, element.HRef.fromString(url), resource, propertiesForResource, propElement ) if resultsWereLimited is not None: if resultsWereLimited[0] == "server": log.error("Too many matching resources in principal-property-search report") responses.append(element.StatusResponse( element.HRef.fromString(request.uri), element.Status.fromResponseCode( responsecode.INSUFFICIENT_STORAGE_SPACE ), element.Error(element.NumberOfMatchesWithinLimits()), element.ResponseDescription("Results limited by %s at %d" % resultsWereLimited), )) returnValue(MultiStatusResponse(responses)) @inlineCallbacks def report_http___calendarserver_org_ns__calendarserver_principal_search(self, request, calendarserver_principal_search): """ Generate a calendarserver-principal-search REPORT. @param request: Request object @param calendarserver_principal_search: CalendarServerPrincipalSearch object """ # Verify root element if not isinstance(calendarserver_principal_search, customxml.CalendarServerPrincipalSearch): msg = "%s expected as root element, not %s." % (customxml.CalendarServerPrincipalSearch.sname(), calendarserver_principal_search.sname()) log.warn(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) # Only handle Depth: 0 depth = request.headers.getHeader("depth", "0") if depth != "0": log.error("Error in calendarserver-principal-search REPORT, Depth set to %s" % (depth,)) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,))) tokens, context, applyTo, clientLimit, propElement = extractCalendarServerPrincipalSearchData(calendarserver_principal_search) if not validateTokens(tokens): raise HTTPError(StatusResponse(responsecode.FORBIDDEN, "Insufficient search token length")) # Run report resultsWereLimited = None resources = [] if applyTo or not hasattr(self, "directory"): for principalCollection in self.principalCollections(): uri = principalCollection.principalCollectionURL() resource = (yield request.locateResource(uri)) if resource: resources.append((resource, uri)) else: resources.append((self, request.uri)) # We need to access a directory service principalCollection = resources[0][0] dir = principalCollection.directory matchingResources = [] matchcount = 0 records = (yield dir.recordsMatchingTokens(tokens, context=context)) for record in records: resource = principalCollection.principalForRecord(record) if resource: matchingResources.append(resource) # We've determined this is a matching resource matchcount += 1 if clientLimit is not None and matchcount >= clientLimit: resultsWereLimited = ("client", matchcount) break if matchcount >= config.MaxPrincipalSearchReportResults: resultsWereLimited = ("server", matchcount) break # Generate the response responses = [] for resource in matchingResources: url = resource.url() yield prop_common.responseForHref( request, responses, element.HRef.fromString(url), resource, prop_common.propertyListForResource, propElement ) if resultsWereLimited is not None: if resultsWereLimited[0] == "server": log.error("Too many matching resources in calendarserver-principal-search report") responses.append(element.StatusResponse( element.HRef.fromString(request.uri), element.Status.fromResponseCode( responsecode.INSUFFICIENT_STORAGE_SPACE ), element.Error(element.NumberOfMatchesWithinLimits()), element.ResponseDescription("Results limited by %s at %d" % resultsWereLimited), )) returnValue(MultiStatusResponse(responses)) class DirectoryElement(Element): """ A L{DirectoryElement} is an L{Element} for rendering the contents of a L{DirectoryRenderingMixIn} resource as HTML. """ loader = XMLFile( thisModule.filePath.sibling("directory-listing.html") ) def __init__(self, resource): """ @param resource: the L{DirectoryRenderingMixIn} resource being listed. """ super(DirectoryElement, self).__init__() self.resource = resource @renderer def resourceDetail(self, request, tag): """ Renderer which returns a distinct element for this resource's data. Subclasses should override. """ return '' @renderer def children(self, request, tag): """ Renderer which yields all child object tags as table rows. """ whenChildren = ( maybeDeferred(self.resource.listChildren) .addCallback(sorted) .addCallback( lambda names: gatherResults( [maybeDeferred(self.resource.getChild, x) for x in names] ) .addCallback(lambda children: zip(children, names)) ) ) @whenChildren.addCallback def gotChildren(children): for even, [child, name] in zip(cycle(["odd", "even"]), children): [url, name, size, lastModified, contentType] = map( str, self.resource.getChildDirectoryEntry( child, name, request) ) yield tag.clone().fillSlots( url=url, name=name, size=str(size), lastModified=lastModified, even=even, type=contentType, ) return whenChildren @renderer def main(self, request, tag): """ Main renderer; fills slots for title, etc. """ return tag.fillSlots(name=request.path) @renderer def properties(self, request, tag): """ Renderer which yields all properties as table row tags. """ whenPropertiesListed = self.resource.listProperties(request) @whenPropertiesListed.addCallback def gotProperties(qnames): accessDeniedValue = object() def gotError(f, name): f.trap(HTTPError) code = f.value.response.code if code == responsecode.NOT_FOUND: log.error("Property %s was returned by listProperties() " "but does not exist for resource %s." % (name, self.resource)) return (name, None) if code == responsecode.UNAUTHORIZED: return (name, accessDeniedValue) return f whenAllProperties = gatherResults([ maybeDeferred(self.resource.readProperty, qn, request) .addCallback(lambda p, iqn=qn: (p.sname(), p.toxml()) if p is not None else (encodeXMLName(*iqn), None)) .addErrback(gotError, encodeXMLName(*qn)) for qn in sorted(qnames) ]) @whenAllProperties.addCallback def gotValues(items): for even, [name, value] in zip(cycle(["odd", "even"]), items): if value is None: value = tags.i("(no value)") elif value is accessDeniedValue: value = tags.i("(access forbidden)") yield tag.clone().fillSlots( even=even, name=name, value=value, ) return whenAllProperties return whenPropertiesListed class DirectoryRenderingMixIn(object): def renderDirectory(self, request): """ Render a directory listing. """ def gotBody(output): mime_params = {"charset": "utf-8"} response = Response(200, {}, output) response.headers.setHeader( "content-type", MimeType("text", "html", mime_params) ) return response return flattenString(request, self.htmlElement()).addCallback(gotBody) def htmlElement(self): """ Create a L{DirectoryElement} or appropriate subclass for rendering this resource. """ return DirectoryElement(self) def getChildDirectoryEntry(self, child, name, request): def orNone(value, default="?", f=None): if value is None: return default elif f is not None: return f(value) else: return value url = urllib.quote(name, '/') if isinstance(child, DAVResource) and child.isCollection(): url += "/" name += "/" if isinstance(child, MetaDataMixin): size = child.contentLength() lastModified = child.lastModified() rtypes = [] fullrtype = child.resourceType() if hasattr(child, "resourceType") else None if fullrtype is not None: for rtype in fullrtype.children: rtypes.append(rtype.name) if rtypes: rtypes = "(%s)" % (", ".join(rtypes),) if child.isCollection() if hasattr(child, "isCollection") else False: contentType = rtypes else: mimeType = child.contentType() if mimeType is None: print('BAD contentType() IMPLEMENTATION', child) contentType = 'application/octet-stream' else: contentType = "%s/%s" % (mimeType.mediaType, mimeType.mediaSubtype) if rtypes: contentType += " %s" % (rtypes,) else: size = None lastModified = None contentType = None if hasattr(child, "resourceType"): rtypes = [] fullrtype = child.resourceType() for rtype in fullrtype.children: rtypes.append(rtype.name) if rtypes: contentType = "(%s)" % (", ".join(rtypes),) return ( url, name, orNone(size), orNone( lastModified, default="", f=lambda t: time.strftime("%Y-%b-%d %H:%M", time.localtime(t)) ), contentType, ) class DAVResource (DirectoryPrincipalPropertySearchMixIn, SuperDAVResource, DirectoryRenderingMixIn, StaticRenderMixin): """ Extended L{twext.web2.dav.resource.DAVResource} implementation. Note we add StaticRenderMixin as a base class because we need all the etag etc behavior that is currently in static.py but is actually applicable to any type of resource. """ log = Logger() http_REPORT = http_REPORT def davComplianceClasses(self): return ("1", "access-control") # Add "2" when we have locking def render(self, request): if not self.exists(): return responsecode.NOT_FOUND if self.isCollection(): return self.renderDirectory(request) return super(DAVResource, self).render(request) def resourceType(self): # Allow live property to be overridden by dead property if self.deadProperties().contains((dav_namespace, "resourcetype")): return self.deadProperties().get((dav_namespace, "resourcetype")) return element.ResourceType(element.Collection()) if self.isCollection() else element.ResourceType() def contentType(self): return MimeType("httpd", "unix-directory") if self.isCollection() else None class DAVResourceWithChildrenMixin (object): """ Bits needed from twext.web2.static """ def __init__(self, principalCollections=None): self.putChildren = {} super(DAVResourceWithChildrenMixin, self).__init__(principalCollections=principalCollections) def putChild(self, name, child): """ Register a child with the given name with this resource. @param name: the name of the child (a URI path segment) @param child: the child to register """ self.putChildren[name] = child def getChild(self, name): """ Look up a child resource. First check C{self.putChildren}, then call C{self.makeChild} if no pre-existing children were found. @return: the child of this resource with the given name. """ if name == "": return self result = self.putChildren.get(name, None) if not result: result = self.makeChild(name) return result def makeChild(self, name): """ Called by L{DAVResourceWithChildrenMixin.getChild} to dynamically create children that have not been pre-created with C{putChild}. """ return None def listChildren(self): """ @return: a sequence of the names of all known children of this resource. """ return self.putChildren.keys() def countChildren(self): """ @return: the number of all known children of this resource. """ return len(self.putChildren.keys()) def locateChild(self, req, segments): """ See L{IResource.locateChild}. """ thisSegment = segments[0] moreSegments = segments[1:] return maybeDeferred(self.getChild, thisSegment).addCallback( lambda it: (it, moreSegments) ) class DAVResourceWithoutChildrenMixin (object): """ Bits needed from twext.web2.static """ def __init__(self, principalCollections=None): self.putChildren = {} super(DAVResourceWithChildrenMixin, self).__init__(principalCollections=principalCollections) def findChildren( self, depth, request, callback, privileges=None, inherited_aces=None ): return succeed(None) def locateChild(self, request, segments): return self, server.StopTraversal class DAVPrincipalResource (DirectoryPrincipalPropertySearchMixIn, SuperDAVPrincipalResource, DirectoryRenderingMixIn): """ Extended L{twext.web2.dav.static.DAVFile} implementation. """ log = Logger() def liveProperties(self): return super(DAVPrincipalResource, self).liveProperties() + ( (calendarserver_namespace, "expanded-group-member-set"), (calendarserver_namespace, "expanded-group-membership"), (calendarserver_namespace, "record-type"), ) http_REPORT = http_REPORT def render(self, request): if not self.exists(): return responsecode.NOT_FOUND if self.isCollection(): return self.renderDirectory(request) return super(DAVResource, self).render(request) @inlineCallbacks def readProperty(self, property, request): if type(property) is tuple: qname = property else: qname = property.qname() namespace, name = qname if namespace == dav_namespace: if name == "resourcetype": rtype = self.resourceType() returnValue(rtype) elif namespace == calendarserver_namespace: if name == "expanded-group-member-set": principals = (yield self.expandedGroupMembers()) returnValue(customxml.ExpandedGroupMemberSet( *[element.HRef(p.principalURL()) for p in principals] )) elif name == "expanded-group-membership": principals = (yield self.expandedGroupMemberships()) returnValue(customxml.ExpandedGroupMembership( *[element.HRef(p.principalURL()) for p in principals] )) elif name == "record-type": if hasattr(self, "record"): returnValue(customxml.RecordType(self.record.recordType)) else: raise HTTPError(StatusResponse( responsecode.NOT_FOUND, "Property %s does not exist." % (qname,) )) result = (yield super(DAVPrincipalResource, self).readProperty(property, request)) returnValue(result) def groupMembers(self): return succeed(()) def expandedGroupMembers(self): return succeed(()) def groupMemberships(self): return succeed(()) def expandedGroupMemberships(self): return succeed(()) def resourceType(self): # Allow live property to be overridden by dead property if self.deadProperties().contains((dav_namespace, "resourcetype")): return self.deadProperties().get((dav_namespace, "resourcetype")) if self.isCollection(): return element.ResourceType(element.Principal(), element.Collection()) else: return element.ResourceType(element.Principal()) class DAVFile (SuperDAVFile, DirectoryRenderingMixIn): """ Extended L{twext.web2.dav.static.DAVFile} implementation. """ log = Logger() def resourceType(self): # Allow live property to be overridden by dead property if self.deadProperties().contains((dav_namespace, "resourcetype")): return self.deadProperties().get((dav_namespace, "resourcetype")) if self.isCollection(): return element.ResourceType.collection #@UndefinedVariable return element.ResourceType.empty #@UndefinedVariable def render(self, request): if not self.fp.exists(): return responsecode.NOT_FOUND if self.fp.isdir(): if request.path[-1] != "/": # Redirect to include trailing '/' in URI return RedirectResponse(request.unparseURL(path=urllib.quote(urllib.unquote(request.path), safe=':/') + '/')) else: ifp = self.fp.childSearchPreauth(*self.indexNames) if ifp: # Render from the index file return self.createSimilarFile(ifp.path).render(request) return self.renderDirectory(request) try: f = self.fp.open() except IOError, e: import errno if e[0] == errno.EACCES: return responsecode.FORBIDDEN elif e[0] == errno.ENOENT: return responsecode.NOT_FOUND else: raise response = Response() response.stream = FileStream(f, 0, self.fp.getsize()) for (header, value) in ( ("content-type", self.contentType()), ("content-encoding", self.contentEncoding()), ): if value is not None: response.headers.setHeader(header, value) return response class ReadOnlyWritePropertiesResourceMixIn (object): """ Read only that will allow writing of properties resource. """ readOnlyResponse = StatusResponse( responsecode.FORBIDDEN, "Resource is read only." ) def _forbidden(self, request): return self.readOnlyResponse http_DELETE = _forbidden http_MOVE = _forbidden http_PUT = _forbidden class ReadOnlyResourceMixIn (ReadOnlyWritePropertiesResourceMixIn): """ Read only resource. """ http_PROPPATCH = ReadOnlyWritePropertiesResourceMixIn._forbidden def writeProperty(self, property, request): raise HTTPError(self.readOnlyResponse) def accessControlList( self, request, inheritance=True, expanding=False, inherited_aces=None ): # Permissions here are fixed, and are not subject to # inheritance rules, etc. return succeed(self.defaultAccessControlList()) class PropertyNotFoundError (HTTPError): def __init__(self, qname): HTTPError.__init__(self, StatusResponse( responsecode.NOT_FOUND, "No such property: %s" % encodeXMLName(*qname) ) ) class CachingPropertyStore (object): """ DAV property store using a dict in memory on top of another property store implementation. """ log = Logger() def __init__(self, propertyStore): self.propertyStore = propertyStore self.resource = propertyStore.resource def get(self, qname, uid=None): #self.log.debug("Get: %r, %r" % (self.resource.fp.path, qname)) cache = self._cache() cachedQname = qname + (uid,) if cachedQname in cache: property = cache.get(cachedQname, None) if property is None: self.log.debug("Cache miss: %r, %r, %r" % (self, self.resource.fp.path, qname)) try: property = self.propertyStore.get(qname, uid) except HTTPError: del cache[cachedQname] raise PropertyNotFoundError(qname) cache[cachedQname] = property return property else: raise PropertyNotFoundError(qname) def set(self, property, uid=None): #self.log.debug("Set: %r, %r" % (self.resource.fp.path, property)) cache = self._cache() cachedQname = property.qname() + (uid,) cache[cachedQname] = None self.propertyStore.set(property, uid) cache[cachedQname] = property def contains(self, qname, uid=None): #self.log.debug("Contains: %r, %r" % (self.resource.fp.path, qname)) cachedQname = qname + (uid,) try: cache = self._cache() except HTTPError, e: if e.response.code == responsecode.NOT_FOUND: return False else: raise if cachedQname in cache: #self.log.debug("Contains cache hit: %r, %r, %r" % (self, self.resource.fp.path, qname)) return True else: return False def delete(self, qname, uid=None): #self.log.debug("Delete: %r, %r" % (self.resource.fp.path, qname)) cachedQname = qname + (uid,) if self._data is not None and cachedQname in self._data: del self._data[cachedQname] self.propertyStore.delete(qname, uid) def list(self, uid=None, filterByUID=True): #self.log.debug("List: %r" % (self.resource.fp.path,)) keys = self._cache().iterkeys() if filterByUID: return [ (namespace, name) for namespace, name, propuid in keys if propuid == uid ] else: return keys def _cache(self): if not hasattr(self, "_data"): #self.log.debug("Cache init: %r" % (self.resource.fp.path,)) self._data = dict( (name, None) for name in self.propertyStore.list(filterByUID=False) ) return self._data def extractCalendarServerPrincipalSearchData(doc): """ Extract relevant info from a CalendarServerPrincipalSearch document @param doc: CalendarServerPrincipalSearch object to extract info from @return: A tuple containing: the list of tokens the context string the applyTo boolean the clientLimit integer the propElement containing the properties to return """ context = doc.attributes.get("context", None) applyTo = False tokens = [] clientLimit = None for child in doc.children: if child.qname() == (dav_namespace, "prop"): propElement = child elif child.qname() == (dav_namespace, "apply-to-principal-collection-set"): applyTo = True elif child.qname() == (calendarserver_namespace, "search-token"): tokens.append(str(child)) elif child.qname() == (calendarserver_namespace, "limit"): try: nresults = child.childOfType(customxml.NResults) clientLimit = int(str(nresults)) except (TypeError, ValueError,): msg = "Bad XML: unknown value for element" log.warn(msg) raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg)) return tokens, context, applyTo, clientLimit, propElement def validateTokens(tokens): """ Make sure there is at least one token longer than one character @param tokens: the tokens to inspect @type tokens: iterable of utf-8 encoded strings @return: True if tokens are valid, False otherwise @rtype: boolean """ for token in tokens: if len(token) > 1: return True return False calendarserver-5.2+dfsg/twistedcaldav/memcacheprops.py0000644000175000017500000003331212263343324022351 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Computer, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. ## """ DAV Property store using memcache on top of another property store implementation. """ __all__ = ["MemcachePropertyCollection"] try: from hashlib import md5 except ImportError: from md5 import new as md5 from twext.python.log import Logger from twext.python.memcacheclient import ClientFactory from twext.python.memcacheclient import MemcacheError, TokenMismatchError from twext.python.filepath import CachingFilePath as FilePath from txdav.xml.base import encodeXMLName from twext.web2 import responsecode from twext.web2.http import HTTPError, StatusResponse from twistedcaldav.config import config NoValue = "" class MemcachePropertyCollection (object): """ Manages a single property store for all resources in a collection. """ log = Logger() def __init__(self, collection, cacheTimeout=0): self.collection = collection self.cacheTimeout = cacheTimeout @classmethod def memcacheClient(cls, refresh=False): if not hasattr(MemcachePropertyCollection, "_memcacheClient"): cls.log.info("Instantiating memcache connection for MemcachePropertyCollection") MemcachePropertyCollection._memcacheClient = ClientFactory.getClient([ "%s:%s" % (config.Memcached.Pools.Default.BindAddress, config.Memcached.Pools.Default.Port) ], debug=0, pickleProtocol=2, ) assert MemcachePropertyCollection._memcacheClient is not None return MemcachePropertyCollection._memcacheClient def propertyCache(self): # The property cache has this format: # { # "/path/to/resource/file": # ( # { # (namespace, name, uid): property, # ..., # }, # memcache_token, # ), # ..., # } if not hasattr(self, "_propertyCache"): self._propertyCache = self._loadCache() return self._propertyCache def childCache(self, child): path = child.fp.path key = self._keyForPath(path) propertyCache = self.propertyCache() try: childCache, token = propertyCache[key] except KeyError: self.log.debug("No child property cache for %s" % (child,)) childCache, token = ({}, None) #message = "No child property cache for %s" % (child,) #self.log.error(message) #raise AssertionError(message) return propertyCache, key, childCache, token def _keyForPath(self, path): key = "|".join(( self.__class__.__name__, path )) return md5(key).hexdigest() def _loadCache(self, childNames=None): if childNames is None: abortIfMissing = False childNames = self.collection.listChildren() else: if childNames: abortIfMissing = True else: return {} self.log.debug("Loading cache for %s" % (self.collection,)) client = self.memcacheClient() assert client is not None, "OMG no cache!" if client is None: return None keys = tuple(( (self._keyForPath(self.collection.fp.child(childName).path), childName) for childName in childNames )) result = self._split_gets_multi((key for key, _ignore_name in keys), client.gets_multi) if abortIfMissing: missing = "missing " else: missing = "" self.log.debug( "Loaded keys for {missing}children of {collection}: {children()}", missing=missing, collection=self.collection, children=lambda: [name for _ignore_key, name in keys], ) missing = tuple(( name for key, name in keys if key not in result )) if missing: if abortIfMissing: raise MemcacheError("Unable to fully load cache for %s" % (self.collection,)) loaded = self._buildCache(childNames=missing) loaded = self._loadCache(childNames=(FilePath(name).basename() for name in loaded.iterkeys())) result.update(loaded.iteritems()) return result def _split_gets_multi(self, keys, func, chunksize=250): """ Splits gets_multi into chunks to avoid a memcacheclient timeout due of a large number of keys. Consolidates and returns results. Takes a function parameter for easier unit testing. """ results = {} count = 0 subset = [] for key in keys: if count == 0: subset = [] subset.append(key) count += 1 if count == chunksize: results.update(func(subset)) count = 0 if count: results.update(func(subset)) return results def _split_set_multi(self, values, func, time=0, chunksize=250): """ Splits set_multi into chunks to avoid a memcacheclient timeout due of a large number of keys. Takes a function parameter for easier unit testing. """ count = 0 subset = {} for key, value in values.iteritems(): if count == 0: subset.clear() subset[key] = value count += 1 if count == chunksize: func(subset, time=time) count = 0 if count: func(subset, time=time) def _storeCache(self, cache): self.log.debug("Storing cache for %s" % (self.collection,)) values = dict(( (self._keyForPath(path), props) for path, props in cache.iteritems() )) client = self.memcacheClient() if client is not None: self._split_set_multi(values, client.set_multi, time=self.cacheTimeout) def _buildCache(self, childNames=None): if childNames is None: childNames = self.collection.listChildren() elif not childNames: return {} self.log.debug("Building cache for %s" % (self.collection,)) cache = {} for childName in childNames: child = self.collection.getChild(childName) if child is None: continue propertyStore = child.deadProperties() props = {} for pnamespace, pname, puid in propertyStore.list(filterByUID=False, cache=False): props[(pnamespace, pname, puid,)] = propertyStore.get((pnamespace, pname,), uid=puid, cache=False) cache[child.fp.path] = props self._storeCache(cache) return cache def setProperty(self, child, property, uid, delete=False): propertyCache, key, childCache, token = self.childCache(child) if delete: qname = property qnameuid = qname + (uid,) if qnameuid in childCache: del childCache[qnameuid] else: qname = property.qname() qnameuid = qname + (uid,) childCache[qnameuid] = property client = self.memcacheClient() if client is not None: retries = 10 while retries: try: if client.set(key, childCache, time=self.cacheTimeout, token=token): # Success break except TokenMismatchError: # The value in memcache has changed since we last # fetched it self.log.debug("memcacheprops setProperty TokenMismatchError; retrying...") finally: # Re-fetch the properties for this child loaded = self._loadCache(childNames=(child.fp.basename(),)) propertyCache.update(loaded.iteritems()) retries -= 1 propertyCache, key, childCache, token = self.childCache(child) if delete: if qnameuid in childCache: del childCache[qnameuid] else: childCache[qnameuid] = property else: self.log.error("memcacheprops setProperty had too many failures") delattr(self, "_propertyCache") raise MemcacheError("Unable to %s property %s%s on %s" % ( "delete" if delete else "set", uid if uid else "", encodeXMLName(*qname), child )) def deleteProperty(self, child, qname, uid): return self.setProperty(child, qname, uid, delete=True) def flushCache(self, child): path = child.fp.path key = self._keyForPath(path) propertyCache = self.propertyCache() if key in propertyCache: del propertyCache[key] client = self.memcacheClient() if client is not None: result = client.delete(key) if not result: raise MemcacheError("Unable to flush cache on %s" % (child,)) def propertyStoreForChild(self, child, childPropertyStore): return self.ChildPropertyStore(self, child, childPropertyStore) class ChildPropertyStore (object): log = Logger() def __init__(self, parentPropertyCollection, child, childPropertyStore): self.parentPropertyCollection = parentPropertyCollection self.child = child self.childPropertyStore = childPropertyStore def propertyCache(self): path = self.child.fp.path key = self.parentPropertyCollection._keyForPath(path) parentPropertyCache = self.parentPropertyCollection.propertyCache() return parentPropertyCache.get(key, ({}, None))[0] def flushCache(self): self.parentPropertyCollection.flushCache(self.child) def get(self, qname, uid=None, cache=True): if cache: propertyCache = self.propertyCache() qnameuid = qname + (uid,) if qnameuid in propertyCache: return propertyCache[qnameuid] else: raise HTTPError(StatusResponse( responsecode.NOT_FOUND, "No such property: %s%s" % (uid if uid else "", encodeXMLName(*qname)) )) self.log.debug("Read for %s%s on %s" % ( ("{%s}:" % (uid,)) if uid else "", qname, self.childPropertyStore.resource.fp.path )) return self.childPropertyStore.get(qname, uid=uid) def set(self, property, uid=None): self.log.debug("Write for %s%s on %s" % ( ("{%s}:" % (uid,)) if uid else "", property.qname(), self.childPropertyStore.resource.fp.path )) self.parentPropertyCollection.setProperty(self.child, property, uid) self.childPropertyStore.set(property, uid=uid) def delete(self, qname, uid=None): self.log.debug("Delete for %s%s on %s" % ( ("{%s}:" % (uid,)) if uid else "", qname, self.childPropertyStore.resource.fp.path )) self.parentPropertyCollection.deleteProperty(self.child, qname, uid) self.childPropertyStore.delete(qname, uid=uid) def contains(self, qname, uid=None, cache=True): if cache: propertyCache = self.propertyCache() qnameuid = qname + (uid,) return qnameuid in propertyCache self.log.debug("Contains for %s%s on %s" % ( ("{%s}:" % (uid,)) if uid else "", qname, self.childPropertyStore.resource.fp.path, )) return self.childPropertyStore.contains(qname, uid=uid) def list(self, uid=None, filterByUID=True, cache=True): if cache: propertyCache = self.propertyCache() results = propertyCache.keys() if filterByUID: return [ (namespace, name) for namespace, name, propuid in results if propuid == uid ] else: return results self.log.debug("List for %s" % (self.childPropertyStore.resource.fp.path,)) return self.childPropertyStore.list(uid=uid, filterByUID=filterByUID) calendarserver-5.2+dfsg/twistedcaldav/bind.py0000644000175000017500000000201712263343324020435 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Bind methods. Have to have this in a separate module for now. """ from twext.web2.dav.util import bindMethods ## # Attach methods ## def doBind(): import twext.web2.dav.method from twext.web2.dav.resource import DAVResource bindMethods(twext.web2.dav.method, DAVResource) import twistedcaldav.method from twistedcaldav.resource import CalDAVResource bindMethods(twistedcaldav.method, CalDAVResource) calendarserver-5.2+dfsg/twistedcaldav/customxml.py0000644000175000017500000010316212264057476021572 0ustar rahulrahul## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Custom CalDAV XML Support. This module provides custom XML utilities for use with CalDAV. This API is considered private to static.py and is therefore subject to change. """ from txdav.xml.element import registerElement, dav_namespace from txdav.xml.element import twisted_dav_namespace, twisted_private_namespace from txdav.xml.element import WebDAVElement, PCDATAElement from txdav.xml.element import WebDAVEmptyElement, WebDAVTextElement from txdav.xml.element import PrincipalPropertySearch, Match from txdav.xml.element import ResourceType, Collection, Principal from twistedcaldav import caldavxml, carddavxml from twistedcaldav.caldavxml import caldav_namespace from twistedcaldav.ical import Component as iComponent from pycalendar.datetime import PyCalendarDateTime calendarserver_namespace = "http://calendarserver.org/ns/" calendarserver_proxy_compliance = ( "calendar-proxy", ) calendarserver_private_events_compliance = ( "calendarserver-private-events", ) calendarserver_private_comments_compliance = ( "calendarserver-private-comments", ) calendarserver_principal_property_search_compliance = ( "calendarserver-principal-property-search", ) calendarserver_principal_search_compliance = ( "calendarserver-principal-search", ) calendarserver_sharing_compliance = ( "calendarserver-sharing", ) # TODO: This is only needed whilst we do not support scheduling in shared calendars calendarserver_sharing_no_scheduling_compliance = ( "calendarserver-sharing-no-scheduling", ) calendarserver_partstat_changes_compliance = ( "calendarserver-partstat-changes", ) calendarserver_home_sync_compliance = ( "calendarserver-home-sync", ) calendarserver_recurrence_split = ( "calendarserver-recurrence-split", ) @registerElement class TwistedCalendarSupportedComponents (WebDAVTextElement): """ Contains the calendar supported components list. """ namespace = twisted_dav_namespace name = "calendar-supported-components" hidden = True def getValue(self): return str(self) @registerElement class TwistedCalendarAccessProperty (WebDAVTextElement): """ Contains the calendar access level (private events) for the resource. """ namespace = twisted_dav_namespace name = "calendar-access" hidden = True def getValue(self): return str(self) @registerElement class TwistedSchedulingObjectResource (WebDAVTextElement): """ Indicates that the resource is a scheduling object resource. """ namespace = twisted_private_namespace name = "scheduling-object-resource" hidden = True @registerElement class TwistedScheduleMatchETags(WebDAVElement): """ List of ETags that can be used for a "weak" If-Match comparison. """ namespace = twisted_private_namespace name = "scheduling-match-etags" hidden = True allowed_children = {(dav_namespace, "getetag"): (0, None)} @registerElement class TwistedCalendarHasPrivateCommentsProperty (WebDAVEmptyElement): """ Indicates that a calendar resource has private comments. NB This MUST be a private property as we don't want to expose the presence of private comments in private events. """ namespace = twisted_private_namespace name = "calendar-has-private-comments" hidden = True @registerElement class CalendarProxyRead (WebDAVEmptyElement): """ A read-only calendar user proxy principal resource. (Apple Extension to CalDAV) """ namespace = calendarserver_namespace name = "calendar-proxy-read" @registerElement class CalendarProxyWrite (WebDAVEmptyElement): """ A read-write calendar user proxy principal resource. (Apple Extension to CalDAV) """ namespace = calendarserver_namespace name = "calendar-proxy-write" @registerElement class CalendarProxyReadFor (WebDAVElement): """ List of principals granting read-only proxy status. (Apple Extension to CalDAV) """ namespace = calendarserver_namespace name = "calendar-proxy-read-for" hidden = True protected = True allowed_children = {(dav_namespace, "href"): (0, None)} @registerElement class CalendarProxyWriteFor (WebDAVElement): """ List of principals granting read-write proxy status. (Apple Extension to CalDAV) """ namespace = calendarserver_namespace name = "calendar-proxy-write-for" hidden = True protected = True allowed_children = {(dav_namespace, "href"): (0, None)} @registerElement class DropBoxHome (WebDAVEmptyElement): """ Denotes a drop box home collection (a collection that will contain drop boxes). (Apple Extension to CalDAV) """ namespace = calendarserver_namespace name = "dropbox-home" @registerElement class DropBox (WebDAVEmptyElement): """ Denotes a drop box collection. (Apple Extension to CalDAV) """ namespace = calendarserver_namespace name = "dropbox" @registerElement class DropBoxHomeURL (WebDAVElement): """ A principal property to indicate the location of the drop box home. (Apple Extension to CalDAV) """ namespace = calendarserver_namespace name = "dropbox-home-URL" hidden = True protected = True allowed_children = {(dav_namespace, "href"): (0, 1)} @registerElement class GETCTag (WebDAVTextElement): """ Contains the calendar collection entity tag. """ namespace = calendarserver_namespace name = "getctag" protected = True @registerElement class CalendarAvailability (WebDAVTextElement): """ Contains the calendar availability property. """ namespace = calendarserver_namespace name = "calendar-availability" hidden = True def calendar(self): """ Returns a calendar component derived from this element. """ return iComponent.fromString(str(self)) def valid(self): """ Determine whether the content of this element is a valid single VAVAILABILITY component, with zero or more VTIEMZONE components. @return: True if valid, False if not. """ try: calendar = self.calendar() if calendar is None: return False except ValueError: return False found = False for subcomponent in calendar.subcomponents(): if subcomponent.name() == "VAVAILABILITY": if found: return False else: found = True elif subcomponent.name() == "VTIMEZONE": continue else: return False return found @registerElement class MaxCollections (WebDAVTextElement): """ Maximum number of child collections in a home collection """ namespace = calendarserver_namespace name = "max-collections" hidden = True protected = True @registerElement class MaxResources (WebDAVTextElement): """ Maximum number of child resources in a collection """ namespace = calendarserver_namespace name = "max-resources" hidden = True protected = True @registerElement class Timezones (WebDAVEmptyElement): """ Denotes a timezone service resource. (Apple Extension to CalDAV) """ namespace = calendarserver_namespace name = "timezones" @registerElement class TZIDs (WebDAVElement): """ Wraps a list of timezone ids. """ namespace = calendarserver_namespace name = "tzids" allowed_children = {(calendarserver_namespace, "tzid"): (0, None)} @registerElement class TZID (WebDAVTextElement): """ A timezone id. """ namespace = calendarserver_namespace name = "tzid" @registerElement class TZData (WebDAVElement): """ Wraps a list of timezone observances. """ namespace = calendarserver_namespace name = "tzdata" allowed_children = {(calendarserver_namespace, "observance"): (0, None)} @registerElement class Observance (WebDAVElement): """ A timezone observance. """ namespace = calendarserver_namespace name = "observance" allowed_children = { (calendarserver_namespace, "onset") : (1, 1), (calendarserver_namespace, "utc-offset"): (1, 1), } @registerElement class Onset (WebDAVTextElement): """ The onset date-time for a DST transition. """ namespace = calendarserver_namespace name = "onset" @registerElement class UTCOffset (WebDAVTextElement): """ A UTC offset value for a timezone observance. """ namespace = calendarserver_namespace name = "utc-offset" @registerElement class PubSubPushTransportsProperty (WebDAVTextElement): """ A calendar property describing the available push notification transports available. """ namespace = calendarserver_namespace name = "push-transports" protected = True hidden = True allowed_children = { (calendarserver_namespace, "transport") : (0, 1), } @registerElement class PubSubTransportProperty (WebDAVTextElement): namespace = calendarserver_namespace name = "transport" protected = True hidden = True allowed_attributes = { "type" : True, } allowed_children = { (calendarserver_namespace, "subscription-url") : (1, 1), (calendarserver_namespace, "apsbundleid") : (1, 1), (calendarserver_namespace, "env") : (1, 1), } @registerElement class PubSubSubscriptionProperty (WebDAVTextElement): namespace = calendarserver_namespace name = "subscription-url" protected = True hidden = True allowed_children = {(dav_namespace, "href"): (0, 1)} @registerElement class PubSubAPSBundleIDProperty (WebDAVTextElement): namespace = calendarserver_namespace name = "apsbundleid" protected = True hidden = True @registerElement class PubSubAPSEnvironmentProperty (WebDAVTextElement): namespace = calendarserver_namespace name = "env" protected = True hidden = True @registerElement class PubSubAPSRefreshIntervalProperty (WebDAVTextElement): namespace = calendarserver_namespace name = "refresh-interval" protected = True hidden = True @registerElement class PubSubXMPPPushKeyProperty (WebDAVTextElement): namespace = calendarserver_namespace name = "pushkey" protected = True hidden = True PrincipalPropertySearch.allowed_children[(calendarserver_namespace, "limit")] = (0, 1) PrincipalPropertySearch.allowed_attributes["type"] = False Match.allowed_attributes = { "caseless": False, "match-type": False, } @registerElement class Limit (WebDAVElement): """ Client supplied limit for reports. """ namespace = calendarserver_namespace name = "limit" allowed_children = { (calendarserver_namespace, "nresults") : (1, 1), } @registerElement class NResults (WebDAVTextElement): """ Number of results limit. """ namespace = calendarserver_namespace name = "nresults" @registerElement class FirstNameProperty (WebDAVTextElement): """ A property representing first name of a principal """ namespace = calendarserver_namespace name = "first-name" protected = True hidden = True @registerElement class LastNameProperty (WebDAVTextElement): """ A property representing last name of a principal """ namespace = calendarserver_namespace name = "last-name" protected = True hidden = True @registerElement class EmailAddressProperty (WebDAVTextElement): """ A property representing email address of a principal """ namespace = calendarserver_namespace name = "email-address" protected = True hidden = True @registerElement class EmailAddressSet (WebDAVElement): """ The list of email addresses of a principal """ namespace = calendarserver_namespace name = "email-address-set" hidden = True allowed_children = {(calendarserver_namespace, "email-address"): (0, None)} @registerElement class ExpandedGroupMemberSet (WebDAVElement): """ The expanded list of members of a (group) principal """ namespace = calendarserver_namespace name = "expanded-group-member-set" protected = True hidden = True allowed_children = {(dav_namespace, "href"): (0, None)} @registerElement class ExpandedGroupMembership (WebDAVElement): """ The expanded list of groups a principal is a member of """ namespace = calendarserver_namespace name = "expanded-group-membership" protected = True hidden = True allowed_children = {(dav_namespace, "href"): (0, None)} @registerElement class IScheduleInbox (WebDAVEmptyElement): """ Denotes the resourcetype of a iSchedule Inbox. (CalDAV-s2s-xx, section x.x.x) """ namespace = calendarserver_namespace name = "ischedule-inbox" @registerElement class FreeBusyURL (WebDAVEmptyElement): """ Denotes the resourcetype of a free-busy URL resource. (CalDAV-s2s-xx, section x.x.x) """ namespace = calendarserver_namespace name = "free-busy-url" @registerElement class ScheduleChanges (WebDAVElement): """ Change indicator for a scheduling message. """ namespace = calendarserver_namespace name = "schedule-changes" protected = True hidden = True allowed_children = { (calendarserver_namespace, "dtstamp") : (0, 1), # Have to allow 0 as element is empty in PROPFIND requests (calendarserver_namespace, "action") : (0, 1), # Have to allow 0 as element is empty in PROPFIND requests } @registerElement class ScheduleDefaultTasksURL (WebDAVElement): """ A single href indicating which calendar is the default for VTODO scheduling. """ namespace = calendarserver_namespace name = "schedule-default-tasks-URL" allowed_children = {(dav_namespace, "href"): (0, 1)} @registerElement class DTStamp (WebDAVTextElement): """ A UTC timestamp in iCal format. """ namespace = calendarserver_namespace name = "dtstamp" def __init__(self, *children): super(DTStamp, self).__init__(children) self.children = (PCDATAElement(PyCalendarDateTime.getNowUTC().getText()),) @registerElement class Action (WebDAVElement): """ A UTC timestamp in iCal format. """ namespace = calendarserver_namespace name = "action" allowed_children = { (calendarserver_namespace, "create") : (0, 1), (calendarserver_namespace, "update") : (0, 1), (calendarserver_namespace, "cancel") : (0, 1), (calendarserver_namespace, "reply") : (0, 1), } @registerElement class Create (WebDAVEmptyElement): """ Event created. """ namespace = calendarserver_namespace name = "create" @registerElement class Update (WebDAVElement): """ Event updated. """ namespace = calendarserver_namespace name = "update" allowed_children = { (calendarserver_namespace, "recurrence") : (1, None), } @registerElement class Cancel (WebDAVElement): """ Event cancelled. """ namespace = calendarserver_namespace name = "cancel" allowed_children = { (calendarserver_namespace, "recurrence") : (0, 1), } @registerElement class Reply (WebDAVElement): """ Event replied to. """ namespace = calendarserver_namespace name = "reply" allowed_children = { (calendarserver_namespace, "attendee") : (1, 1), (calendarserver_namespace, "recurrence") : (1, None), } @registerElement class Recurrence (WebDAVElement): """ Changes to an event. """ namespace = calendarserver_namespace name = "recurrence" allowed_children = { (calendarserver_namespace, "master") : (0, 1), (calendarserver_namespace, "recurrenceid") : (0, None), (calendarserver_namespace, "changes") : (0, 1), } @registerElement class Master (WebDAVEmptyElement): """ Master instance changed. """ namespace = calendarserver_namespace name = "master" @registerElement class RecurrenceID (WebDAVTextElement): """ A recurrence instance changed. """ namespace = calendarserver_namespace name = "recurrenceid" @registerElement class Changes (WebDAVElement): """ Changes to an event. """ namespace = calendarserver_namespace name = "changes" allowed_children = { (calendarserver_namespace, "changed-property") : (0, None), } @registerElement class ChangedProperty (WebDAVElement): """ Changes to a property. """ namespace = calendarserver_namespace name = "changed-property" allowed_children = { (calendarserver_namespace, "changed-parameter") : (0, None), } allowed_attributes = { "name" : True, } @registerElement class ChangedParameter (WebDAVEmptyElement): """ Changes to a parameter. """ namespace = calendarserver_namespace name = "changed-parameter" allowed_attributes = { "name" : True, } @registerElement class Attendee (WebDAVTextElement): """ An attendee calendar user address. """ namespace = calendarserver_namespace name = "attendee" @registerElement class RecordType (WebDAVTextElement): """ Exposes the type of a record """ namespace = calendarserver_namespace name = "record-type" protected = True hidden = True @registerElement class AutoSchedule (WebDAVTextElement): """ Whether the principal automatically accepts invitations """ namespace = calendarserver_namespace name = "auto-schedule" @registerElement class AutoScheduleMode (WebDAVTextElement): """ The principal's auto-schedule mode """ namespace = calendarserver_namespace name = "auto-schedule-mode" ## # Sharing ## @registerElement class ReadAccess (WebDAVEmptyElement): """ Denotes read and update attendee partstat on a shared calendar. """ namespace = calendarserver_namespace name = "read" @registerElement class ReadWriteAccess (WebDAVEmptyElement): """ Denotes read and write access on a shared calendar. """ namespace = calendarserver_namespace name = "read-write" @registerElement class UID (WebDAVTextElement): namespace = calendarserver_namespace name = "uid" @registerElement class InReplyTo (WebDAVTextElement): namespace = calendarserver_namespace name = "in-reply-to" @registerElement class SharedOwner (WebDAVEmptyElement): """ Denotes a shared collection. """ namespace = calendarserver_namespace name = "shared-owner" @registerElement class Shared (WebDAVEmptyElement): """ Denotes a shared collection. """ namespace = calendarserver_namespace name = "shared" @registerElement class Subscribed (WebDAVEmptyElement): """ Denotes a subscribed calendar collection. """ namespace = calendarserver_namespace name = "subscribed" @registerElement class SharedURL (WebDAVTextElement): """ The source url for a shared calendar. """ namespace = calendarserver_namespace name = "shared-url" protected = True hidden = True @registerElement class SharedAs (WebDAVElement): """ The url for a shared calendar. """ namespace = calendarserver_namespace name = "shared-as" allowed_children = { (dav_namespace, "href") : (1, 1), } @registerElement class SharedAcceptEmailNotification (WebDAVTextElement): """ The accept email flag for a shared calendar. """ namespace = calendarserver_namespace name = "shared-accept-email-notification" @registerElement class Birthday (WebDAVEmptyElement): """ Denotes a birthday calendar collection. """ namespace = calendarserver_namespace name = "birthday" @registerElement class AllowedSharingModes (WebDAVElement): namespace = calendarserver_namespace name = "allowed-sharing-modes" protected = True hidden = True allowed_children = { (calendarserver_namespace, "can-be-shared") : (0, 1), (calendarserver_namespace, "can-be-published") : (0, 1), } @registerElement class CanBeShared (WebDAVEmptyElement): namespace = calendarserver_namespace name = "can-be-shared" @registerElement class CanBePublished (WebDAVEmptyElement): namespace = calendarserver_namespace name = "can-be-published" @registerElement class InviteShare (WebDAVElement): namespace = calendarserver_namespace name = "share" allowed_children = { (calendarserver_namespace, "set") : (0, None), (calendarserver_namespace, "remove") : (0, None), } @registerElement class InviteSet (WebDAVElement): namespace = calendarserver_namespace name = "set" allowed_children = { (dav_namespace, "href") : (1, 1), (calendarserver_namespace, "common-name") : (0, 1), (calendarserver_namespace, "summary") : (0, 1), (calendarserver_namespace, "read") : (0, 1), (calendarserver_namespace, "read-write") : (0, 1), (calendarserver_namespace, "read-write-schedule") : (0, 1), } @registerElement class InviteRemove (WebDAVElement): namespace = calendarserver_namespace name = "remove" allowed_children = { (dav_namespace, "href") : (1, 1), (calendarserver_namespace, "read") : (0, 1), (calendarserver_namespace, "read-write") : (0, 1), (calendarserver_namespace, "read-write-schedule") : (0, 1), } @registerElement class InviteUser (WebDAVElement): namespace = calendarserver_namespace name = "user" allowed_children = { (calendarserver_namespace, "uid") : (0, 1), (dav_namespace, "href") : (1, 1), (calendarserver_namespace, "common-name") : (0, 1), (calendarserver_namespace, "invite-noresponse") : (0, 1), (calendarserver_namespace, "invite-deleted") : (0, 1), (calendarserver_namespace, "invite-accepted") : (0, 1), (calendarserver_namespace, "invite-declined") : (0, 1), (calendarserver_namespace, "invite-invalid") : (0, 1), (calendarserver_namespace, "access") : (1, 1), (calendarserver_namespace, "summary") : (0, 1), } @registerElement class InviteAccess (WebDAVElement): namespace = calendarserver_namespace name = "access" allowed_children = { (calendarserver_namespace, "read") : (0, 1), (calendarserver_namespace, "read-write") : (0, 1), (calendarserver_namespace, "read-write-schedule") : (0, 1), } @registerElement class Invite (WebDAVElement): namespace = calendarserver_namespace name = "invite" allowed_children = { (calendarserver_namespace, "organizer") : (0, 1), (calendarserver_namespace, "user") : (0, None), } @registerElement class InviteSummary (WebDAVTextElement): namespace = calendarserver_namespace name = "summary" @registerElement class InviteStatusNoResponse (WebDAVEmptyElement): namespace = calendarserver_namespace name = "invite-noresponse" @registerElement class InviteStatusDeleted (WebDAVEmptyElement): namespace = calendarserver_namespace name = "invite-deleted" @registerElement class InviteStatusAccepted (WebDAVEmptyElement): namespace = calendarserver_namespace name = "invite-accepted" @registerElement class InviteStatusDeclined (WebDAVEmptyElement): namespace = calendarserver_namespace name = "invite-declined" @registerElement class InviteStatusInvalid (WebDAVEmptyElement): namespace = calendarserver_namespace name = "invite-invalid" @registerElement class HostURL (WebDAVElement): """ The source for a shared calendar """ namespace = calendarserver_namespace name = "hosturl" allowed_children = { (dav_namespace, "href") : (0, None) } @registerElement class Organizer (WebDAVElement): """ The organizer for a shared calendar """ namespace = calendarserver_namespace name = "organizer" allowed_children = { (dav_namespace, "href") : (0, None), (calendarserver_namespace, "common-name") : (0, 1) } @registerElement class CommonName (WebDAVTextElement): """ Common name for Sharer or Sharee """ namespace = calendarserver_namespace name = "common-name" @registerElement class InviteNotification (WebDAVElement): namespace = calendarserver_namespace name = "invite-notification" allowed_children = { (calendarserver_namespace, "uid") : (0, 1), (dav_namespace, "href") : (0, 1), (calendarserver_namespace, "invite-noresponse") : (0, 1), (calendarserver_namespace, "invite-deleted") : (0, 1), (calendarserver_namespace, "invite-accepted") : (0, 1), (calendarserver_namespace, "invite-declined") : (0, 1), (calendarserver_namespace, "access") : (0, 1), (calendarserver_namespace, "hosturl") : (0, 1), (calendarserver_namespace, "organizer") : (0, 1), (calendarserver_namespace, "summary") : (0, 1), (caldav_namespace, "supported-calendar-component-set") : (0, 1), } allowed_attributes = { "shared-type" : True, } @registerElement class InviteReply (WebDAVElement): namespace = calendarserver_namespace name = "invite-reply" allowed_children = { (dav_namespace, "href") : (0, 1), (calendarserver_namespace, "common-name") : (0, 1), (calendarserver_namespace, "first-name") : (0, 1), (calendarserver_namespace, "last-name") : (0, 1), (calendarserver_namespace, "invite-accepted") : (0, 1), (calendarserver_namespace, "invite-declined") : (0, 1), (calendarserver_namespace, "hosturl") : (0, 1), (calendarserver_namespace, "in-reply-to") : (0, 1), (calendarserver_namespace, "summary") : (0, 1), } @registerElement class ResourceUpdateNotification (WebDAVElement): namespace = calendarserver_namespace name = "resource-update-notification" allowed_children = { (dav_namespace, "href") : (0, 1), (calendarserver_namespace, "uid") : (0, 1), (calendarserver_namespace, "resource-added-notification") : (0, 1), (calendarserver_namespace, "resource-updated-notification") : (0, 1), (calendarserver_namespace, "resource-deleted-notification") : (0, 1), } @registerElement class ResourceUpdateAdded(WebDAVEmptyElement): namespace = calendarserver_namespace name = "resource-added-notification" @registerElement class ResourceUpdateUpdated(WebDAVEmptyElement): namespace = calendarserver_namespace name = "resource-updated-notification" @registerElement class ResourceUpdateDeleted(WebDAVEmptyElement): namespace = calendarserver_namespace name = "resource-deleted-notification" @registerElement class SharedCalendarUpdateNotification (WebDAVElement): namespace = calendarserver_namespace name = "shared-update-notification" allowed_children = { (calendarserver_namespace, "hosturl") : (0, 1), # The shared calendar url (dav_namespace, "href") : (0, 1), # Email userid that was invited (calendarserver_namespace, "invite-deleted") : (0, 1), # What the user did... (calendarserver_namespace, "invite-accepted") : (0, 1), (calendarserver_namespace, "invite-declined") : (0, 1), } ## # Notifications ## @registerElement class Notification (WebDAVElement): """ Denotes a notification collection, or a notification message. """ namespace = calendarserver_namespace name = "notification" allowed_children = { (calendarserver_namespace, "dtstamp") : (0, None), (calendarserver_namespace, "invite-notification") : (0, None), (calendarserver_namespace, "invite-reply") : (0, None), (calendarserver_namespace, "resource-update-notification") : (0, None), (calendarserver_namespace, "shared-update-notification") : (0, None), } @registerElement class NotificationURL (WebDAVElement): """ A principal property to indicate the notification collection for the principal. """ namespace = calendarserver_namespace name = "notification-URL" hidden = True protected = True allowed_children = { (dav_namespace, "href") : (0, 1) } @registerElement class NotificationType (WebDAVElement): """ A property to indicate what type of notification the resource represents. """ namespace = calendarserver_namespace name = "notificationtype" hidden = True protected = True allowed_children = { (calendarserver_namespace, "invite-notification") : (0, None), (calendarserver_namespace, "invite-reply") : (0, None), } @registerElement class Link (WebDAVEmptyElement): """ Denotes a linked resource. """ namespace = calendarserver_namespace name = "link" mm_namespace = "http://me.com/_namespace/" @registerElement class Multiput (WebDAVElement): namespace = mm_namespace name = "multiput" allowed_children = { (mm_namespace, "resource") : (1, None), } @registerElement class Resource (WebDAVElement): namespace = mm_namespace name = "resource" allowed_children = { (dav_namespace, "href") : (0, 1), (mm_namespace, "if-match") : (0, 1), (dav_namespace, "set") : (0, 1), (dav_namespace, "remove") : (0, 1), (mm_namespace, "delete") : (0, 1), } @registerElement class IfMatch (WebDAVElement): namespace = mm_namespace name = "if-match" allowed_children = { (dav_namespace, "getetag") : (1, 1), } @registerElement class Delete (WebDAVEmptyElement): namespace = mm_namespace name = "delete" @registerElement class BulkRequests (WebDAVElement): namespace = mm_namespace name = "bulk-requests" hidden = True protected = True allowed_children = { (mm_namespace, "simple") : (0, 1), (mm_namespace, "crud") : (0, 1), } @registerElement class Simple (WebDAVElement): namespace = mm_namespace name = "simple" hidden = True protected = True allowed_children = { (mm_namespace, "max-resources") : (1, 1), (mm_namespace, "max-bytes") : (1, 1), } @registerElement class CRUD (WebDAVElement): namespace = mm_namespace name = "crud" hidden = True protected = True allowed_children = { (mm_namespace, "max-resources") : (1, 1), (mm_namespace, "max-bytes") : (1, 1), } @registerElement class MaxBulkResources (WebDAVTextElement): namespace = mm_namespace name = "max-resources" @registerElement class MaxBulkBytes (WebDAVTextElement): namespace = mm_namespace name = "max-bytes" # # Client properties we might care about # @registerElement class CalendarColor(WebDAVTextElement): namespace = "http://apple.com/ns/ical/" name = "calendar-color" # # calendarserver-principal-search REPORT # @registerElement class CalendarServerPrincipalSearchToken (WebDAVTextElement): """ Contains a search token. """ namespace = calendarserver_namespace name = "search-token" @registerElement class CalendarServerPrincipalSearch (WebDAVElement): namespace = calendarserver_namespace name = "calendarserver-principal-search" allowed_children = { (calendarserver_namespace, "search-token"): (0, None), (calendarserver_namespace, "limit"): (0, 1), (dav_namespace, "prop"): (0, 1), (dav_namespace, "apply-to-principal-collection-set"): (0, 1), } allowed_attributes = {"context": False} ## # Extensions to ResourceType ## ResourceType.dropboxhome = ResourceType(Collection(), DropBoxHome()) ResourceType.dropbox = ResourceType(Collection(), DropBox()) ResourceType.calendarproxyread = ResourceType(Principal(), Collection(), CalendarProxyRead()) ResourceType.calendarproxywrite = ResourceType(Principal(), Collection(), CalendarProxyWrite()) ResourceType.timezones = ResourceType(Timezones()) ResourceType.ischeduleinbox = ResourceType(IScheduleInbox()) ResourceType.freebusyurl = ResourceType(FreeBusyURL()) ResourceType.notification = ResourceType(Collection(), Notification()) ResourceType.sharedownercalendar = ResourceType(Collection(), caldavxml.Calendar(), SharedOwner()) ResourceType.sharedcalendar = ResourceType(Collection(), caldavxml.Calendar(), Shared()) ResourceType.sharedowneraddressbook = ResourceType(Collection(), carddavxml.AddressBook(), SharedOwner()) ResourceType.sharedaddressbook = ResourceType(Collection(), carddavxml.AddressBook(), Shared()) ResourceType.sharedownergroup = ResourceType(SharedOwner()) ResourceType.sharedgroup = ResourceType(Shared()) ResourceType.link = ResourceType(Link()) calendarserver-5.2+dfsg/twistedcaldav/timezonestdservice.py0000644000175000017500000006766112263343324023467 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Timezone service resource and operations. This is based on http://tools.ietf.org/html/draft-douglass-timezone-service which is the CalConnect proposal for a standard timezone service. """ __all__ = [ "TimezoneStdServiceResource", ] from twext.python.log import Logger from twext.web2 import responsecode from twext.web2.dav.method.propfind import http_PROPFIND from twext.web2.dav.noneprops import NonePropertyStore from twext.web2.http import HTTPError, JSONResponse from twext.web2.http import Response from twext.web2.http_headers import MimeType from twext.web2.stream import MemoryStream from txdav.xml import element as davxml from twisted.internet.defer import succeed, inlineCallbacks, returnValue, \ DeferredList from twistedcaldav import xmlutil from twistedcaldav.client.geturl import getURL from twistedcaldav.config import config from twistedcaldav.extensions import DAVResource, \ DAVResourceWithoutChildrenMixin from twistedcaldav.ical import tzexpandlocal from twistedcaldav.resource import ReadOnlyNoCopyResourceMixIn from twistedcaldav.timezones import TimezoneException, TimezoneCache, readVTZ, \ addVTZ from twistedcaldav.xmlutil import addSubElement from pycalendar.calendar import PyCalendar from pycalendar.datetime import PyCalendarDateTime from pycalendar.exceptions import PyCalendarInvalidData import hashlib import itertools import json import os log = Logger() class TimezoneStdServiceResource (ReadOnlyNoCopyResourceMixIn, DAVResourceWithoutChildrenMixin, DAVResource): """ Timezone Service resource. Strictly speaking this is an HTTP-only resource no WebDAV support needed. Extends L{DAVResource} to provide timezone service functionality. """ def __init__(self, parent): """ @param parent: the parent resource of this one. """ assert parent is not None DAVResource.__init__(self, principalCollections=parent.principalCollections()) self.parent = parent self.expandcache = {} self.primary = True self.info_source = None if config.TimezoneService.Mode == "primary": log.info("Using primary Timezone Service") self._initPrimaryService() elif config.TimezoneService.Mode == "secondary": log.info("Using secondary Timezone Service") self._initSecondaryService() else: raise ValueError("Invalid TimezoneService mode: %s" % (config.TimezoneService.Mode,)) def _initPrimaryService(self): tzpath = TimezoneCache.getDBPath() xmlfile = os.path.join(tzpath, "timezones.xml") self.timezones = PrimaryTimezoneDatabase(tzpath, xmlfile) if not os.path.exists(xmlfile): self.timezones.createNewDatabase() else: self.timezones.readDatabase() self.info_source = TimezoneCache.version def _initSecondaryService(self): # Must have writeable paths tzpath = TimezoneCache.getDBPath() xmlfile = config.TimezoneService.XMLInfoPath if not xmlfile: xmlfile = os.path.join(tzpath, "timezones.xml") self.timezones = SecondaryTimezoneDatabase(tzpath, xmlfile, None) try: self.timezones.readDatabase() except: pass self.info_source = "Secondary" self.primary = False def onStartup(self): return self.timezones.onStartup() def deadProperties(self): if not hasattr(self, "_dead_properties"): self._dead_properties = NonePropertyStore(self) return self._dead_properties def etag(self): return succeed(None) def checkPreconditions(self, request): return None def checkPrivileges(self, request, privileges, recurse=False, principal=None, inherited_aces=None): return succeed(None) def defaultAccessControlList(self): return davxml.ACL( # DAV:Read for all principals (includes anonymous) davxml.ACE( davxml.Principal(davxml.All()), davxml.Grant( davxml.Privilege(davxml.Read()), ), davxml.Protected(), ), ) def contentType(self): return MimeType.fromString("text/html; charset=utf-8") def resourceType(self): return None def isCollection(self): return False def isCalendarCollection(self): return False def isPseudoCalendarCollection(self): return False def render(self, request): output = """ Timezone Standard Service Resource

Timezone Standard Service Resource.

""" response = Response(200, {}, output) response.headers.setHeader("content-type", MimeType("text", "html")) return response http_PROPFIND = http_PROPFIND def http_GET(self, request): """ The timezone service POST method. """ # GET and POST do the same thing return self.http_POST(request) def http_POST(self, request): """ The timezone service POST method. """ # Check authentication and access controls def _gotResult(_): if not request.args: # Do normal GET behavior return self.render(request) action = request.args.get("action", ("",)) if len(action) != 1: raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, { "error": "invalid-action", "description": "Invalid action query parameter", }, )) action = action[0] action = { "capabilities" : self.doCapabilities, "list" : self.doList, "get" : self.doGet, "expand" : self.doExpand, }.get(action, None) if action is None: raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, { "error": "invalid-action", "description": "Unknown action query parameter", }, )) return action(request) d = self.authorize(request, (davxml.Read(),)) d.addCallback(_gotResult) return d def doCapabilities(self, request): """ Return a list of all timezones known to the server. """ result = { "info" : { "version": "1", "primary-source" if self.primary else "secondary_source": self.info_source, "contacts" : [], }, "actions" : [ { "name": "capabilities", "parameters": [], }, { "name": "list", "parameters": [ {"name": "changedsince", "required": False, "multi": False, }, ], }, { "name": "get", "parameters": [ {"name": "format", "required": False, "multi": False, "values": ["text/calendar", "text/plain", ], }, {"name": "tzid", "required": True, "multi": False, }, ], }, { "name": "expand", "parameters": [ {"name": "tzid", "required": True, "multi": False, }, {"name": "start", "required": False, "multi": False, }, {"name": "end", "required": False, "multi": False, }, ], }, ] } return JSONResponse(responsecode.OK, result) def doList(self, request): """ Return a list of all timezones known to the server. """ changedsince = request.args.get("changedsince", ()) if len(changedsince) > 1: raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, { "error": "invalid-changedsince", "description": "Invalid changedsince query parameter", }, )) if len(changedsince) == 1: # Validate a date-time stamp changedsince = changedsince[0] try: dt = PyCalendarDateTime.parseText(changedsince) except ValueError: raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, { "error": "invalid-changedsince", "description": "Invalid changedsince query parameter", }, )) if not dt.utc(): raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, "Invalid changedsince query parameter value", )) timezones = [] for tz in self.timezones.listTimezones(changedsince): timezones.append({ "tzid": tz.tzid, "last-modified": tz.dtstamp, "aliases": tz.aliases, }) result = { "dtstamp": self.timezones.dtstamp, "timezones": timezones, } return JSONResponse(responsecode.OK, result) def doGet(self, request): """ Return the specified timezone data. """ tzids = request.args.get("tzid", ()) if len(tzids) != 1: raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, { "error": "invalid-tzid", "description": "Invalid tzid query parameter", }, )) format = request.args.get("format", ("text/calendar",)) if len(format) != 1 or format[0] not in ("text/calendar", "text/plain",): raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, { "error": "invalid-format", "description": "Invalid format query parameter", }, )) format = format[0] calendar = self.timezones.getTimezone(tzids[0]) if calendar is None: raise HTTPError(JSONResponse( responsecode.NOT_FOUND, { "error": "missing-tzid", "description": "Tzid could not be found", } )) tzdata = calendar.getText() response = Response() response.stream = MemoryStream(tzdata) response.headers.setHeader("content-type", MimeType.fromString("%s; charset=utf-8" % (format,))) return response def doExpand(self, request): """ Expand a timezone within specified start/end dates. """ tzids = request.args.get("tzid", ()) if len(tzids) != 1: raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, { "error": "invalid-tzid", "description": "Invalid tzid query parameter", }, )) try: start = request.args.get("start", ()) if len(start) > 1: raise ValueError() elif len(start) == 1: start = PyCalendarDateTime.parseText(start[0]) else: start = PyCalendarDateTime.getToday() start.setDay(1) start.setMonth(1) except ValueError: raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, { "error": "invalid-start", "description": "Invalid start query parameter", } )) try: end = request.args.get("end", ()) if len(end) > 1: raise ValueError() elif len(end) == 1: end = PyCalendarDateTime.parseText(end[0]) else: end = PyCalendarDateTime.getToday() end.setDay(1) end.setMonth(1) end.offsetYear(10) if end <= start: raise ValueError() except ValueError: raise HTTPError(JSONResponse( responsecode.BAD_REQUEST, { "error": "invalid-end", "description": "Invalid end query parameter", } )) tzid = tzids[0] tzdata = self.timezones.getTimezone(tzid) if tzdata is None: raise HTTPError(JSONResponse( responsecode.NOT_FOUND, { "error": "missing-tzid", "description": "Tzid could not be found", } )) # Now do the expansion (but use a cache to avoid re-calculating TZs) observances = self.expandcache.get((tzid, start, end), None) if observances is None: observances = tzexpandlocal(tzdata, start, end) self.expandcache[(tzid, start, end)] = observances # Turn into JSON result = { "dtstamp": self.timezones.dtstamp, "observances": [ { "name": name, "onset": onset.getXMLText(), "utc-offset-from": utc_offset_from, "utc-offset-to": utc_offset_to, } for onset, utc_offset_from, utc_offset_to, name in observances ], } return JSONResponse(responsecode.OK, result) class TimezoneInfo(object): """ Maintains information from an on-disk store of timezone files. """ def __init__(self, tzid, aliases, dtstamp, md5): self.tzid = tzid self.aliases = aliases self.dtstamp = dtstamp self.md5 = md5 @classmethod def readXML(cls, node): """ Parse XML data. """ if node.tag != "timezone": return None tzid = node.findtext("tzid") dtstamp = node.findtext("dtstamp") aliases = tuple([alias_node.text for alias_node in node.findall("alias")]) md5 = node.findtext("md5") return cls(tzid, aliases, dtstamp, md5) def generateXML(self, parent): """ Generate the XML element for this timezone info. """ node = xmlutil.addSubElement(parent, "timezone") xmlutil.addSubElement(node, "tzid", self.tzid) xmlutil.addSubElement(node, "dtstamp", self.dtstamp) for alias in self.aliases: xmlutil.addSubElement(node, "alias", alias) xmlutil.addSubElement(node, "md5", self.md5) class CommonTimezoneDatabase(object): """ Maintains the database of timezones read from an XML file. """ def __init__(self, basepath, xmlfile): self.basepath = basepath self.xmlfile = xmlfile self.dtstamp = None self.timezones = {} self.aliases = {} def onStartup(self): return succeed(None) def readDatabase(self): """ Read in XML data. """ _ignore, root = xmlutil.readXML(self.xmlfile, "timezones") self.dtstamp = root.findtext("dtstamp") for child in root: if child.tag == "timezone": tz = TimezoneInfo.readXML(child) if tz: self.timezones[tz.tzid] = tz for alias in tz.aliases: self.aliases[alias] = tz.tzid def listTimezones(self, changedsince): """ List timezones (not aliases) possibly changed since a particular dtstamp. """ for tzid, tzinfo in sorted(self.timezones.items(), key=lambda x: x[0]): # Ignore those that are aliases if tzid in self.aliases: continue # Detect timestamp changes if changedsince and tzinfo.dtstamp <= changedsince: continue yield tzinfo def getTimezone(self, tzid): """ Generate a PyCalendar containing the requested timezone. """ # We will just use our existing TimezoneCache here calendar = PyCalendar() try: vtz = readVTZ(tzid) calendar.addComponent(vtz.getComponents()[0].duplicate()) except TimezoneException: # Check if an alias exists and create data for that if tzid in self.aliases: try: vtz = readVTZ(self.aliases[tzid]) except TimezoneException: log.error("Failed to find timezone data for alias: %s" % (tzid,)) return None else: vtz = vtz.duplicate() vtz.getComponents()[0].getProperties("TZID")[0].setValue(tzid) addVTZ(tzid, vtz) calendar.addComponent(vtz.getComponents()[0].duplicate()) else: log.error("Failed to find timezone data for: %s" % (tzid,)) return None return calendar def _dumpTZs(self): _ignore, root = xmlutil.newElementTreeWithRoot("timezones") addSubElement(root, "dtstamp", self.dtstamp) for _ignore, v in sorted(self.timezones.items(), key=lambda x: x[0]): v.generateXML(root) xmlutil.writeXML(self.xmlfile, root) def _buildAliases(self): """ Rebuild aliases mappings from current tzinfo. """ self.aliases = {} for tzinfo in self.timezones.values(): for alias in tzinfo.aliases: self.aliases[alias] = tzinfo.tzid class PrimaryTimezoneDatabase(CommonTimezoneDatabase): """ Maintains the database of timezones read from an XML file. """ def __init__(self, basepath, xmlfile): super(PrimaryTimezoneDatabase, self).__init__(basepath, xmlfile) def createNewDatabase(self): """ Create a new DB xml file from scratch by scanning zoneinfo. """ self.dtstamp = PyCalendarDateTime.getNowUTC().getXMLText() self._scanTZs("") self._dumpTZs() def _scanTZs(self, path, checkIfChanged=False): # Read in all timezone files first for item in os.listdir(os.path.join(self.basepath, path)): fullPath = os.path.join(self.basepath, path, item) if item.find('.') == -1: self._scanTZs(os.path.join(path, item), checkIfChanged) elif item.endswith(".ics"): # Build TimezoneInfo object tzid = os.path.join(path, item[:-4]) try: md5 = hashlib.md5(open(fullPath).read()).hexdigest() except IOError: log.error("Unable to read timezone file: %s" % (fullPath,)) continue if checkIfChanged: oldtz = self.timezones.get(tzid) if oldtz != None and oldtz.md5 == md5: continue self.changeCount += 1 self.changed.add(tzid) self.timezones[tzid] = TimezoneInfo(tzid, (), self.dtstamp, md5) # Try links (aliases) file try: aliases = open(os.path.join(self.basepath, "links.txt")).read() except IOError, e: log.error("Unable to read links.txt file: %s" % (str(e),)) aliases = "" try: for alias in aliases.splitlines(): alias_from, alias_to = alias.split() tzinfo = self.timezones.get(alias_to) if tzinfo: if alias_from != alias_to: if alias_from not in tzinfo.aliases: tzinfo.aliases += (alias_from,) self.aliases[alias_from] = alias_to else: log.error("Missing alias from '%s' to '%s'" % (alias_from, alias_to,)) except ValueError: log.error("Unable to parse links.txt file: %s" % (str(e),)) def updateDatabase(self): """ Update existing DB info by comparing md5's. """ self.dtstamp = PyCalendarDateTime.getNowUTC().getXMLText() self.changeCount = 0 self.changed = set() self._scanTZs("", checkIfChanged=True) if self.changeCount: self._dumpTZs() class SecondaryTimezoneDatabase(CommonTimezoneDatabase): """ Caches a database of timezones from another timezone service. """ def __init__(self, basepath, xmlfile, uri): super(SecondaryTimezoneDatabase, self).__init__(basepath, xmlfile) self.uri = uri self.discovered = False self._url = None log.debug("Configuring secondary server with basepath: %s" % (self.basepath,)) if not os.path.exists(self.basepath): os.makedirs(self.basepath) # Paths need to be writeable if not os.access(basepath, os.W_OK): raise ValueError("Secondary Timezone Service needs writeable zoneinfo path at: %s" % (basepath,)) if os.path.exists(xmlfile) and not os.access(xmlfile, os.W_OK): raise ValueError("Secondary Timezone Service needs writeable xmlfile path at: %s" % (xmlfile,)) def onStartup(self): return self.syncWithServer() @inlineCallbacks def syncWithServer(self): """ Sync local data with that from the server we are replicating. """ log.debug("Sync'ing with secondary server") result = (yield self._getTimezoneListFromServer()) if result is None: # Nothing changed since last sync log.debug("No changes on secondary server") returnValue(None) newdtstamp, newtimezones = result # Compare timezone infos # New ones on the server newtzids = set(newtimezones.keys()) - set(self.timezones.keys()) # Check for changes changedtzids = set() for tzid in set(newtimezones.keys()) & set(self.timezones.keys()): if self.timezones[tzid].dtstamp < newtimezones[tzid].dtstamp: changedtzids.add(tzid) log.debug("Fetching %d new, %d changed timezones on secondary server" % (len(newtzids), len(changedtzids),)) # Now apply changes - do requests in parallel for speedier fetching BATCH = 5 tzids = list(itertools.chain(newtzids, changedtzids)) tzids.sort() while tzids: yield DeferredList([self._getTimezoneFromServer(newtimezones[tzid]) for tzid in tzids[0:BATCH]]) tzids = tzids[BATCH:] self.dtstamp = newdtstamp self._dumpTZs() self._buildAliases() log.debug("Sync with secondary server complete") returnValue((len(newtzids), len(changedtzids),)) @inlineCallbacks def _discoverServer(self): """ Make sure we know the timezone service path """ if self.uri is None: if config.TimezoneService.SecondaryService.Host: self.uri = "https://%s/.well-known/timezone" % (config.TimezoneService.SecondaryService.Host,) elif config.TimezoneService.SecondaryService.URI: self.uri = config.TimezoneService.SecondaryService.URI elif not self.uri.startswith("https:") and not self.uri.startswith("http:"): self.uri = "https://%s/.well-known/timezone" % (self.uri,) testURI = "%s?action=capabilities" % (self.uri,) log.debug("Discovering secondary server: %s" % (testURI,)) response = (yield getURL(testURI)) if response is None or response.code / 100 != 2: log.error("Unable to discover secondary server: %s" % (testURI,)) self.discovered = False returnValue(False) # Cache the redirect target if hasattr(response, "location"): self.uri = response.location log.debug("Redirected secondary server to: %s" % (self.uri,)) # TODO: Ignoring the data from capabilities for now self.discovered = True returnValue(True) @inlineCallbacks def _getTimezoneListFromServer(self): """ Retrieve the timezone list from the specified server """ # Make sure we have the server if not self.discovered: result = (yield self._discoverServer()) if not result: returnValue(None) # List all from the server url = "%s?action=list" % (self.uri,) if self.dtstamp: url = "%s&changedsince=%s" % (url, self.dtstamp,) log.debug("Getting timezone list from secondary server: %s" % (url,)) response = (yield getURL(url)) if response is None or response.code / 100 != 2: returnValue(None) ct = response.headers.getRawHeaders("content-type", ("bogus/type",))[0] ct = ct.split(";", 1) ct = ct[0] if ct not in ("application/json",): returnValue(None) try: jroot = json.loads(response.data) dtstamp = jroot["dtstamp"] timezones = {} for timezone in jroot["timezones"]: tzid = timezone["tzid"] lastmod = timezone["last-modified"] aliases = timezone.get("aliases", ()) timezones[tzid] = TimezoneInfo(tzid, aliases, lastmod, None) except (ValueError, KeyError): log.debug("Failed to parse JSON timezone list response: %s" % (response.data,)) returnValue(None) log.debug("Got %s timezones from secondary server" % (len(timezones),)) returnValue((dtstamp, timezones,)) @inlineCallbacks def _getTimezoneFromServer(self, tzinfo): # List all from the server url = "%s?action=get&tzid=%s" % (self.uri, tzinfo.tzid,) log.debug("Getting timezone from secondary server: %s" % (url,)) response = (yield getURL(url)) if response is None or response.code / 100 != 2: returnValue(None) ct = response.headers.getRawHeaders("content-type", ("bogus/type",))[0] ct = ct.split(";", 1) ct = ct[0] if ct not in ("text/calendar",): log.error("Invalid content-type '%s' for tzid : %s" % (ct, tzinfo.tzid,)) returnValue(None) ical = response.data try: calendar = PyCalendar.parseText(ical) except PyCalendarInvalidData: log.error("Invalid calendar data for tzid: %s" % (tzinfo.tzid,)) returnValue(None) ical = calendar.getText() tzinfo.md5 = hashlib.md5(ical).hexdigest() try: tzpath = os.path.join(self.basepath, tzinfo.tzid) + ".ics" if not os.path.exists(os.path.dirname(tzpath)): os.makedirs(os.path.dirname(tzpath)) f = open(tzpath, "w") f.write(ical) f.close() except IOError, e: log.error("Unable to write calendar file for %s: %s" % (tzinfo.tzid, str(e),)) else: self.timezones[tzinfo.tzid] = tzinfo def _removeTimezone(self, tzid): tzpath = os.path.join(self.basepath, tzid) + ".ics" try: os.remove(tzpath) del self.timezones[tzid] except IOError, e: log.error("Unable to write calendar file for %s: %s" % (tzid, str(e),)) calendarserver-5.2+dfsg/twistedcaldav/linkresource.py0000644000175000017500000001123612263343324022231 0ustar rahulrahul## # Copyright (c) 2010-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twext.python.log import Logger from twext.web2 import responsecode, server, http from txdav.xml import element as davxml from twext.web2.http import HTTPError, StatusResponse from twext.web2.resource import WrapperResource from twisted.internet.defer import inlineCallbacks, returnValue, maybeDeferred from twistedcaldav.config import config __all__ = [ "LinkResource", ] # FIXME: copied from resource.py to avoid circular dependency class CalDAVComplianceMixIn(object): def davComplianceClasses(self): return ( tuple(super(CalDAVComplianceMixIn, self).davComplianceClasses()) + config.CalDAVComplianceClasses ) """ A resource that is a soft-link to another. """ class LinkResource(CalDAVComplianceMixIn, WrapperResource): """ This is similar to a WrapperResource except that we locate our resource dynamically. We need to deal with the case of a missing underlying resource (broken link) as indicated by self._linkedResource being None. """ log = Logger() def __init__(self, parent, link_url): self.parent = parent self.linkURL = link_url self.loopDetect = set() super(LinkResource, self).__init__(self.parent.principalCollections()) @inlineCallbacks def linkedResource(self, request): if not hasattr(self, "_linkedResource"): if self.linkURL in self.loopDetect: raise HTTPError(StatusResponse(responsecode.LOOP_DETECTED, "Recursive link target: %s" % (self.linkURL,))) else: self.loopDetect.add(self.linkURL) self._linkedResource = (yield request.locateResource(self.linkURL)) self.loopDetect.remove(self.linkURL) if self._linkedResource is None: raise HTTPError(StatusResponse(responsecode.NOT_FOUND, "Missing link target: %s" % (self.linkURL,))) returnValue(self._linkedResource) def isCollection(self): return True if hasattr(self, "_linkedResource") else False def resourceType(self): return self._linkedResource.resourceType() if hasattr(self, "_linkedResource") else davxml.ResourceType.link def locateChild(self, request, segments): def _defer(result): if result is None: return (self, server.StopTraversal) else: return (result, segments) d = self.linkedResource(request) d.addCallback(_defer) return d @inlineCallbacks def renderHTTP(self, request): linked_to = (yield self.linkedResource(request)) if linked_to: returnValue(linked_to) else: returnValue(http.StatusResponse(responsecode.OK, "Link resource with missing target: %s" % (self.linkURL,))) def getChild(self, name): return self._linkedResource.getChild(name) if hasattr(self, "_linkedResource") else None @inlineCallbacks def hasProperty(self, property, request): hosted = (yield self.linkedResource(request)) result = (yield hosted.hasProperty(property, request)) if hosted else False returnValue(result) @inlineCallbacks def readProperty(self, property, request): hosted = (yield self.linkedResource(request)) result = (yield hosted.readProperty(property, request)) if hosted else None returnValue(result) @inlineCallbacks def writeProperty(self, property, request): hosted = (yield self.linkedResource(request)) result = (yield hosted.writeProperty(property, request)) if hosted else None returnValue(result) class LinkFollowerMixIn(object): @inlineCallbacks def locateChild(self, req, segments): self._inside_locateChild = True resource, path = (yield maybeDeferred(super(LinkFollowerMixIn, self).locateChild, req, segments)) while isinstance(resource, LinkResource): linked_to = (yield resource.linkedResource(req)) if linked_to is None: break resource = linked_to returnValue((resource, path)) calendarserver-5.2+dfsg/twistedcaldav/timezonexml.py0000644000175000017500000001213512263343324022076 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ This module provides XML definitions for use with Timezone Standard Service. """ from txdav.xml.element import registerElement from txdav.xml.element import WebDAVElement, WebDAVEmptyElement, WebDAVTextElement ## # Timezone Service XML Definitions ## timezone_namespace = "urn:ietf:params:xml:ns:timezone-service" @registerElement class Capabilities (WebDAVElement): namespace = timezone_namespace name = "capabilities" allowed_children = { (timezone_namespace, "operation"): (0, None), } @registerElement class Operation (WebDAVElement): namespace = timezone_namespace name = "operation" allowed_children = { (timezone_namespace, "action"): (1, 1), (timezone_namespace, "description"): (0, 1), (timezone_namespace, "accept-parameter"): (0, None), } @registerElement class Action (WebDAVTextElement): namespace = timezone_namespace name = "action" @registerElement class Description (WebDAVTextElement): namespace = timezone_namespace name = "description" @registerElement class AcceptParameter (WebDAVElement): namespace = timezone_namespace name = "accept-parameter" allowed_children = { (timezone_namespace, "name"): (1, 1), (timezone_namespace, "required"): (1, 1), (timezone_namespace, "multi"): (1, 1), (timezone_namespace, "value"): (0, None), (timezone_namespace, "description"): (0, 1), } @registerElement class Name (WebDAVTextElement): namespace = timezone_namespace name = "name" @registerElement class Required (WebDAVTextElement): namespace = timezone_namespace name = "required" @registerElement class Multi (WebDAVTextElement): namespace = timezone_namespace name = "multi" @registerElement class Value (WebDAVTextElement): namespace = timezone_namespace name = "value" @registerElement class TimezoneList (WebDAVElement): namespace = timezone_namespace name = "timezone-list" allowed_children = { (timezone_namespace, "dtstamp"): (1, 1), (timezone_namespace, "summary"): (0, None), } @registerElement class Dtstamp (WebDAVTextElement): namespace = timezone_namespace name = "dtstamp" @registerElement class Summary (WebDAVElement): namespace = timezone_namespace name = "summary" allowed_children = { (timezone_namespace, "tzid"): (1, 1), (timezone_namespace, "last-modified"): (1, 1), (timezone_namespace, "local-name"): (0, None), (timezone_namespace, "alias"): (0, None), (timezone_namespace, "inactive"): (0, 1), } @registerElement class Tzid (WebDAVTextElement): namespace = timezone_namespace name = "tzid" @registerElement class LastModified (WebDAVTextElement): namespace = timezone_namespace name = "last-modified" @registerElement class LocalName (WebDAVTextElement): namespace = timezone_namespace name = "local-name" @registerElement class Alias (WebDAVTextElement): namespace = timezone_namespace name = "alias" @registerElement class Inactive (WebDAVEmptyElement): namespace = timezone_namespace name = "inactive" @registerElement class Timezones (WebDAVElement): namespace = timezone_namespace name = "timezones" allowed_children = { (timezone_namespace, "dtstamp"): (1, 1), (timezone_namespace, "tzdata"): (0, None), } @registerElement class Tzdata (WebDAVElement): namespace = timezone_namespace name = "tzdata" allowed_children = { (timezone_namespace, "tzid"): (1, 1), (timezone_namespace, "calscale"): (0, 1), (timezone_namespace, "observance"): (0, None), } @registerElement class Calscale (WebDAVTextElement): namespace = timezone_namespace name = "calscale" @registerElement class Observance (WebDAVElement): namespace = timezone_namespace name = "observance" allowed_children = { (timezone_namespace, "name"): (1, 1), (timezone_namespace, "local-name"): (0, None), (timezone_namespace, "onset"): (1, 1), (timezone_namespace, "utc-offset-from"): (1, 1), (timezone_namespace, "utc-offset-to"): (1, 1), } @registerElement class Onset (WebDAVTextElement): namespace = timezone_namespace name = "onset" @registerElement class UTCOffsetFrom (WebDAVTextElement): namespace = timezone_namespace name = "utc-offset-from" @registerElement class UTCOffsetTo (WebDAVTextElement): namespace = timezone_namespace name = "utc-offset-to" calendarserver-5.2+dfsg/twistedcaldav/query/0000755000175000017500000000000012322625314020312 5ustar rahulrahulcalendarserver-5.2+dfsg/twistedcaldav/query/sqlgenerator.py0000644000175000017500000002704112263343324023400 0ustar rahulrahul## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from __future__ import print_function """ SQL statement generator from query expressions. """ __version__ = "0.0" __all__ = [ "sqlgenerator", ] from twistedcaldav.query import expression import cStringIO as StringIO class sqlgenerator(object): FROM = " from " WHERE = " where " RESOURCEDB = "RESOURCE" TIMESPANDB = "TIMESPAN" TRANSPARENCYDB = "TRANSPARENCY" PERUSERDB = "PERUSER" NOTOP = "NOT " ANDOP = " AND " OROP = " OR " CONTAINSOP = " GLOB " NOTCONTAINSOP = " NOT GLOB " ISOP = " == " ISNOTOP = " != " STARTSWITHOP = " GLOB " NOTSTARTSWITHOP = " NOT GLOB " ENDSWITHOP = " GLOB " NOTENDSWITHOP = " NOT GLOB " INOP = " IN " NOTINOP = " NOT IN " FIELDS = { "TYPE": "RESOURCE.TYPE", "UID": "RESOURCE.UID", } TIMESPANTEST = "((TIMESPAN.FLOAT == 'N' AND TIMESPAN.START < %s AND TIMESPAN.END > %s) OR (TIMESPAN.FLOAT == 'Y' AND TIMESPAN.START < %s AND TIMESPAN.END > %s))" TIMESPANTEST_NOEND = "((TIMESPAN.FLOAT == 'N' AND TIMESPAN.END > %s) OR (TIMESPAN.FLOAT == 'Y' AND TIMESPAN.END > %s))" TIMESPANTEST_NOSTART = "((TIMESPAN.FLOAT == 'N' AND TIMESPAN.START < %s) OR (TIMESPAN.FLOAT == 'Y' AND TIMESPAN.START < %s))" TIMESPANTEST_TAIL_PIECE = " AND TIMESPAN.RESOURCEID == RESOURCE.RESOURCEID" TIMESPANTEST_JOIN_ON_PIECE = "TIMESPAN.INSTANCEID == TRANSPARENCY.INSTANCEID AND TRANSPARENCY.PERUSERID == %s" def __init__(self, expr, calendarid, userid, freebusy=False): """ @param expr: the query expression object model @type expr: L{twistedcaldav.query.calendarqueryfilter.Filter} @param calendarid: resource ID - not used for file-based per-calendar indexes @type calendarid: C{int} @param userid: user for whom query is being done - query will be scoped to that user's privileges and their transparency @type userid: C{str} @param freebusy: whether or not a freebusy query is being done - if it is, additional time range and transparency information is returned @type freebusy: C{bool} """ self.expression = expr self.calendarid = calendarid self.userid = userid if userid else "" self.freebusy = freebusy self.usedtimespan = False def generate(self): """ Generate the actual SQL 'where ...' expression from the passed in expression tree. @return: a C{tuple} of (C{str}, C{list}), where the C{str} is the partial SQL statement, and the C{list} is the list of argument substitutions to use with the SQL API execute method. """ # Init state self.sout = StringIO.StringIO() self.arguments = [] self.substitutions = [] self.usedtimespan = False # Generate ' where ...' partial statement self.generateExpression(self.expression) # Prefix with ' from ...' partial statement select = self.FROM + self.RESOURCEDB if self.usedtimespan: # Free busy needs transparency join if self.freebusy: self.frontArgument(self.userid) select += ", %s LEFT OUTER JOIN %s ON (%s)" % ( self.TIMESPANDB, self.TRANSPARENCYDB, self.TIMESPANTEST_JOIN_ON_PIECE ) else: select += ", %s" % ( self.TIMESPANDB, ) select += self.WHERE if self.usedtimespan: select += "(" select += self.sout.getvalue() if self.usedtimespan: if self.calendarid: self.setArgument(self.calendarid) select += ")%s" % (self.TIMESPANTEST_TAIL_PIECE,) select = select % tuple(self.substitutions) return select, self.arguments def generateExpression(self, expr): """ Generate an expression and all it's subexpressions. @param expr: the L{baseExpression} derived class to write out. @return: C{True} if the TIMESPAN table is used, C{False} otherwise. """ # Generate based on each type of expression we might encounter # ALL if isinstance(expr, expression.allExpression): # Wipe out the ' where ...' clause so everything is matched self.sout.truncate(0) self.arguments = [] self.substitutions = [] self.usedtimespan = False # NOT elif isinstance(expr, expression.notExpression): self.sout.write(self.NOTOP) self.generateSubExpression(expr.expressions[0]) # AND elif isinstance(expr, expression.andExpression): first = True for e in expr.expressions: if first: first = False else: self.sout.write(self.ANDOP) self.generateSubExpression(e) # OR elif isinstance(expr, expression.orExpression): first = True for e in expr.expressions: if first: first = False else: self.sout.write(self.OROP) self.generateSubExpression(e) # time-range elif isinstance(expr, expression.timerangeExpression): if expr.start and expr.end: self.setArgument(expr.end) self.setArgument(expr.start) self.setArgument(expr.endfloat) self.setArgument(expr.startfloat) test = self.TIMESPANTEST elif expr.start and expr.end is None: self.setArgument(expr.start) self.setArgument(expr.startfloat) test = self.TIMESPANTEST_NOEND elif not expr.start and expr.end: self.setArgument(expr.end) self.setArgument(expr.endfloat) test = self.TIMESPANTEST_NOSTART self.sout.write(test) self.usedtimespan = True # CONTAINS elif isinstance(expr, expression.containsExpression): self.sout.write(expr.field) self.sout.write(self.CONTAINSOP) self.addArgument(self.containsArgument(expr.text)) # NOT CONTAINS elif isinstance(expr, expression.notcontainsExpression): self.sout.write(expr.field) self.sout.write(self.NOTCONTAINSOP) self.addArgument(self.containsArgument(expr.text)) # IS elif isinstance(expr, expression.isExpression): self.sout.write(expr.field) self.sout.write(self.ISOP) self.addArgument(expr.text) # IS NOT elif isinstance(expr, expression.isnotExpression): self.sout.write(expr.field) self.sout.write(self.ISNOTOP) self.addArgument(expr.text) # STARTSWITH elif isinstance(expr, expression.startswithExpression): self.sout.write(expr.field) self.sout.write(self.STARTSWITHOP) self.addArgument(self.startswithArgument(expr.text)) # NOT STARTSWITH elif isinstance(expr, expression.notstartswithExpression): self.sout.write(expr.field) self.sout.write(self.NOTSTARTSWITHOP) self.addArgument(self.startswithArgument(expr.text)) # ENDSWITH elif isinstance(expr, expression.endswithExpression): self.sout.write(expr.field) self.sout.write(self.ENDSWITHOP) self.addArgument(self.endswithArgument(expr.text)) # NOT ENDSWITH elif isinstance(expr, expression.notendswithExpression): self.sout.write(expr.field) self.sout.write(self.NOTENDSWITHOP) self.addArgument(self.endswithArgument(expr.text)) # IN elif isinstance(expr, expression.inExpression): self.sout.write(expr.field) self.sout.write(self.INOP) self.sout.write("(") for count, item in enumerate(expr.text): if count != 0: self.sout.write(", ") self.addArgument(item) self.sout.write(")") # NOT IN elif isinstance(expr, expression.notinExpression): self.sout.write(expr.field) self.sout.write(self.NOTINOP) self.sout.write("(") for count, item in enumerate(expr.text): if count != 0: self.sout.write(", ") self.addArgument(item) self.sout.write(")") def generateSubExpression(self, expression): """ Generate an SQL expression possibly in parenthesis if its a compound expression. @param expression: the L{baseExpression} to write out. @return: C{True} if the TIMESPAN table is used, C{False} otherwise. """ if expression.multi(): self.sout.write("(") self.generateExpression(expression) if expression.multi(): self.sout.write(")") def addArgument(self, arg): """ @param arg: the C{str} of the argument to add """ # Append argument to the list and add the appropriate substitution string to the output stream. self.arguments.append(arg) self.substitutions.append(":" + str(len(self.arguments))) self.sout.write("%s") def setArgument(self, arg): """ @param arg: the C{str} of the argument to add @return: C{str} for argument substitution text """ # Append argument to the list and add the appropriate substitution string to the output stream. self.arguments.append(arg) self.substitutions.append(":" + str(len(self.arguments))) def frontArgument(self, arg): """ @param arg: the C{str} of the argument to add @return: C{str} for argument substitution text """ # Append argument to the list and add the appropriate substitution string to the output stream. self.arguments.insert(0, arg) self.substitutions.append(":" + str(len(self.arguments))) def containsArgument(self, arg): return "*%s*" % (arg,) def startswithArgument(self, arg): return "%s*" % (arg,) def endswithArgument(self, arg): return "*%s" % (arg,) if __name__ == "__main__": e1 = expression.isExpression("TYPE", "VEVENT", False) e2 = expression.timerangeExpression("20060101T120000Z", "20060101T130000Z", "20060101T080000Z", "20060101T090000Z") e3 = expression.notcontainsExpression("SUMMARY", "help", True) e5 = expression.andExpression([e1, e2, e3]) print(e5) sql = sqlgenerator(e5, 'dummy-cal', 'dummy-user') print(sql.generate()) e6 = expression.inExpression("TYPE", ("VEVENT", "VTODO",), False) print(e6) sql = sqlgenerator(e6, 'dummy-cal', 'dummy-user') print(sql.generate()) e7 = expression.notinExpression("TYPE", ("VEVENT", "VTODO",), False) print(e7) sql = sqlgenerator(e7, 'dummy-cal', 'dummy-user') print(sql.generate()) calendarserver-5.2+dfsg/twistedcaldav/query/expression.py0000644000175000017500000002210412263343324023064 0ustar rahulrahul## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from __future__ import print_function """ Query Expression Elements. These are used to build a 'generic' query expression tree that can then be used by different query language generators to produce the actual query syntax required (SQL, xpath eyc). """ __version__ = "0.0" __all__ = [ "allExpression", "notExpression", "andExpression", "orExpression", "timerangeExpression", "textcompareExpression", "containsExpression", "notcontainsExpression", "isExpression", "isnotExpression", "startswithExpression", "notstartswithExpression", "endswithExpression", "notendswithExpression", "inExpression", "notinExpression", ] class baseExpression(object): """ The base class for all types of expression. """ def __init__(self): pass def multi(self): """ Indicate whether this expression is composed of multiple sub-expressions. @return: C{True} if this expressions contains multiple sub-expressions, C{False} otherwise. """ return False def _collapsedExpression(self): return self def andWith(self, other): if isinstance(other, andExpression): return andExpression((self._collapsedExpression(),) + tuple(other.expressions)) else: return andExpression((self._collapsedExpression(), other._collapsedExpression(),)) def orWith(self, other): if isinstance(other, orExpression): return orExpression((self._collapsedExpression(),) + tuple(other.expressions)) else: return orExpression((self._collapsedExpression(), other._collapsedExpression(),)) class allExpression(baseExpression): """ Match everything. """ def __init__(self): pass class logicExpression(baseExpression): """ An expression representing a logical operation (boolean). """ def __init__(self, expressions): self.expressions = expressions def __str__(self): """ Generate a suitable text descriptor of this expression. @return: a C{str} of the text for this expression. """ result = "" for e in self.expressions: if len(result) != 0: result += " " + self.operator() + " " result += str(e) if len(result): result = "(" + result + ")" return result def multi(self): """ Indicate whether this expression is composed of multiple expressions. @return: C{True} if this expressions contains multiple sub-expressions, C{False} otherwise. """ return True def _collapsedExpression(self): if self.multi() and len(self.expressions) == 1: return self.expressions[0]._collapsedExpression() else: return self class notExpression(logicExpression): """ Logical NOT operation. """ def __init__(self, expression): super(notExpression, self).__init__([expression]) def operator(self): return "NOT" def __str__(self): result = self.operator() + " " + str(self.expressions[0]) return result def multi(self): """ Indicate whether this expression is composed of multiple expressions. @return: C{True} if this expressions contains multiple sub-expressions, C{False} otherwise. """ return False class andExpression(logicExpression): """ Logical AND operation. """ def __init__(self, expressions): super(andExpression, self).__init__(expressions) def operator(self): return "AND" def andWith(self, other): self.expressions = tuple(self.expressions) + (other._collapsedExpression(),) return self class orExpression(logicExpression): """ Logical OR operation. """ def __init__(self, expressions): super(orExpression, self).__init__(expressions) def operator(self): return "OR" def orWith(self, other): self.expressions = tuple(self.expressions) + (other._collapsedExpression(),) return self class timerangeExpression(baseExpression): """ CalDAV time-range comparison expression. """ def __init__(self, start, end, startfloat, endfloat): self.start = start self.end = end self.startfloat = startfloat self.endfloat = endfloat def __str__(self): return "timerange(" + str(self.start) + ", " + str(self.end) + ")" class textcompareExpression(baseExpression): """ Base class for text comparison expressions. """ def __init__(self, field, text, caseless): self.field = field self.text = text self.caseless = caseless def __str__(self): return self.operator() + "(" + self.field + ", " + self.text + ", " + str(self.caseless) + ")" class containsExpression(textcompareExpression): """ Text CONTAINS (sub-string match) expression. """ def __init__(self, field, text, caseless): super(containsExpression, self).__init__(field, text, caseless) def operator(self): return "contains" class notcontainsExpression(textcompareExpression): """ Text NOT CONTAINS (sub-string match) expression. """ def __init__(self, field, text, caseless): super(notcontainsExpression, self).__init__(field, text, caseless) def operator(self): return "does not contain" class isExpression(textcompareExpression): """ Text IS (exact string match) expression. """ def __init__(self, field, text, caseless): super(isExpression, self).__init__(field, text, caseless) def operator(self): return "is" class isnotExpression(textcompareExpression): """ Text IS NOT (exact string match) expression. """ def __init__(self, field, text, caseless): super(isnotExpression, self).__init__(field, text, caseless) def operator(self): return "is not" class startswithExpression(textcompareExpression): """ Text STARTSWITH (sub-string match) expression. """ def __init__(self, field, text, caseless): super(startswithExpression, self).__init__(field, text, caseless) def operator(self): return "starts with" class notstartswithExpression(textcompareExpression): """ Text NOT STARTSWITH (sub-string match) expression. """ def __init__(self, field, text, caseless): super(notstartswithExpression, self).__init__(field, text, caseless) def operator(self): return "does not start with" class endswithExpression(textcompareExpression): """ Text STARTSWITH (sub-string match) expression. """ def __init__(self, field, text, caseless): super(endswithExpression, self).__init__(field, text, caseless) def operator(self): return "ends with" class notendswithExpression(textcompareExpression): """ Text NOT STARTSWITH (sub-string match) expression. """ def __init__(self, field, text, caseless): super(notendswithExpression, self).__init__(field, text, caseless) def operator(self): return "does not end with" class inExpression(textcompareExpression): """ Text IN (exact string match to one of the supplied items) expression. """ def __init__(self, field, text_list, caseless): super(inExpression, self).__init__(field, text_list, caseless) def operator(self): return "in" def __str__(self): return self.operator() + "(" + self.field + ", " + str(self.text) + ", " + str(self.caseless) + ")" class notinExpression(textcompareExpression): """ Text NOT IN (exact string match to none of the supplied items) expression. """ def __init__(self, field, text, caseless): super(notinExpression, self).__init__(field, text, caseless) def operator(self): return "not in" def __str__(self): return self.operator() + "(" + self.field + ", " + str(self.text) + ", " + str(self.caseless) + ")" if __name__ == "__main__": e1 = isExpression("type", "vevent", False) e2 = timerangeExpression("20060101T120000Z", "20060101T130000Z", "20060101T120000Z", "20060101T130000Z") e3 = containsExpression("summary", "help", True) e4 = notExpression(e3) e5 = andExpression([e1, e2, e4]) print(e5) e6 = inExpression("type", ("vevent", "vtodo",), False) print(e6) e7 = notinExpression("type", ("vevent", "vtodo",), False) print(e7) calendarserver-5.2+dfsg/twistedcaldav/query/test/0000755000175000017500000000000012322625314021271 5ustar rahulrahulcalendarserver-5.2+dfsg/twistedcaldav/query/test/test_addressbookquery.py0000644000175000017500000000253312263343324026275 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twistedcaldav import carddavxml from twistedcaldav.query import addressbookqueryfilter import twistedcaldav.test.util from twistedcaldav.query.addressbookquery import sqladdressbookquery class Tests(twistedcaldav.test.util.TestCase): def test_query(self): """ Basic query test - single term. Only UID can be queried via sql. """ filter = carddavxml.Filter( *[carddavxml.PropertyFilter( carddavxml.TextMatch.fromString("Example"), **{"name":"UID"} )] ) filter = addressbookqueryfilter.Filter(filter) sql, args = sqladdressbookquery(filter) self.assertTrue(sql.find("UID") != -1) self.assertTrue("*Example*" in args) calendarserver-5.2+dfsg/twistedcaldav/query/test/test_calendarquery.py0000644000175000017500000001071212263343324025544 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twistedcaldav import caldavxml from twistedcaldav.query import calendarqueryfilter import twistedcaldav.test.util from pycalendar.timezone import PyCalendarTimezone from twistedcaldav.query.calendarquery import sqlcalendarquery class Tests(twistedcaldav.test.util.TestCase): def test_query(self): """ Basic query test - no time range """ filter = caldavxml.Filter( caldavxml.ComponentFilter( *[caldavxml.ComponentFilter( **{"name":("VEVENT", "VFREEBUSY", "VAVAILABILITY")} )], **{"name": "VCALENDAR"} ) ) filter = calendarqueryfilter.Filter(filter) filter.child.settzinfo(PyCalendarTimezone(tzid="America/New_York")) sql, args = sqlcalendarquery(filter) self.assertTrue(sql.find("RESOURCE") != -1) self.assertTrue(sql.find("TIMESPAN") == -1) self.assertTrue(sql.find("TRANSPARENCY") == -1) self.assertTrue("VEVENT" in args) def test_query_timerange(self): """ Basic query test - with time range """ filter = caldavxml.Filter( caldavxml.ComponentFilter( *[caldavxml.ComponentFilter( *[caldavxml.TimeRange(**{"start":"20060605T160000Z", "end":"20060605T170000Z"})], **{"name":("VEVENT", "VFREEBUSY", "VAVAILABILITY")} )], **{"name": "VCALENDAR"} ) ) filter = calendarqueryfilter.Filter(filter) filter.child.settzinfo(PyCalendarTimezone(tzid="America/New_York")) sql, args = sqlcalendarquery(filter) self.assertTrue(sql.find("RESOURCE") != -1) self.assertTrue(sql.find("TIMESPAN") != -1) self.assertTrue(sql.find("TRANSPARENCY") == -1) self.assertTrue("VEVENT" in args) def test_query_not_extended(self): """ Query test - two terms not anyof """ filter = caldavxml.Filter( caldavxml.ComponentFilter( *[ caldavxml.ComponentFilter( **{"name":("VEVENT")} ), caldavxml.ComponentFilter( **{"name":("VTODO")} ), ], **{"name": "VCALENDAR"} ) ) filter = calendarqueryfilter.Filter(filter) filter.child.settzinfo(PyCalendarTimezone(tzid="America/New_York")) sql, args = sqlcalendarquery(filter) self.assertTrue(sql.find("RESOURCE") != -1) self.assertTrue(sql.find("TIMESPAN") == -1) self.assertTrue(sql.find("TRANSPARENCY") == -1) self.assertTrue(sql.find(" OR ") == -1) self.assertTrue("VEVENT" in args) self.assertTrue("VTODO" in args) def test_query_extended(self): """ Extended query test - two terms with anyof """ filter = caldavxml.Filter( caldavxml.ComponentFilter( *[ caldavxml.ComponentFilter( *[caldavxml.TimeRange(**{"start":"20060605T160000Z", })], **{"name":("VEVENT")} ), caldavxml.ComponentFilter( **{"name":("VTODO")} ), ], **{"name": "VCALENDAR", "test": "anyof"} ) ) filter = calendarqueryfilter.Filter(filter) filter.child.settzinfo(PyCalendarTimezone(tzid="America/New_York")) sql, args = sqlcalendarquery(filter) self.assertTrue(sql.find("RESOURCE") != -1) self.assertTrue(sql.find("TIMESPAN") != -1) self.assertTrue(sql.find("TRANSPARENCY") == -1) self.assertTrue(sql.find(" OR ") != -1) self.assertTrue("VEVENT" in args) self.assertTrue("VTODO" in args) calendarserver-5.2+dfsg/twistedcaldav/query/test/test_expression.py0000644000175000017500000001434112263343324025106 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twistedcaldav.query import expression import twistedcaldav.test.util class Tests(twistedcaldav.test.util.TestCase): def test_andWith(self): tests = ( ( expression.isExpression("A", "1", True), expression.isExpression("B", "2", True), "(is(A, 1, True) AND is(B, 2, True))" ), ( expression.isExpression("A", "1", True), expression.andExpression(( expression.isExpression("B", "2", True), )), "(is(A, 1, True) AND is(B, 2, True))" ), ( expression.isExpression("A", "1", True), expression.andExpression(( expression.isExpression("B", "2", True), expression.isExpression("C", "3", True), )), "(is(A, 1, True) AND is(B, 2, True) AND is(C, 3, True))" ), ( expression.isExpression("A", "1", True), expression.orExpression(( expression.isExpression("B", "2", True), )), "(is(A, 1, True) AND is(B, 2, True))" ), ( expression.isExpression("A", "1", True), expression.orExpression(( expression.isExpression("B", "2", True), expression.isExpression("C", "3", True), )), "(is(A, 1, True) AND (is(B, 2, True) OR is(C, 3, True)))" ), ( expression.andExpression(( expression.isExpression("A", "1", True), )), expression.isExpression("B", "2", True), "(is(A, 1, True) AND is(B, 2, True))" ), ( expression.andExpression(( expression.isExpression("A", "1", True), expression.isExpression("B", "2", True), )), expression.isExpression("C", "3", True), "(is(A, 1, True) AND is(B, 2, True) AND is(C, 3, True))" ), ( expression.orExpression(( expression.isExpression("A", "1", True), )), expression.isExpression("B", "2", True), "(is(A, 1, True) AND is(B, 2, True))" ), ( expression.orExpression(( expression.isExpression("A", "1", True), expression.isExpression("B", "2", True), )), expression.isExpression("C", "3", True), "((is(A, 1, True) OR is(B, 2, True)) AND is(C, 3, True))" ), ) for expr1, expr2, result in tests: self.assertEqual(str(expr1.andWith(expr2)), result, msg="Failed on %s" % (result,)) def test_orWith(self): tests = ( ( expression.isExpression("A", "1", True), expression.isExpression("B", "2", True), "(is(A, 1, True) OR is(B, 2, True))" ), ( expression.isExpression("A", "1", True), expression.andExpression(( expression.isExpression("B", "2", True), )), "(is(A, 1, True) OR is(B, 2, True))" ), ( expression.isExpression("A", "1", True), expression.andExpression(( expression.isExpression("B", "2", True), expression.isExpression("C", "3", True), )), "(is(A, 1, True) OR (is(B, 2, True) AND is(C, 3, True)))" ), ( expression.isExpression("A", "1", True), expression.orExpression(( expression.isExpression("B", "2", True), )), "(is(A, 1, True) OR is(B, 2, True))" ), ( expression.isExpression("A", "1", True), expression.orExpression(( expression.isExpression("B", "2", True), expression.isExpression("C", "3", True), )), "(is(A, 1, True) OR is(B, 2, True) OR is(C, 3, True))" ), ( expression.andExpression(( expression.isExpression("A", "1", True), )), expression.isExpression("B", "2", True), "(is(A, 1, True) OR is(B, 2, True))" ), ( expression.andExpression(( expression.isExpression("A", "1", True), expression.isExpression("B", "2", True), )), expression.isExpression("C", "3", True), "((is(A, 1, True) AND is(B, 2, True)) OR is(C, 3, True))" ), ( expression.orExpression(( expression.isExpression("A", "1", True), )), expression.isExpression("B", "2", True), "(is(A, 1, True) OR is(B, 2, True))" ), ( expression.orExpression(( expression.isExpression("A", "1", True), expression.isExpression("B", "2", True), )), expression.isExpression("C", "3", True), "(is(A, 1, True) OR is(B, 2, True) OR is(C, 3, True))" ), ) for expr1, expr2, result in tests: self.assertEqual(str(expr1.orWith(expr2)), result, msg="Failed on %s" % (result,)) calendarserver-5.2+dfsg/twistedcaldav/query/test/test_queryfilter.py0000644000175000017500000001327112263343324025263 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twistedcaldav import caldavxml from twistedcaldav.query import calendarqueryfilter import twistedcaldav.test.util from twistedcaldav.caldavxml import TimeZone from pycalendar.timezone import PyCalendarTimezone class Tests(twistedcaldav.test.util.TestCase): def test_allQuery(self): xml_element = caldavxml.Filter( caldavxml.ComponentFilter( **{"name": "VCALENDAR"} ) ) calendarqueryfilter.Filter(xml_element) def test_simpleSummaryRangeQuery(self): xml_element = caldavxml.Filter( caldavxml.ComponentFilter( caldavxml.ComponentFilter( caldavxml.PropertyFilter( caldavxml.TextMatch.fromString("test"), **{"name": "SUMMARY", } ), **{"name": "VEVENT"} ), **{"name": "VCALENDAR"} ) ) calendarqueryfilter.Filter(xml_element) def test_simpleTimeRangeQuery(self): xml_element = caldavxml.Filter( caldavxml.ComponentFilter( caldavxml.ComponentFilter( caldavxml.TimeRange(**{"start": "20060605T160000Z", "end": "20060605T170000Z"}), **{"name": "VEVENT"} ), **{"name": "VCALENDAR"} ) ) calendarqueryfilter.Filter(xml_element) def test_multipleTimeRangeQuery(self): xml_element = caldavxml.Filter( caldavxml.ComponentFilter( caldavxml.ComponentFilter( caldavxml.TimeRange(**{"start": "20060605T160000Z", "end": "20060605T170000Z"}), **{"name": ("VEVENT", "VFREEBUSY", "VAVAILABILITY")} ), **{"name": "VCALENDAR"} ) ) calendarqueryfilter.Filter(xml_element) def test_queryWithTimezone(self): xml_element = caldavxml.Filter( caldavxml.ComponentFilter( caldavxml.ComponentFilter( caldavxml.TimeRange(**{"start": "20060605T160000Z", "end": "20060605T170000Z"}), **{"name": "VEVENT"} ), **{"name": "VCALENDAR"} ) ) filter = calendarqueryfilter.Filter(xml_element) tz = filter.settimezone(TimeZone.fromString("""BEGIN:VCALENDAR PRODID:-//CALENDARSERVER.ORG//NONSGML Version 1//EN VERSION:2.0 BEGIN:VTIMEZONE TZID:America/New_York X-LIC-LOCATION:America/New_York BEGIN:DAYLIGHT TZOFFSETFROM:-0500 TZOFFSETTO:-0400 TZNAME:EDT DTSTART:19180331T020000 RRULE:FREQ=YEARLY;BYMONTH=3;BYDAY=-1SU;UNTIL=19200328T070000Z END:DAYLIGHT BEGIN:STANDARD TZOFFSETFROM:-0400 TZOFFSETTO:-0500 TZNAME:EST DTSTART:19181027T020000 RRULE:FREQ=YEARLY;BYMONTH=10;BYDAY=-1SU;UNTIL=19201031T060000Z END:STANDARD BEGIN:DAYLIGHT TZOFFSETFROM:-0500 TZOFFSETTO:-0400 TZNAME:EDT DTSTART:19210424T020000 RRULE:FREQ=YEARLY;BYMONTH=4;BYDAY=-1SU;UNTIL=19410427T070000Z END:DAYLIGHT BEGIN:STANDARD TZOFFSETFROM:-0400 TZOFFSETTO:-0500 TZNAME:EST DTSTART:19210925T020000 RRULE:FREQ=YEARLY;BYMONTH=9;BYDAY=-1SU;UNTIL=19410928T060000Z END:STANDARD BEGIN:DAYLIGHT TZOFFSETFROM:-0500 TZOFFSETTO:-0400 TZNAME:EDT DTSTART:19460428T020000 RRULE:FREQ=YEARLY;BYMONTH=4;BYDAY=-1SU;UNTIL=19730429T070000Z END:DAYLIGHT BEGIN:STANDARD TZOFFSETFROM:-0400 TZOFFSETTO:-0500 TZNAME:EST DTSTART:19460929T020000 RRULE:FREQ=YEARLY;BYMONTH=9;BYDAY=-1SU;UNTIL=19540926T060000Z END:STANDARD BEGIN:STANDARD TZOFFSETFROM:-0400 TZOFFSETTO:-0500 TZNAME:EST DTSTART:19551030T020000 RRULE:FREQ=YEARLY;BYMONTH=10;BYDAY=-1SU;UNTIL=20061029T060000Z END:STANDARD BEGIN:DAYLIGHT TZOFFSETFROM:-0500 TZOFFSETTO:-0400 TZNAME:EDT DTSTART:19760425T020000 RRULE:FREQ=YEARLY;BYMONTH=4;BYDAY=-1SU;UNTIL=19860427T070000Z END:DAYLIGHT BEGIN:DAYLIGHT TZOFFSETFROM:-0500 TZOFFSETTO:-0400 TZNAME:EDT DTSTART:19870405T020000 RRULE:FREQ=YEARLY;BYMONTH=4;BYDAY=1SU;UNTIL=20060402T070000Z END:DAYLIGHT BEGIN:DAYLIGHT TZOFFSETFROM:-0500 TZOFFSETTO:-0400 TZNAME:EDT DTSTART:20070311T020000 RRULE:FREQ=YEARLY;BYMONTH=3;BYDAY=2SU END:DAYLIGHT BEGIN:STANDARD TZOFFSETFROM:-0400 TZOFFSETTO:-0500 TZNAME:EST DTSTART:20071104T020000 RRULE:FREQ=YEARLY;BYMONTH=11;BYDAY=1SU END:STANDARD BEGIN:STANDARD TZOFFSETFROM:-045602 TZOFFSETTO:-0500 TZNAME:EST DTSTART:18831118T120358 RDATE:18831118T120358 END:STANDARD BEGIN:STANDARD TZOFFSETFROM:-0500 TZOFFSETTO:-0500 TZNAME:EST DTSTART:19200101T000000 RDATE:19200101T000000 RDATE:19420101T000000 RDATE:19460101T000000 RDATE:19670101T000000 END:STANDARD BEGIN:DAYLIGHT TZOFFSETFROM:-0500 TZOFFSETTO:-0400 TZNAME:EWT DTSTART:19420209T020000 RDATE:19420209T020000 END:DAYLIGHT BEGIN:DAYLIGHT TZOFFSETFROM:-0400 TZOFFSETTO:-0400 TZNAME:EPT DTSTART:19450814T190000 RDATE:19450814T190000 END:DAYLIGHT BEGIN:STANDARD TZOFFSETFROM:-0400 TZOFFSETTO:-0500 TZNAME:EST DTSTART:19450930T020000 RDATE:19450930T020000 END:STANDARD BEGIN:DAYLIGHT TZOFFSETFROM:-0500 TZOFFSETTO:-0400 TZNAME:EDT DTSTART:19740106T020000 RDATE:19740106T020000 RDATE:19750223T020000 END:DAYLIGHT END:VTIMEZONE END:VCALENDAR """)) self.assertTrue(isinstance(tz, PyCalendarTimezone)) calendarserver-5.2+dfsg/twistedcaldav/query/test/__init__.py0000644000175000017500000000122112263343324023400 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Tests for the twistedcaldav.query module. """ calendarserver-5.2+dfsg/twistedcaldav/query/addressbookquery.py0000644000175000017500000001060312263343324024254 0ustar rahulrahul## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Convert a addressbook-query into an expression tree. Convert a addressbook-query into a partial SQL statement. """ __version__ = "0.0" __all__ = [ "addressbookquery", "sqladdressbookquery", ] from twistedcaldav.query import expression, sqlgenerator, addressbookqueryfilter # SQL Index column (field) names def addressbookquery(filter, fields): """ Convert the supplied addressbook-query into an expression tree. @param filter: the L{Filter} for the addressbook-query to convert. @return: a L{baseExpression} for the expression tree. """ # Lets assume we have a valid filter from the outset. # Top-level filter contains zero or more prop-filter element if len(filter.children) > 0: return propfilterListExpression(filter.children, fields) else: return expression.allExpression() def propfilterListExpression(propfilters, fields): """ Create an expression for a list of prop-filter elements. @param propfilters: the C{list} of L{ComponentFilter} elements. @return: a L{baseExpression} for the expression tree. """ if len(propfilters) == 1: return propfilterExpression(propfilters[0], fields) else: return expression.orExpression([propfilterExpression(c, fields) for c in propfilters]) def propfilterExpression(propfilter, fields): """ Create an expression for a single prop-filter element. @param propfilter: the L{PropertyFilter} element. @return: a L{baseExpression} for the expression tree. """ # Only handle UID right now if propfilter.filter_name != "UID": raise ValueError # Handle is-not-defined case if not propfilter.defined: # Test for <> != "*" return expression.isExpression(fields["UID"], "", True) # Handle embedded parameters/text-match params = [] for filter in propfilter.filters: if isinstance(filter, addressbookqueryfilter.TextMatch): if filter.match_type == "equals": tm = expression.isnotExpression if filter.negate else expression.isExpression elif filter.match_type == "contains": tm = expression.notcontainsExpression if filter.negate else expression.containsExpression elif filter.match_type == "starts-with": tm = expression.notstartswithExpression if filter.negate else expression.startswithExpression elif filter.match_type == "ends-with": tm = expression.notendswithExpression if filter.negate else expression.endswithExpression params.append(tm(fields[propfilter.filter_name], str(filter.text), True)) else: # No embedded parameters - not right now as our Index does not handle them raise ValueError # Now build return expression if len(params) > 1: if propfilter.propfilter_test == "anyof": return expression.orExpression(params) else: return expression.andExpression(params) elif len(params) == 1: return params[0] else: return None def sqladdressbookquery(filter, addressbookid=None, generator=sqlgenerator.sqlgenerator): """ Convert the supplied addressbook-query into a partial SQL statement. @param filter: the L{Filter} for the addressbook-query to convert. @return: a C{tuple} of (C{str}, C{list}), where the C{str} is the partial SQL statement, and the C{list} is the list of argument substitutions to use with the SQL API execute method. Or return C{None} if it is not possible to create an SQL query to fully match the addressbook-query. """ try: expression = addressbookquery(filter, generator.FIELDS) sql = generator(expression, addressbookid, None) return sql.generate() except ValueError: return None calendarserver-5.2+dfsg/twistedcaldav/query/addressbookqueryfilter.py0000644000175000017500000002214312263343324025464 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Object model of CARDAV:filter element used in an addressbook-query. """ __all__ = [ "Filter", ] from twext.python.log import Logger from twistedcaldav.carddavxml import carddav_namespace from twistedcaldav.vcard import Property log = Logger() class FilterBase(object): """ Determines which matching components are returned. """ def __init__(self, xml_element): self.xmlelement = xml_element def match(self, item, access=None): raise NotImplementedError def valid(self, level=0): raise NotImplementedError class Filter(FilterBase): """ Determines which matching components are returned. """ def __init__(self, xml_element): super(Filter, self).__init__(xml_element) filter_test = xml_element.attributes.get("test", "anyof") if filter_test not in ("anyof", "allof"): raise ValueError("Test must be only one of anyof, allof") self.filter_test = filter_test self.children = [PropertyFilter(child) for child in xml_element.children] def match(self, vcard): """ Returns True if the given address property matches this filter, False otherwise. Empty element means always match. """ if len(self.children) > 0: allof = self.filter_test == "allof" for propfilter in self.children: if allof != propfilter._match(vcard): return not allof return allof else: return True def valid(self): """ Indicate whether this filter element's structure is valid wrt vCard data object model. @return: True if valid, False otherwise """ # Test each property for propfilter in self.children: if not propfilter.valid(): return False else: return True class FilterChildBase(FilterBase): """ CardDAV filter element. """ def __init__(self, xml_element): super(FilterChildBase, self).__init__(xml_element) qualifier = None filters = [] for child in xml_element.children: qname = child.qname() if qname in ( (carddav_namespace, "is-not-defined"), ): if qualifier is not None: raise ValueError("Only one of CardDAV:is-not-defined allowed") qualifier = IsNotDefined(child) elif qname == (carddav_namespace, "text-match"): filters.append(TextMatch(child)) elif qname == (carddav_namespace, "param-filter"): filters.append(ParameterFilter(child)) else: raise ValueError("Unknown child element: %s" % (qname,)) if qualifier and isinstance(qualifier, IsNotDefined) and (len(filters) != 0): raise ValueError("No other tests allowed when CardDAV:is-not-defined is present") if xml_element.qname() == (carddav_namespace, "prop-filter"): propfilter_test = xml_element.attributes.get("test", "anyof") if propfilter_test not in ("anyof", "allof"): raise ValueError("Test must be only one of anyof, allof") else: propfilter_test = "anyof" self.propfilter_test = propfilter_test self.qualifier = qualifier self.filters = filters self.filter_name = xml_element.attributes["name"] if isinstance(self.filter_name, unicode): self.filter_name = self.filter_name.encode("utf-8") self.defined = not self.qualifier or not isinstance(qualifier, IsNotDefined) def match(self, item): """ Returns True if the given address book item (either a property or parameter value) matches this filter, False otherwise. """ # Always return True for the is-not-defined case as the result of this will # be negated by the caller if not self.defined: return True if self.qualifier and not self.qualifier.match(item): return False if len(self.filters) > 0: allof = self.propfilter_test == "allof" for filter in self.filters: if allof != filter._match(item): return not allof return allof else: return True class PropertyFilter (FilterChildBase): """ Limits a search to specific properties. """ def _match(self, vcard): # At least one property must match (or is-not-defined is set) for property in vcard.properties(): if property.name().upper() == self.filter_name.upper() and self.match(property): break else: return not self.defined return self.defined def valid(self): """ Indicate whether this filter element's structure is valid wrt vCard data object model. @return: True if valid, False otherwise """ # No tests return True class ParameterFilter (FilterChildBase): """ Limits a search to specific parameters. """ def _match(self, property): # At least one parameter must match (or is-not-defined is set) result = not self.defined for parameterName in property.parameterNames(): if parameterName.upper() == self.filter_name.upper() and self.match([property.parameterValues(parameterName)]): result = self.defined break return result class IsNotDefined (FilterBase): """ Specifies that the named iCalendar item does not exist. """ def match(self, component, access=None): # Oddly, this needs always to return True so that it appears there is # a match - but we then "negate" the result if is-not-defined is set. # Actually this method should never be called as we special case the # is-not-defined option. return True class TextMatch (FilterBase): """ Specifies a substring match on a property or parameter value. """ def __init__(self, xml_element): super(TextMatch, self).__init__(xml_element) self.text = str(xml_element) if "collation" in xml_element.attributes: self.collation = xml_element.attributes["collation"] else: self.collation = "i;unicode-casemap" if "negate-condition" in xml_element.attributes: self.negate = xml_element.attributes["negate-condition"] if self.negate not in ("yes", "no"): self.negate = "no" self.negate = {"yes": True, "no": False}[self.negate] else: self.negate = False if "match-type" in xml_element.attributes: self.match_type = xml_element.attributes["match-type"] if self.match_type not in ( "equals", "contains", "starts-with", "ends-with", ): self.match_type = "contains" else: self.match_type = "contains" def _match(self, item): """ Match the text for the item. If the item is a property, then match the property value, otherwise it may be a list of parameter values - try to match anyone of those """ if item is None: return False if isinstance(item, Property): values = [item.strvalue()] else: values = item test = unicode(self.text, "utf-8").lower() def _textCompare(s): # Currently ignores the collation and does caseless matching s = s.lower() if self.match_type == "equals": return s == test elif self.match_type == "contains": return s.find(test) != -1 elif self.match_type == "starts-with": return s.startswith(test) elif self.match_type == "ends-with": return s.endswith(test) else: return False for value in values: # NB Its possible that we have a text list value which appears as a Python list, # so we need to check for that and iterate over the list. if isinstance(value, list): for subvalue in value: if _textCompare(unicode(subvalue, "utf-8")): return not self.negate else: if _textCompare(unicode(value, "utf-8")): return not self.negate return self.negate calendarserver-5.2+dfsg/twistedcaldav/query/calendarqueryfilter.py0000644000175000017500000006407512263343324024747 0ustar rahulrahul## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Object model of CALDAV:filter element used in a calendar-query. """ __all__ = [ "Filter", ] from twext.python.log import Logger from twistedcaldav.caldavxml import caldav_namespace, CalDAVTimeZoneElement from twistedcaldav.dateops import timeRangesOverlap from twistedcaldav.ical import Component, Property from pycalendar.datetime import PyCalendarDateTime from pycalendar.timezone import PyCalendarTimezone log = Logger() class FilterBase(object): """ Determines which matching components are returned. """ def __init__(self, xml_element): self.xmlelement = xml_element def match(self, item, access=None): raise NotImplementedError def valid(self, level=0): raise NotImplementedError class Filter(FilterBase): """ Determines which matching components are returned. """ def __init__(self, xml_element): super(Filter, self).__init__(xml_element) # One comp-filter element must be present if len(xml_element.children) != 1 or xml_element.children[0].qname() != (caldav_namespace, "comp-filter"): raise ValueError("Invalid CALDAV:filter element: %s" % (xml_element,)) self.child = ComponentFilter(xml_element.children[0]) def match(self, component, access=None): """ Returns True if the given calendar component matches this filter, False otherwise. """ # We only care about certain access restrictions. if access not in (Component.ACCESS_CONFIDENTIAL, Component.ACCESS_RESTRICTED): access = None # We need to prepare ourselves for a time-range query by pre-calculating # the set of instances up to the latest time-range limit. That way we can # avoid having to do some form of recurrence expansion for each query sub-part. maxend, isStartTime = self.getmaxtimerange() if maxend: if isStartTime: if component.isRecurringUnbounded(): # Unbounded recurrence is always within a start-only time-range instances = None else: # Expand the instances up to infinity instances = component.expandTimeRanges(PyCalendarDateTime(2100, 1, 1, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), ignoreInvalidInstances=True) else: instances = component.expandTimeRanges(maxend, ignoreInvalidInstances=True) else: instances = None self.child.setInstances(instances) # contains exactly one return self.child.match(component, access) def valid(self): """ Indicate whether this filter element's structure is valid wrt iCalendar data object model. @return: True if valid, False otherwise """ # Must have one child element for VCALENDAR return self.child.valid(0) def settimezone(self, tzelement): """ Set the default timezone to use with this query. @param calendar: a L{Component} for the VCALENDAR containing the one VTIMEZONE that we want @return: the L{PyCalendarTimezone} derived from the VTIMEZONE or utc. """ if tzelement is None: tz = None elif isinstance(tzelement, CalDAVTimeZoneElement): tz = tzelement.gettimezone() elif isinstance(tzelement, Component): tz = tzelement.gettimezone() if tz is None: tz = PyCalendarTimezone(utc=True) self.child.settzinfo(tz) return tz def getmaxtimerange(self): """ Get the date farthest into the future in any time-range elements """ return self.child.getmaxtimerange(None, False) def getmintimerange(self): """ Get the date farthest into the past in any time-range elements. That is either the start date, or if start is not present, the end date. """ return self.child.getmintimerange(None, False) class FilterChildBase(FilterBase): """ CalDAV filter element. """ def __init__(self, xml_element): super(FilterChildBase, self).__init__(xml_element) qualifier = None filters = [] for child in xml_element.children: qname = child.qname() if qname in ( (caldav_namespace, "is-not-defined"), (caldav_namespace, "time-range"), (caldav_namespace, "text-match"), ): if qualifier is not None: raise ValueError("Only one of CalDAV:time-range, CalDAV:text-match allowed") if qname == (caldav_namespace, "is-not-defined"): qualifier = IsNotDefined(child) elif qname == (caldav_namespace, "time-range"): qualifier = TimeRange(child) elif qname == (caldav_namespace, "text-match"): qualifier = TextMatch(child) elif qname == (caldav_namespace, "comp-filter"): filters.append(ComponentFilter(child)) elif qname == (caldav_namespace, "prop-filter"): filters.append(PropertyFilter(child)) elif qname == (caldav_namespace, "param-filter"): filters.append(ParameterFilter(child)) else: raise ValueError("Unknown child element: %s" % (qname,)) if qualifier and isinstance(qualifier, IsNotDefined) and (len(filters) != 0): raise ValueError("No other tests allowed when CalDAV:is-not-defined is present") self.qualifier = qualifier self.filters = filters self.filter_name = xml_element.attributes["name"] if isinstance(self.filter_name, unicode): self.filter_name = self.filter_name.encode("utf-8") self.defined = not self.qualifier or not isinstance(qualifier, IsNotDefined) filter_test = xml_element.attributes.get("test", "allof") if filter_test not in ("anyof", "allof"): raise ValueError("Test must be only one of anyof, allof") self.filter_test = filter_test def match(self, item, access=None): """ Returns True if the given calendar item (either a component, property or parameter value) matches this filter, False otherwise. """ # Always return True for the is-not-defined case as the result of this will # be negated by the caller if not self.defined: return True if self.qualifier and not self.qualifier.match(item, access): return False if len(self.filters) > 0: allof = self.filter_test == "allof" for filter in self.filters: if allof != filter._match(item, access): return not allof return allof else: return True class ComponentFilter (FilterChildBase): """ Limits a search to only the chosen component types. """ def match(self, item, access): """ Returns True if the given calendar item (which is a component) matches this filter, False otherwise. This specialization uses the instance matching option of the time-range filter to minimize instance expansion. """ # Always return True for the is-not-defined case as the result of this will # be negated by the caller if not self.defined: return True if self.qualifier and not self.qualifier.matchinstance(item, self.instances): return False if len(self.filters) > 0: allof = self.filter_test == "allof" for filter in self.filters: if allof != filter._match(item, access): return not allof return allof else: return True def _match(self, component, access): # At least one subcomponent must match (or is-not-defined is set) for subcomponent in component.subcomponents(): # If access restrictions are in force, restrict matching to specific components only. # In particular do not match VALARM. if access and subcomponent.name() not in ("VEVENT", "VTODO", "VJOURNAL", "VFREEBUSY", "VTIMEZONE",): continue # Try to match the component name if isinstance(self.filter_name, str): if subcomponent.name() != self.filter_name: continue else: if subcomponent.name() not in self.filter_name: continue if self.match(subcomponent, access): break else: return not self.defined return self.defined def setInstances(self, instances): """ Give the list of instances to each comp-filter element. @param instances: the list of instances. """ self.instances = instances for compfilter in [x for x in self.filters if isinstance(x, ComponentFilter)]: compfilter.setInstances(instances) def valid(self, level): """ Indicate whether this filter element's structure is valid wrt iCalendar data object model. @param level: the nesting level of this filter element, 0 being the top comp-filter. @return: True if valid, False otherwise """ # Check for time-range timerange = self.qualifier and isinstance(self.qualifier, TimeRange) if level == 0: # Must have VCALENDAR at the top if (self.filter_name != "VCALENDAR") or timerange: log.info("Top-level comp-filter must be VCALENDAR, instead: %s" % (self.filter_name,)) return False elif level == 1: # Disallow VCALENDAR, VALARM, STANDARD, DAYLIGHT, AVAILABLE at the top, everything else is OK if self.filter_name in ("VCALENDAR", "VALARM", "STANDARD", "DAYLIGHT", "AVAILABLE"): log.info("comp-filter wrong component type: %s" % (self.filter_name,)) return False # time-range only on VEVENT, VTODO, VJOURNAL, VFREEBUSY, VAVAILABILITY if timerange and self.filter_name not in ("VEVENT", "VTODO", "VJOURNAL", "VFREEBUSY", "VAVAILABILITY"): log.info("time-range cannot be used with component %s" % (self.filter_name,)) return False elif level == 2: # Disallow VCALENDAR, VTIMEZONE, VEVENT, VTODO, VJOURNAL, VFREEBUSY, VAVAILABILITY at the top, everything else is OK if (self.filter_name in ("VCALENDAR", "VTIMEZONE", "VEVENT", "VTODO", "VJOURNAL", "VFREEBUSY", "VAVAILABILITY")): log.info("comp-filter wrong sub-component type: %s" % (self.filter_name,)) return False # time-range only on VALARM, AVAILABLE if timerange and self.filter_name not in ("VALARM", "AVAILABLE",): log.info("time-range cannot be used with sub-component %s" % (self.filter_name,)) return False else: # Disallow all standard iCal components anywhere else if (self.filter_name in ("VCALENDAR", "VTIMEZONE", "VEVENT", "VTODO", "VJOURNAL", "VFREEBUSY", "VALARM", "STANDARD", "DAYLIGHT", "AVAILABLE")) or timerange: log.info("comp-filter wrong standard component type: %s" % (self.filter_name,)) return False # Test each property for propfilter in [x for x in self.filters if isinstance(x, PropertyFilter)]: if not propfilter.valid(): return False # Test each component for compfilter in [x for x in self.filters if isinstance(x, ComponentFilter)]: if not compfilter.valid(level + 1): return False # Test the time-range if timerange: if not self.qualifier.valid(): return False return True def settzinfo(self, tzinfo): """ Set the default timezone to use with this query. @param tzinfo: a L{PyCalendarTimezone} to use. """ # Give tzinfo to any TimeRange we have if isinstance(self.qualifier, TimeRange): self.qualifier.settzinfo(tzinfo) # Pass down to sub components/properties for x in self.filters: x.settzinfo(tzinfo) def getmaxtimerange(self, currentMaximum, currentIsStartTime): """ Get the date farthest into the future in any time-range elements @param currentMaximum: current future value to compare with @type currentMaximum: L{PyCalendarDateTime} """ # Give tzinfo to any TimeRange we have isStartTime = False if isinstance(self.qualifier, TimeRange): isStartTime = self.qualifier.end is None compareWith = self.qualifier.start if isStartTime else self.qualifier.end if currentMaximum is None or currentMaximum < compareWith: currentMaximum = compareWith currentIsStartTime = isStartTime # Pass down to sub components/properties for x in self.filters: currentMaximum, currentIsStartTime = x.getmaxtimerange(currentMaximum, currentIsStartTime) return currentMaximum, currentIsStartTime def getmintimerange(self, currentMinimum, currentIsEndTime): """ Get the date farthest into the past in any time-range elements. That is either the start date, or if start is not present, the end date. """ # Give tzinfo to any TimeRange we have isEndTime = False if isinstance(self.qualifier, TimeRange): isEndTime = self.qualifier.start is None compareWith = self.qualifier.end if isEndTime else self.qualifier.start if currentMinimum is None or currentMinimum > compareWith: currentMinimum = compareWith currentIsEndTime = isEndTime # Pass down to sub components/properties for x in self.filters: currentMinimum, currentIsEndTime = x.getmintimerange(currentMinimum, currentIsEndTime) return currentMinimum, currentIsEndTime class PropertyFilter (FilterChildBase): """ Limits a search to specific properties. """ def _match(self, component, access): # When access restriction is in force, we need to only allow matches against the properties # allowed by the access restriction level. if access: allowedProperties = Component.confidentialPropertiesMap.get(component.name(), None) if allowedProperties and access == Component.ACCESS_RESTRICTED: allowedProperties += Component.extraRestrictedProperties else: allowedProperties = None # At least one property must match (or is-not-defined is set) for property in component.properties(): # Apply access restrictions, if any. if allowedProperties is not None and property.name().upper() not in allowedProperties: continue if property.name().upper() == self.filter_name.upper() and self.match(property, access): break else: return not self.defined return self.defined def valid(self): """ Indicate whether this filter element's structure is valid wrt iCalendar data object model. @return: True if valid, False otherwise """ # Check for time-range timerange = self.qualifier and isinstance(self.qualifier, TimeRange) # time-range only on COMPLETED, CREATED, DTSTAMP, LAST-MODIFIED if timerange and self.filter_name.upper() not in ("COMPLETED", "CREATED", "DTSTAMP", "LAST-MODIFIED"): log.info("time-range cannot be used with property %s" % (self.filter_name,)) return False # Test the time-range if timerange: if not self.qualifier.valid(): return False # No other tests return True def settzinfo(self, tzinfo): """ Set the default timezone to use with this query. @param tzinfo: a L{PyCalendarTimezone} to use. """ # Give tzinfo to any TimeRange we have if isinstance(self.qualifier, TimeRange): self.qualifier.settzinfo(tzinfo) def getmaxtimerange(self, currentMaximum, currentIsStartTime): """ Get the date farthest into the future in any time-range elements @param currentMaximum: current future value to compare with @type currentMaximum: L{PyCalendarDateTime} """ # Give tzinfo to any TimeRange we have isStartTime = False if isinstance(self.qualifier, TimeRange): isStartTime = self.qualifier.end is None compareWith = self.qualifier.start if isStartTime else self.qualifier.end if currentMaximum is None or currentMaximum < compareWith: currentMaximum = compareWith currentIsStartTime = isStartTime return currentMaximum, currentIsStartTime def getmintimerange(self, currentMinimum, currentIsEndTime): """ Get the date farthest into the past in any time-range elements. That is either the start date, or if start is not present, the end date. """ # Give tzinfo to any TimeRange we have isEndTime = False if isinstance(self.qualifier, TimeRange): isEndTime = self.qualifier.start is None compareWith = self.qualifier.end if isEndTime else self.qualifier.start if currentMinimum is None or currentMinimum > compareWith: currentMinimum = compareWith currentIsEndTime = isEndTime return currentMinimum, currentIsEndTime class ParameterFilter (FilterChildBase): """ Limits a search to specific parameters. """ def _match(self, property, access): # At least one parameter must match (or is-not-defined is set) result = not self.defined for parameterName in property.parameterNames(): if parameterName.upper() == self.filter_name.upper() and self.match([property.parameterValue(parameterName)], access): result = self.defined break return result class IsNotDefined (FilterBase): """ Specifies that the named iCalendar item does not exist. """ def match(self, component, access=None): # Oddly, this needs always to return True so that it appears there is # a match - but we then "negate" the result if is-not-defined is set. # Actually this method should never be called as we special case the # is-not-defined option. return True class TextMatch (FilterBase): """ Specifies a substring match on a property or parameter value. (CalDAV-access-09, section 9.6.4) """ def __init__(self, xml_element): super(TextMatch, self).__init__(xml_element) self.text = str(xml_element) if "caseless" in xml_element.attributes: caseless = xml_element.attributes["caseless"] if caseless == "yes": self.caseless = True elif caseless == "no": self.caseless = False else: self.caseless = True if "negate-condition" in xml_element.attributes: negate = xml_element.attributes["negate-condition"] if negate == "yes": self.negate = True elif negate == "no": self.negate = False else: self.negate = False if "match-type" in xml_element.attributes: self.match_type = xml_element.attributes["match-type"] if self.match_type not in ( "equals", "contains", "starts-with", "ends-with", ): self.match_type = "contains" else: self.match_type = "contains" def match(self, item, access): """ Match the text for the item. If the item is a property, then match the property value, otherwise it may be a list of parameter values - try to match anyone of those """ if item is None: return False if isinstance(item, Property): values = [item.strvalue()] else: values = item test = unicode(self.text, "utf-8") if self.caseless: test = test.lower() def _textCompare(s): if self.caseless: s = s.lower() if self.match_type == "equals": return s == test elif self.match_type == "contains": return s.find(test) != -1 elif self.match_type == "starts-with": return s.startswith(test) elif self.match_type == "ends-with": return s.endswith(test) else: return False for value in values: # NB Its possible that we have a text list value which appears as a Python list, # so we need to check for that and iterate over the list. if isinstance(value, list): for subvalue in value: if _textCompare(unicode(subvalue, "utf-8")): return not self.negate else: if _textCompare(unicode(value, "utf-8")): return not self.negate return self.negate class TimeRange (FilterBase): """ Specifies a time for testing components against. """ def __init__(self, xml_element): super(TimeRange, self).__init__(xml_element) # One of start or end must be present if "start" not in xml_element.attributes and "end" not in xml_element.attributes: raise ValueError("One of 'start' or 'end' must be present in CALDAV:time-range") self.start = PyCalendarDateTime.parseText(xml_element.attributes["start"]) if "start" in xml_element.attributes else None self.end = PyCalendarDateTime.parseText(xml_element.attributes["end"]) if "end" in xml_element.attributes else None self.tzinfo = None def settzinfo(self, tzinfo): """ Set the default timezone to use with this query. @param tzinfo: a L{PyCalendarTimezone} to use. """ # Give tzinfo to any TimeRange we have self.tzinfo = tzinfo def valid(self, level=0): """ Indicate whether the time-range is valid (must be date-time in UTC). @return: True if valid, False otherwise """ if self.start is not None and self.start.isDateOnly(): log.info("start attribute in is not a date-time: %s" % (self.start,)) return False if self.end is not None and self.end.isDateOnly(): log.info("end attribute in is not a date-time: %s" % (self.end,)) return False if self.start is not None and not self.start.utc(): log.info("start attribute in is not UTC: %s" % (self.start,)) return False if self.end is not None and not self.end.utc(): log.info("end attribute in is not UTC: %s" % (self.end,)) return False # No other tests return True def match(self, property, access=None): """ NB This is only called when doing a time-range match on a property. """ if property is None: return False else: return property.containsTimeRange(self.start, self.end, self.tzinfo) def matchinstance(self, component, instances): """ Test whether this time-range element causes a match to the specified component using the specified set of instances to determine the expanded time ranges. @param component: the L{Component} to test. @param instances: the list of expanded instances. @return: True if the time-range query matches, False otherwise. """ if component is None: return False assert instances is not None or self.end is None, "Failure to expand instance for time-range filter: %r" % (self,) # Special case open-ended unbounded if instances is None: if component.getRecurrenceIDUTC() is None: return True else: # See if the overridden component's start is past the start start, _ignore_end = component.getEffectiveStartEnd() if start is None: return True else: return start >= self.start # Handle alarms as a special case alarms = (component.name() == "VALARM") if alarms: testcomponent = component._parent else: testcomponent = component for key in instances: instance = instances[key] # First make sure components match if not testcomponent.same(instance.component): continue if alarms: # Get all the alarm triggers for this instance and test each one triggers = instance.getAlarmTriggers() for trigger in triggers: if timeRangesOverlap(trigger, None, self.start, self.end, self.tzinfo): return True else: # Regular instance overlap test if timeRangesOverlap(instance.start, instance.end, self.start, self.end, self.tzinfo): return True return False calendarserver-5.2+dfsg/twistedcaldav/query/calendarquery.py0000644000175000017500000002133012263343324023524 0ustar rahulrahul## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ Convert a calendar-query into an expression tree. Convert a calendar-query into a partial SQL statement. """ __version__ = "0.0" __all__ = [ "calendarquery", "sqlcalendarquery", ] from twistedcaldav.dateops import floatoffset, pyCalendarTodatetime from twistedcaldav.query import expression, sqlgenerator, calendarqueryfilter # SQL Index column (field) names def calendarquery(filter, fields): """ Convert the supplied calendar-query into an expression tree. @param filter: the L{Filter} for the calendar-query to convert. @return: a L{baseExpression} for the expression tree. """ # Lets assume we have a valid filter from the outset. # Top-level filter contains exactly one comp-filter element assert filter.child is not None vcalfilter = filter.child assert isinstance(vcalfilter, calendarqueryfilter.ComponentFilter) assert vcalfilter.filter_name == "VCALENDAR" if len(vcalfilter.filters) > 0: # Determine logical expression grouping logical = expression.andExpression if vcalfilter.filter_test == "allof" else expression.orExpression # Only comp-filters are handled for _ignore in [x for x in vcalfilter.filters if not isinstance(x, calendarqueryfilter.ComponentFilter)]: raise ValueError return compfilterListExpression(vcalfilter.filters, fields, logical) else: return expression.allExpression() def compfilterListExpression(compfilters, fields, logical): """ Create an expression for a list of comp-filter elements. @param compfilters: the C{list} of L{ComponentFilter} elements. @return: a L{baseExpression} for the expression tree. """ if len(compfilters) == 1: return compfilterExpression(compfilters[0], fields) else: return logical([compfilterExpression(c, fields) for c in compfilters]) def compfilterExpression(compfilter, fields): """ Create an expression for a single comp-filter element. @param compfilter: the L{ComponentFilter} element. @return: a L{baseExpression} for the expression tree. """ # Handle is-not-defined case if not compfilter.defined: # Test for TYPE != <> return expression.isnotExpression(fields["TYPE"], compfilter.filter_name, True) # Determine logical expression grouping logical = expression.andExpression if compfilter.filter_test == "allof" else expression.orExpression expressions = [] if isinstance(compfilter.filter_name, str): expressions.append(expression.isExpression(fields["TYPE"], compfilter.filter_name, True)) else: expressions.append(expression.inExpression(fields["TYPE"], compfilter.filter_name, True)) # Handle time-range if compfilter.qualifier and isinstance(compfilter.qualifier, calendarqueryfilter.TimeRange): start, end, startfloat, endfloat = getTimerangeArguments(compfilter.qualifier) expressions.append(expression.timerangeExpression(start, end, startfloat, endfloat)) # Handle properties - we can only do UID right now props = [] for p in [x for x in compfilter.filters if isinstance(x, calendarqueryfilter.PropertyFilter)]: props.append(propfilterExpression(p, fields)) if len(props) > 1: propsExpression = logical(props) elif len(props) == 1: propsExpression = props[0] else: propsExpression = None # Handle embedded components - we do not right now as our Index does not handle them comps = [] for _ignore in [x for x in compfilter.filters if isinstance(x, calendarqueryfilter.ComponentFilter)]: raise ValueError if len(comps) > 1: compsExpression = logical(comps) elif len(comps) == 1: compsExpression = comps[0] else: compsExpression = None # Now build compound expression if ((propsExpression is not None) and (compsExpression is not None)): expressions.append(logical([propsExpression, compsExpression])) elif propsExpression is not None: expressions.append(propsExpression) elif compsExpression is not None: expressions.append(compsExpression) # Now build return expression return expression.andExpression(expressions) def propfilterExpression(propfilter, fields): """ Create an expression for a single prop-filter element. @param propfilter: the L{PropertyFilter} element. @return: a L{baseExpression} for the expression tree. """ # Only handle UID right now if propfilter.filter_name != "UID": raise ValueError # Handle is-not-defined case if not propfilter.defined: # Test for <> != "*" return expression.isExpression(fields["UID"], "", True) # Determine logical expression grouping logical = expression.andExpression if propfilter.filter_test == "allof" else expression.orExpression # Handle time-range - we cannot do this with our Index right now if propfilter.qualifier and isinstance(propfilter.qualifier, calendarqueryfilter.TimeRange): raise ValueError # Handle text-match tm = None if propfilter.qualifier and isinstance(propfilter.qualifier, calendarqueryfilter.TextMatch): if propfilter.qualifier.match_type == "equals": tm = expression.isnotExpression if propfilter.qualifier.negate else expression.isExpression elif propfilter.qualifier.match_type == "contains": tm = expression.notcontainsExpression if propfilter.qualifier.negate else expression.containsExpression elif propfilter.qualifier.match_type == "starts-with": tm = expression.notstartswithExpression if propfilter.qualifier.negate else expression.startswithExpression elif propfilter.qualifier.match_type == "ends-with": tm = expression.notendswithExpression if propfilter.qualifier.negate else expression.endswithExpression tm = tm(fields[propfilter.filter_name], propfilter.qualifier.text, propfilter.qualifier.caseless) # Handle embedded parameters - we do not right now as our Index does not handle them params = [] for _ignore in propfilter.filters: raise ValueError if len(params) > 1: paramsExpression = logical(params) elif len(params) == 1: paramsExpression = params[0] else: paramsExpression = None # Now build return expression if (tm is not None) and (paramsExpression is not None): return logical([tm, paramsExpression]) elif tm is not None: return tm elif paramsExpression is not None: return paramsExpression else: return None def getTimerangeArguments(timerange): """ Get start/end and floating start/end (adjusted for timezone offset) values from the supplied time-range test. @param timerange: the L{TimeRange} used in the query. @return: C{tuple} of C{str} for start, end, startfloat, endfloat """ # Start/end in UTC start = timerange.start end = timerange.end # Get timezone tzinfo = timerange.tzinfo # Now force to floating UTC startfloat = floatoffset(start, tzinfo) if start else None endfloat = floatoffset(end, tzinfo) if end else None return ( pyCalendarTodatetime(start) if start else None, pyCalendarTodatetime(end) if end else None, pyCalendarTodatetime(startfloat) if startfloat else None, pyCalendarTodatetime(endfloat) if endfloat else None, ) def sqlcalendarquery(filter, calendarid=None, userid=None, freebusy=False, generator=sqlgenerator.sqlgenerator): """ Convert the supplied calendar-query into a partial SQL statement. @param filter: the L{Filter} for the calendar-query to convert. @return: a C{tuple} of (C{str}, C{list}), where the C{str} is the partial SQL statement, and the C{list} is the list of argument substitutions to use with the SQL API execute method. Or return C{None} if it is not possible to create an SQL query to fully match the calendar-query. """ try: expression = calendarquery(filter, generator.FIELDS) sql = generator(expression, calendarid, userid, freebusy) return sql.generate() except ValueError: return None calendarserver-5.2+dfsg/twistedcaldav/query/__init__.py0000644000175000017500000000120312263343324022421 0ustar rahulrahul## # Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## """ CalDAV and CardDAV queries. """ calendarserver-5.2+dfsg/twistedcaldav/test/0000755000175000017500000000000012322625316020126 5ustar rahulrahulcalendarserver-5.2+dfsg/twistedcaldav/test/test_caldavxml.py0000644000175000017500000000350212263343324023512 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twistedcaldav import caldavxml import twistedcaldav.test.util class CustomXML (twistedcaldav.test.util.TestCase): def test_TimeRange(self): self.assertRaises(ValueError, caldavxml.CalDAVTimeRangeElement) tr = caldavxml.CalDAVTimeRangeElement(start="20110201T120000Z") self.assertTrue(tr.valid()) tr = caldavxml.CalDAVTimeRangeElement(start="20110201T120000") self.assertFalse(tr.valid()) tr = caldavxml.CalDAVTimeRangeElement(start="20110201") self.assertFalse(tr.valid()) tr = caldavxml.CalDAVTimeRangeElement(end="20110201T120000Z") self.assertTrue(tr.valid()) tr = caldavxml.CalDAVTimeRangeElement(end="20110201T120000") self.assertFalse(tr.valid()) tr = caldavxml.CalDAVTimeRangeElement(end="20110201") self.assertFalse(tr.valid()) tr = caldavxml.CalDAVTimeRangeElement(start="20110201T120000Z", end="20110202T120000Z") self.assertTrue(tr.valid()) tr = caldavxml.CalDAVTimeRangeElement(start="20110201T120000Z", end="20110202T120000") self.assertFalse(tr.valid()) tr = caldavxml.CalDAVTimeRangeElement(start="20110201T120000Z", end="20110202") self.assertFalse(tr.valid()) calendarserver-5.2+dfsg/twistedcaldav/test/test_stdconfig.py0000644000175000017500000001273612263343324023530 0ustar rahulrahul# -*- coding: utf-8 -*- ## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from cStringIO import StringIO from twext.python.filepath import CachingFilePath as FilePath from twisted.trial.unittest import TestCase from twistedcaldav.config import Config, ConfigDict from twistedcaldav.stdconfig import NoUnicodePlistParser, PListConfigProvider,\ _updateDataStore, _updateMultiProcess import twistedcaldav.stdconfig nonASCIIValue = "→←" nonASCIIPlist = "%s" % ( nonASCIIValue, ) nonASCIIConfigPList = """ DataRoot %s """ % (nonASCIIValue,) class ConfigParsingTests(TestCase): """ Tests to verify the behavior of the configuration parser. """ def test_noUnicodePListParser(self): """ L{NoUnicodePlistParser.parse} retrieves non-ASCII property list values as (UTF-8 encoded) 'str' objects, so that a single type is consistently used regardless of the input data. """ parser = NoUnicodePlistParser() self.assertEquals(parser.parse(StringIO(nonASCIIPlist)), nonASCIIValue) def test_parseNonASCIIConfig(self): """ Non-ASCII s found as part of a configuration file will be retrieved as UTF-8 encoded 'str' objects, as parsed by L{NoUnicodePlistParser}. """ cfg = Config(PListConfigProvider({"DataRoot": ""})) tempfile = FilePath(self.mktemp()) tempfile.setContent(nonASCIIConfigPList) cfg.load(tempfile.path) self.assertEquals(cfg.DataRoot, nonASCIIValue) def test_relativeDefaultPaths(self): """ The paths specified in the default configuration should be interpreted as relative to the paths specified in the configuration file. """ cfg = Config(PListConfigProvider( {"AccountingLogRoot": "some-path", "LogRoot": "should-be-ignored"})) cfg.addPostUpdateHooks([_updateDataStore]) tempfile = FilePath(self.mktemp()) tempfile.setContent("" "LogRoot/some/root" "") cfg.load(tempfile.path) self.assertEquals(cfg.AccountingLogRoot, "/some/root/some-path") tempfile.setContent("" "LogRoot/other/root" "") cfg.load(tempfile.path) self.assertEquals(cfg.AccountingLogRoot, "/other/root/some-path") def test_includes(self): plist1 = """ ServerRoot /root DocumentRoot defaultdoc DataRoot defaultdata ConfigRoot defaultconfig LogRoot defaultlog RunRoot defaultrun Includes %s """ plist2 = """ DataRoot overridedata """ tempfile2 = FilePath(self.mktemp()) tempfile2.setContent(plist2) tempfile1 = FilePath(self.mktemp()) tempfile1.setContent(plist1 % (tempfile2.path,)) cfg = Config(PListConfigProvider({ "ServerRoot": "", "DocumentRoot": "", "DataRoot": "", "ConfigRoot": "", "LogRoot": "", "RunRoot": "", "Includes": [], })) cfg.addPostUpdateHooks([_updateDataStore]) cfg.load(tempfile1.path) self.assertEquals(cfg.DocumentRoot, "/root/overridedata/defaultdoc") self.assertEquals(cfg.DataRoot, "/root/overridedata") def test_updateDataStore(self): configDict = { "ServerRoot" : "/a/b/c/", } _updateDataStore(configDict) self.assertEquals(configDict["ServerRoot"], "/a/b/c") def test_updateMultiProcess(self): def stubProcessCount(*args): return 3 self.patch(twistedcaldav.stdconfig, "computeProcessCount", stubProcessCount) configDict = ConfigDict({ "MultiProcess" : { "ProcessCount" : 0, "MinProcessCount" : 2, "PerCPU" : 1, "PerGB" : 1, }, "Postgres" : { "ExtraConnections" : 5, "BuffersToConnectionsRatio" : 1.5, }, "SharedConnectionPool" : False, "MaxDBConnectionsPerPool" : 10, }) _updateMultiProcess(configDict) self.assertEquals(45, configDict.Postgres.MaxConnections) self.assertEquals(67, configDict.Postgres.SharedBuffers) calendarserver-5.2+dfsg/twistedcaldav/test/test_sql.py0000644000175000017500000002070512263343324022342 0ustar rahulrahul## # Copyright (c) 2007-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twistedcaldav.sql import AbstractSQLDatabase import twistedcaldav.test.util from threading import Thread import time import os class SQL (twistedcaldav.test.util.TestCase): """ Test abstract SQL DB class """ class TestDB(AbstractSQLDatabase): def __init__(self, path, persistent=False, autocommit=False, version="1"): self.version = version super(SQL.TestDB, self).__init__(path, persistent, autocommit=autocommit) def _db_version(self): """ @return: the schema version assigned to this index. """ return self.version def _db_type(self): """ @return: the collection type assigned to this index. """ return "TESTTYPE" def _db_init_data_tables(self, q): """ Initialise the underlying database tables. @param q: a database cursor to use. """ # # TESTTYPE table # q.execute( """ create table TESTTYPE ( KEY text unique, VALUE text ) """ ) class TestDBRecreateUpgrade(TestDB): class RecreateDBException(Exception): pass class UpgradeDBException(Exception): pass def __init__(self, path, persistent=False, autocommit=False): super(SQL.TestDBRecreateUpgrade, self).__init__(path, persistent, autocommit=autocommit, version="2") def _db_recreate(self, do_commit=True): raise self.RecreateDBException() class TestDBCreateIndexOnUpgrade(TestDB): def __init__(self, path, persistent=False, autocommit=False): super(SQL.TestDBCreateIndexOnUpgrade, self).__init__(path, persistent, autocommit=autocommit, version="2") def _db_upgrade_data_tables(self, q, old_version): q.execute( """ create index TESTING on TESTTYPE (VALUE) """ ) class TestDBPauseInInit(TestDB): def _db_init(self, db_filename, q): time.sleep(1) super(SQL.TestDBPauseInInit, self)._db_init(db_filename, q) def test_connect(self): """ Connect to database and create table """ db = SQL.TestDB(self.mktemp()) self.assertFalse(hasattr(db, "_db_connection")) self.assertTrue(db._db() is not None) self.assertTrue(db._db_connection is not None) def test_connect_autocommit(self): """ Connect to database and create table """ db = SQL.TestDB(self.mktemp(), autocommit=True) self.assertFalse(hasattr(db, "_db_connection")) self.assertTrue(db._db() is not None) self.assertTrue(db._db_connection is not None) def test_readwrite(self): """ Add a record, search for it """ db = SQL.TestDB(self.mktemp()) db._db().execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", ("FOO", "BAR",)) db._db_commit() q = db._db().execute("SELECT * from TESTTYPE") items = [i for i in q.fetchall()] self.assertEqual(items, [("FOO", "BAR")]) def test_readwrite_autocommit(self): """ Add a record, search for it """ db = SQL.TestDB(self.mktemp(), autocommit=True) db._db().execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", ("FOO", "BAR",)) q = db._db().execute("SELECT * from TESTTYPE") items = [i for i in q.fetchall()] self.assertEqual(items, [("FOO", "BAR")]) def test_readwrite_cursor(self): """ Add a record, search for it """ db = SQL.TestDB(self.mktemp()) db._db_execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", "FOO", "BAR") items = db._db_execute("SELECT * from TESTTYPE") self.assertEqual(items, [("FOO", "BAR")]) def test_readwrite_cursor_autocommit(self): """ Add a record, search for it """ db = SQL.TestDB(self.mktemp(), autocommit=True) db._db_execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", "FOO", "BAR") items = db._db_execute("SELECT * from TESTTYPE") self.assertEqual(items, [("FOO", "BAR")]) def test_readwrite_rollback(self): """ Add a record, search for it """ db = SQL.TestDB(self.mktemp()) db._db_execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", "FOO", "BAR") db._db_rollback() items = db._db_execute("SELECT * from TESTTYPE") self.assertEqual(items, []) def test_close(self): """ Close database """ db = SQL.TestDB(self.mktemp()) self.assertFalse(hasattr(db, "_db_connection")) self.assertTrue(db._db() is not None) db._db_close() self.assertFalse(hasattr(db, "_db_connection")) db._db_close() def test_duplicate_create(self): dbname = self.mktemp() class DBThread(Thread): def run(self): try: db = SQL.TestDBPauseInInit(dbname) db._db() self.result = True except: self.result = False t1 = DBThread() t2 = DBThread() t1.start() t2.start() t1.join() t2.join() self.assertTrue(t1.result) self.assertTrue(t2.result) def test_version_upgrade_nonpersistent(self): """ Connect to database and create table """ db = SQL.TestDB(self.mktemp(), autocommit=True) self.assertTrue(db._db() is not None) db._db_execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", "FOO", "BAR") items = db._db_execute("SELECT * from TESTTYPE") self.assertEqual(items, [("FOO", "BAR")]) db._db_close() db = None db = SQL.TestDBRecreateUpgrade(self.mktemp(), autocommit=True) self.assertRaises(SQL.TestDBRecreateUpgrade.RecreateDBException, db._db) items = db._db_execute("SELECT * from TESTTYPE") self.assertEqual(items, []) def test_version_upgrade_persistent(self): """ Connect to database and create table """ db_file = self.mktemp() db = SQL.TestDB(db_file, persistent=True, autocommit=True) self.assertTrue(db._db() is not None) db._db_execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", "FOO", "BAR") items = db._db_execute("SELECT * from TESTTYPE") self.assertEqual(items, [("FOO", "BAR")]) db._db_close() db = None db = SQL.TestDBRecreateUpgrade(db_file, persistent=True, autocommit=True) self.assertRaises(NotImplementedError, db._db) self.assertTrue(os.path.exists(db_file)) db._db_close() db = None db = SQL.TestDB(db_file, persistent=True, autocommit=True) self.assertTrue(db._db() is not None) items = db._db_execute("SELECT * from TESTTYPE") self.assertEqual(items, [("FOO", "BAR")]) def test_version_upgrade_persistent_add_index(self): """ Connect to database and create table """ db_file = self.mktemp() db = SQL.TestDB(db_file, persistent=True, autocommit=True) self.assertTrue(db._db() is not None) db._db_execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", "FOO", "BAR") items = db._db_execute("SELECT * from TESTTYPE") self.assertEqual(items, [("FOO", "BAR")]) db._db_close() db = None db = SQL.TestDBCreateIndexOnUpgrade(db_file, persistent=True, autocommit=True) self.assertTrue(db._db() is not None) items = db._db_execute("SELECT * from TESTTYPE") self.assertEqual(items, [("FOO", "BAR")]) calendarserver-5.2+dfsg/twistedcaldav/test/test_timezonestdservice.py0000644000175000017500000001151212263343324025465 0ustar rahulrahul## # Copyright (c) 2011-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twistedcaldav.timezones import TimezoneCache from twistedcaldav.timezonestdservice import TimezoneInfo, \ PrimaryTimezoneDatabase from xml.etree.ElementTree import Element import hashlib import os import twistedcaldav.test.util class TestTimezoneInfo (twistedcaldav.test.util.TestCase): """ Timezone support tests """ def test_generateXML(self): hashed = hashlib.md5("test").hexdigest() info = TimezoneInfo("America/New_York", ("US/Eastern",), "20110517T120000Z", hashed) node = Element("root") info.generateXML(node) timezone = node.find("timezone") self.assertTrue(timezone is not None) self.assertEqual(timezone.findtext("tzid"), "America/New_York") self.assertEqual(timezone.findtext("dtstamp"), "20110517T120000Z") self.assertEqual(timezone.findtext("alias"), "US/Eastern") self.assertEqual(timezone.findtext("md5"), hashed) def test_parseXML(self): hashed = hashlib.md5("test").hexdigest() info1 = TimezoneInfo("America/New_York", ("US/Eastern",), "20110517T120000Z", hashed) node = Element("root") info1.generateXML(node) timezone = node.find("timezone") info2 = TimezoneInfo.readXML(timezone) self.assertEqual(info2.tzid, "America/New_York") self.assertEqual(info2.aliases, ("US/Eastern",)) self.assertEqual(info2.dtstamp, "20110517T120000Z") self.assertEqual(info2.md5, hashed) class TestPrimaryTimezoneDatabase (twistedcaldav.test.util.TestCase): """ Timezone support tests """ def setUp(self): TimezoneCache.create() def testCreate(self): xmlfile = self.mktemp() db = PrimaryTimezoneDatabase(TimezoneCache.getDBPath(), xmlfile) db.createNewDatabase() self.assertTrue(os.path.exists(xmlfile)) self.assertTrue(db.dtstamp is not None) self.assertTrue(len(db.timezones) > 0) def testUpdate(self): xmlfile = self.mktemp() db = PrimaryTimezoneDatabase(TimezoneCache.getDBPath(), xmlfile) db.createNewDatabase() self.assertTrue(os.path.exists(xmlfile)) db.updateDatabase() self.assertTrue(db.changeCount == 0) self.assertTrue(len(db.changed) == 0) def testRead(self): xmlfile = self.mktemp() db1 = PrimaryTimezoneDatabase(TimezoneCache.getDBPath(), xmlfile) db1.createNewDatabase() self.assertTrue(os.path.exists(xmlfile)) db2 = PrimaryTimezoneDatabase(TimezoneCache.getDBPath(), xmlfile) db2.readDatabase() self.assertEqual(db1.dtstamp, db2.dtstamp) self.assertEqual(len(db1.timezones), len(db2.timezones)) def testList(self): xmlfile = self.mktemp() db = PrimaryTimezoneDatabase(TimezoneCache.getDBPath(), xmlfile) db.createNewDatabase() self.assertTrue(os.path.exists(xmlfile)) tzids = set([tz.tzid for tz in db.listTimezones(None)]) self.assertTrue("America/New_York" in tzids) self.assertTrue("US/Eastern" not in tzids) def testListChangedSince(self): xmlfile = self.mktemp() db = PrimaryTimezoneDatabase(TimezoneCache.getDBPath(), xmlfile) db.createNewDatabase() self.assertTrue(os.path.exists(xmlfile)) tzids = set([tz.tzid for tz in db.listTimezones(db.dtstamp)]) self.assertTrue(len(tzids) == 0) def testGetNone(self): xmlfile = self.mktemp() db = PrimaryTimezoneDatabase(TimezoneCache.getDBPath(), xmlfile) db.createNewDatabase() self.assertTrue(os.path.exists(xmlfile)) tz = db.getTimezone("Bogus") self.assertEqual(tz, None) def testGetOne(self): xmlfile = self.mktemp() db = PrimaryTimezoneDatabase(TimezoneCache.getDBPath(), xmlfile) db.createNewDatabase() self.assertTrue(os.path.exists(xmlfile)) # Original tz1 = db.getTimezone("America/New_York") self.assertTrue(str(tz1).find("VTIMEZONE") != -1) self.assertTrue(str(tz1).find("TZID:America/New_York") != -1) # Alias tz1 = db.getTimezone("US/Eastern") self.assertTrue(str(tz1).find("VTIMEZONE") != -1) self.assertTrue(str(tz1).find("TZID:US/Eastern") != -1) calendarserver-5.2+dfsg/twistedcaldav/test/test_addressbookquery.py0000644000175000017500000001713112263343324025130 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## import os from twext.web2 import responsecode from twext.web2.iweb import IResponse from twext.web2.stream import MemoryStream from txdav.xml import element as davxml from twext.web2.dav.util import davXMLFromStream, joinURL from twistedcaldav import carddavxml, vcard from twistedcaldav.config import config from twistedcaldav.test.util import StoreTestCase, SimpleStoreRequest from twisted.internet.defer import inlineCallbacks, returnValue from twisted.python.filepath import FilePath class AddressBookQuery(StoreTestCase): """ addressbook-query REPORT """ data_dir = os.path.join(os.path.dirname(__file__), "data") vcards_dir = os.path.join(data_dir, "vCards") def test_addressbook_query_by_uid(self): """ vCard by UID. """ uid = "ED7A5AEC-AB19-4CE0-AD6A-2923A3E5C4E1:ABPerson" return self.simple_vcard_query( "/addressbook/", carddavxml.PropertyFilter( carddavxml.TextMatch.fromString(uid), name="UID", ), [uid] ) def test_addressbook_query_all_vcards(self): """ All vCards. """ uids = [r[0] for r in (os.path.splitext(f) for f in os.listdir(self.vcards_dir)) if r[1] == ".vcf"] return self.simple_vcard_query("/addressbook/", None, uids) def test_addressbook_query_limited_with_data(self): """ All vCards. """ oldValue = config.MaxQueryWithDataResults config.MaxQueryWithDataResults = 1 def _restoreValueOK(f): config.MaxQueryWithDataResults = oldValue return None def _restoreValueError(f): config.MaxQueryWithDataResults = oldValue self.fail("REPORT must not fail with 403") uids = [r[0] for r in (os.path.splitext(f) for f in os.listdir(self.vcards_dir)) if r[1] == ".vcf"] d = self.simple_vcard_query("/addressbook/", None, uids, limit=1) d.addCallbacks(_restoreValueOK, _restoreValueError) return d def test_addressbook_query_limited_without_data(self): """ All vCards. """ oldValue = config.MaxQueryWithDataResults config.MaxQueryWithDataResults = 1 def _restoreValueOK(f): config.MaxQueryWithDataResults = oldValue return None def _restoreValueError(f): config.MaxQueryWithDataResults = oldValue self.fail("REPORT must not fail with 403") uids = [r[0] for r in (os.path.splitext(f) for f in os.listdir(self.vcards_dir)) if r[1] == ".vcf"] d = self.simple_vcard_query("/addressbook/", None, uids, withData=False) d.addCallbacks(_restoreValueOK, _restoreValueError) return d def simple_vcard_query(self, vcard_uri, vcard_filter, uids, withData=True, limit=None): vcard_uri = joinURL("/addressbooks/users/wsanchez", vcard_uri) props = ( davxml.GETETag(), ) if withData: props += ( carddavxml.AddressData(), ) query = carddavxml.AddressBookQuery( davxml.PropertyContainer(*props), carddavxml.Filter( vcard_filter, ), ) def got_xml(doc): if not isinstance(doc.root_element, davxml.MultiStatus): self.fail("REPORT response XML root element is not multistatus: %r" % (doc.root_element,)) count = 0 for response in doc.root_element.childrenOfType(davxml.PropertyStatusResponse): for propstat in response.childrenOfType(davxml.PropertyStatus): status = propstat.childOfType(davxml.Status) if status.code == responsecode.INSUFFICIENT_STORAGE_SPACE and limit is not None: continue if status.code != responsecode.OK: self.fail("REPORT failed (status %s) to locate properties: %r" % (status.code, propstat)) elif limit is not None: count += 1 continue properties = propstat.childOfType(davxml.PropertyContainer).children for property in properties: qname = property.qname() if qname == (davxml.dav_namespace, "getetag"): continue if qname != (carddavxml.carddav_namespace, "address-data"): self.fail("Response included unexpected property %r" % (property,)) result_addressbook = property.address() if result_addressbook is None: self.fail("Invalid response CardDAV:address-data: %r" % (property,)) uid = result_addressbook.resourceUID() if uid in uids: uids.remove(uid) else: self.fail("Got addressbook for unexpected UID %r" % (uid,)) original_filename = file(os.path.join(self.vcards_dir, uid + ".vcf")) original_addressbook = vcard.Component.fromStream(original_filename) self.assertEqual(result_addressbook, original_addressbook) if limit is not None and count != limit: self.fail("Wrong number of limited results: %d" % (count,)) return self.addressbook_query(vcard_uri, query, got_xml) @inlineCallbacks def addressbook_query(self, addressbook_uri, query, got_xml): ''' FIXME: clear address book, possibly by removing mkcol = """ """ response = yield self.send(SimpleStoreRequest(self, "MKCOL", addressbook_uri, content=mkcol, authid="wsanchez")) response = IResponse(response) if response.code != responsecode.CREATED: self.fail("MKCOL failed: %s" % (response.code,)) ''' # Add vCards to addressbook for child in FilePath(self.vcards_dir).children(): if os.path.splitext(child.basename())[1] != ".vcf": continue request = SimpleStoreRequest(self, "PUT", joinURL(addressbook_uri, child.basename()), authid="wsanchez") request.stream = MemoryStream(child.getContent()) yield self.send(request) request = SimpleStoreRequest(self, "REPORT", addressbook_uri, authid="wsanchez") request.stream = MemoryStream(query.toxml()) response = yield self.send(request) response = IResponse(response) if response.code != responsecode.MULTI_STATUS: self.fail("REPORT failed: %s" % (response.code,)) returnValue( (yield davXMLFromStream(response.stream).addCallback(got_xml)) ) calendarserver-5.2+dfsg/twistedcaldav/test/test_upgrade.py0000644000175000017500000016321712263343324023200 0ustar rahulrahul## # Copyright (c) 2008-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## import hashlib import os import zlib import cPickle from twisted.python.reflect import namedClass from twisted.internet.defer import inlineCallbacks from txdav.xml.parser import WebDAVDocument from txdav.caldav.datastore.index_file import db_basename from twistedcaldav.config import config from twistedcaldav.directory.xmlfile import XMLDirectoryService from twistedcaldav.directory.resourceinfo import ResourceInfoDatabase from txdav.caldav.datastore.scheduling.imip.mailgateway import MailGatewayTokensDatabase from twistedcaldav.upgrade import ( xattrname, upgradeData, updateFreeBusySet, removeIllegalCharacters, normalizeCUAddrs ) from twistedcaldav.test.util import TestCase from calendarserver.tools.util import getDirectory freeBusyAttr = xattrname( "{urn:ietf:params:xml:ns:caldav}calendar-free-busy-set" ) cTagAttr = xattrname( "{http:%2F%2Fcalendarserver.org%2Fns%2F}getctag" ) md5Attr = xattrname( "{http:%2F%2Ftwistedmatrix.com%2Fxml_namespace%2Fdav%2F}getcontentmd5" ) OLDPROXYFILE = ".db.calendaruserproxy" NEWPROXYFILE = "proxies.sqlite" class UpgradeTests(TestCase): def setUpXMLDirectory(self): xmlFile = os.path.join(os.path.dirname(os.path.dirname(__file__)), "directory", "test", "accounts.xml") config.DirectoryService.params.xmlFile = xmlFile xmlAugmentsFile = os.path.join(os.path.dirname(os.path.dirname(__file__)), "directory", "test", "augments.xml") config.AugmentService.type = "twistedcaldav.directory.augment.AugmentXMLDB" config.AugmentService.params.xmlFiles = (xmlAugmentsFile,) resourceFile = os.path.join(os.path.dirname(os.path.dirname(__file__)), "directory", "test", "resources.xml") config.ResourceService.params.xmlFile = resourceFile def doUpgrade(self, config): """ Perform the actual upgrade. (Hook for parallel tests.) """ return upgradeData(config) def setUpInitialStates(self): self.setUpXMLDirectory() self.setUpOldDocRoot() self.setUpOldDocRootWithoutDB() self.setUpNewDocRoot() self.setUpNewDataRoot() self.setUpDataRootWithProxyDB() def setUpOldDocRoot(self): # Set up doc root self.olddocroot = os.path.abspath(self.mktemp()) os.mkdir(self.olddocroot) principals = os.path.join(self.olddocroot, "principals") os.mkdir(principals) os.mkdir(os.path.join(principals, "__uids__")) os.mkdir(os.path.join(principals, "users")) os.mkdir(os.path.join(principals, "groups")) os.mkdir(os.path.join(principals, "locations")) os.mkdir(os.path.join(principals, "resources")) os.mkdir(os.path.join(principals, "sudoers")) open(os.path.join(principals, OLDPROXYFILE), "w").close() def setUpOldDocRootWithoutDB(self): # Set up doc root self.olddocrootnodb = os.path.abspath(self.mktemp()) os.mkdir(self.olddocrootnodb) principals = os.path.join(self.olddocrootnodb, "principals") os.mkdir(principals) os.mkdir(os.path.join(principals, "__uids__")) os.mkdir(os.path.join(principals, "users")) os.mkdir(os.path.join(principals, "groups")) os.mkdir(os.path.join(principals, "locations")) os.mkdir(os.path.join(principals, "resources")) os.mkdir(os.path.join(principals, "sudoers")) os.mkdir(os.path.join(self.olddocrootnodb, "calendars")) def setUpNewDocRoot(self): # Set up doc root self.newdocroot = os.path.abspath(self.mktemp()) os.mkdir(self.newdocroot) os.mkdir(os.path.join(self.newdocroot, "calendars")) def setUpNewDataRoot(self): # Set up data root self.newdataroot = os.path.abspath(self.mktemp()) os.mkdir(self.newdataroot) def setUpDataRootWithProxyDB(self): # Set up data root self.existingdataroot = os.path.abspath(self.mktemp()) os.mkdir(self.existingdataroot) principals = os.path.join(self.existingdataroot, "principals") os.mkdir(principals) open(os.path.join(self.existingdataroot, NEWPROXYFILE), "w").close() @inlineCallbacks def test_normalUpgrade(self): """ Test the behavior of normal upgrade from old server to new. """ self.setUpInitialStates() config.DocumentRoot = self.olddocroot config.DataRoot = self.newdataroot # Check pre-conditions self.assertTrue(os.path.exists(os.path.join(config.DocumentRoot, "principals"))) self.assertTrue(os.path.isdir(os.path.join(config.DocumentRoot, "principals"))) self.assertTrue(os.path.exists(os.path.join(config.DocumentRoot, "principals", OLDPROXYFILE))) self.assertFalse(os.path.exists(os.path.join(config.DataRoot, NEWPROXYFILE))) (yield self.doUpgrade(config)) # Check post-conditions self.assertFalse(os.path.exists(os.path.join(config.DocumentRoot, "principals",))) self.assertTrue(os.path.exists(os.path.join(config.DataRoot, NEWPROXYFILE))) @inlineCallbacks def test_noUpgrade(self): """ Test the behavior of running on a new server (i.e. no upgrade needed). """ self.setUpInitialStates() config.DocumentRoot = self.newdocroot config.DataRoot = self.existingdataroot # Check pre-conditions self.assertFalse(os.path.exists(os.path.join(config.DocumentRoot, "principals"))) self.assertTrue(os.path.exists(os.path.join(config.DataRoot, NEWPROXYFILE))) (yield self.doUpgrade(config)) # Check post-conditions self.assertFalse(os.path.exists(os.path.join(config.DocumentRoot, "principals",))) self.assertTrue(os.path.exists(os.path.join(config.DataRoot, NEWPROXYFILE))) def test_freeBusyUpgrade(self): """ Test the updating of calendar-free-busy-set xattrs on inboxes """ self.setUpInitialStates() directory = getDirectory() # # Verify these values require no updating: # # Uncompressed XML value = "\r\n\r\n /calendars/__uids__/BB05932F-DCE7-4195-9ED4-0896EAFF3B0B/calendar\r\n\r\n" self.assertEquals(updateFreeBusySet(value, directory), None) # Zlib compressed XML value = "\r\n\r\n /calendars/__uids__/BB05932F-DCE7-4195-9ED4-0896EAFF3B0B/calendar\r\n\r\n" value = zlib.compress(value) self.assertEquals(updateFreeBusySet(value, directory), None) # Pickled XML value = "\r\n\r\n /calendars/__uids__/BB05932F-DCE7-4195-9ED4-0896EAFF3B0B/calendar\r\n\r\n" doc = WebDAVDocument.fromString(value) value = cPickle.dumps(doc.root_element) self.assertEquals(updateFreeBusySet(value, directory), None) # # Verify these values do require updating: # expected = "\n\r\n /calendars/__uids__/6423F94A-6B76-4A3A-815B-D52CFD77935D/calendar/\r\n" # Uncompressed XML value = "\r\n\r\n /calendars/users/wsanchez/calendar\r\n\r\n" newValue = updateFreeBusySet(value, directory) newValue = zlib.decompress(newValue) self.assertEquals(newValue, expected) # Zlib compressed XML value = "\r\n\r\n /calendars/users/wsanchez/calendar\r\n\r\n" value = zlib.compress(value) newValue = updateFreeBusySet(value, directory) newValue = zlib.decompress(newValue) self.assertEquals(newValue, expected) # Pickled XML value = "\r\n\r\n /calendars/users/wsanchez/calendar\r\n\r\n" doc = WebDAVDocument.fromString(value) value = cPickle.dumps(doc.root_element) newValue = updateFreeBusySet(value, directory) newValue = zlib.decompress(newValue) self.assertEquals(newValue, expected) # # Shortname not in directory, return empty string # expected = "\n" value = "\r\n\r\n /calendars/users/nonexistent/calendar\r\n\r\n" newValue = updateFreeBusySet(value, directory) newValue = zlib.decompress(newValue) self.assertEquals(newValue, expected) @inlineCallbacks def verifyDirectoryComparison(self, before, after, reverify=False): """ Verify that the hierarchy described by "before", when upgraded, matches the hierarchy described by "after". @param before: a dictionary of the format accepted by L{TestCase.createHierarchy} @param after: a dictionary of the format accepted by L{TestCase.createHierarchy} @param reverify: if C{True}, re-verify the hierarchy by upgrading a second time and re-verifying the root again. @raise twisted.trial.unittest.FailTest: if the test fails. @return: C{None} """ root = self.createHierarchy(before) config.DocumentRoot = root config.DataRoot = root (yield self.doUpgrade(config)) self.assertTrue(self.verifyHierarchy(root, after)) if reverify: # Ensure that repeating the process doesn't change anything (yield self.doUpgrade(config)) self.assertTrue(self.verifyHierarchy(root, after)) @inlineCallbacks def test_removeNotificationDirectories(self): """ The upgrade process should remove unused notification directories in users' calendar homes, as well as the XML files found therein. """ self.setUpXMLDirectory() before = { "calendars": { "users": { "wsanchez": { "calendar" : { db_basename : { "@contents": "", }, }, "notifications": { "sample-notification.xml": { "@contents": "\n" } } } } } } after = { "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar": { db_basename : { "@contents": "", }, }, } } } } }, ".calendarserver_version" : { "@contents" : "2", }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after)) @inlineCallbacks def test_calendarsUpgradeWithTypes(self): """ Verify that calendar homes in the /calendars/// form are upgraded to /calendars/__uids__/XX/YY/ form """ self.setUpXMLDirectory() before = { "calendars" : { "users" : { "wsanchez" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_before, "@xattrs" : { md5Attr : "12345", }, }, "@xattrs" : { cTagAttr : "12345", }, }, "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { # Pickled XML Doc freeBusyAttr : cPickle.dumps(WebDAVDocument.fromString("\r\n\r\n /calendars/users/wsanchez/calendar\r\n\r\n").root_element), }, }, }, }, "groups" : { "managers" : { "calendar" : { db_basename : { "@contents": "", }, }, }, }, }, "principals" : { OLDPROXYFILE : { "@contents" : "", } } } after = { ".calendarserver_version" : { "@contents" : "2", }, "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_after, "@xattrs" : { md5Attr : zlib.compress("\r\n%s\r\n" % (event01_after_md5,)), }, }, "@xattrs" : { cTagAttr : isValidCTag, # method below }, }, "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { freeBusyAttr : zlib.compress("\n\r\n /calendars/__uids__/6423F94A-6B76-4A3A-815B-D52CFD77935D/calendar/\r\n"), }, }, }, }, }, "9F" : { "F6" : { "9FF60DAD-0BDE-4508-8C77-15F0CA5C8DD1" : { "calendar" : { db_basename : { "@contents": "", }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after, reverify=True)) @inlineCallbacks def test_calendarsUpgradeWithOrphans(self): """ Verify that calendar homes in the /calendars/// form whose records don't exist are moved into dataroot/archived/ """ self.setUpXMLDirectory() before = { "calendars" : { "users" : { "unknownuser" : { }, }, "groups" : { "unknowngroup" : { }, }, }, "principals" : { OLDPROXYFILE : { "@contents" : "", } } } after = { "archived" : { "unknownuser" : { }, "unknowngroup" : { }, }, ".calendarserver_version" : { "@contents" : "2", }, "calendars" : { "__uids__" : { }, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after, reverify=True)) @inlineCallbacks def test_calendarsUpgradeWithDuplicateOrphans(self): """ Verify that calendar homes in the /calendars/// form whose records don't exist are moved into dataroot/archived/ """ self.setUpXMLDirectory() before = { "archived" : { "unknownuser" : { }, "unknowngroup" : { }, }, "calendars" : { "users" : { "unknownuser" : { }, }, "groups" : { "unknowngroup" : { }, }, }, "principals" : { OLDPROXYFILE : { "@contents" : "", } } } after = { "archived" : { "unknownuser" : { }, "unknowngroup" : { }, "unknownuser.1" : { }, "unknowngroup.1" : { }, }, ".calendarserver_version" : { "@contents" : "2", }, "calendars" : { "__uids__" : { }, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after, reverify=True)) @inlineCallbacks def test_calendarsUpgradeWithUnknownFiles(self): """ Unknown files, including .DS_Store files at any point in the hierarchy, as well as non-directory in a user's calendar home, will be ignored and not interrupt an upgrade. """ self.setUpXMLDirectory() ignoredUIDContents = { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, }, "garbage.ics" : { "@contents": "Oops, not actually an ICS file.", }, "other-file.txt": { "@contents": "Also not a calendar collection." }, } } }, ".DS_Store" : { "@contents" : "", } } before = { ".DS_Store" : { "@contents" : "", }, "calendars" : { ".DS_Store" : { "@contents" : "", }, "__uids__" : ignoredUIDContents, }, "principals" : { ".DS_Store" : { "@contents" : "", }, OLDPROXYFILE : { "@contents" : "", } } } after = { ".DS_Store" : { "@contents" : "", }, ".calendarserver_version" : { "@contents" : "2", }, "calendars" : { ".DS_Store" : { "@contents" : "", }, "__uids__" : ignoredUIDContents, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after, reverify=True)) @inlineCallbacks def test_calendarsUpgradeWithNestedCollections(self): """ Unknown files, including .DS_Store files at any point in the hierarchy, as well as non-directory in a user's calendar home, will be ignored and not interrupt an upgrade. """ self.setUpXMLDirectory() beforeUIDContents = { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, }, "nested1": { "nested2": {}, }, } } }, ".DS_Store" : { "@contents" : "", } } afterUIDContents = { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, }, ".collection.nested1": { "nested2": {}, }, } } }, ".DS_Store" : { "@contents" : "", } } before = { ".DS_Store" : { "@contents" : "", }, "calendars" : { ".DS_Store" : { "@contents" : "", }, "__uids__" : beforeUIDContents, }, "principals" : { ".DS_Store" : { "@contents" : "", }, OLDPROXYFILE : { "@contents" : "", } } } after = { ".DS_Store" : { "@contents" : "", }, ".calendarserver_version" : { "@contents" : "2", }, "calendars" : { ".DS_Store" : { "@contents" : "", }, "__uids__" : afterUIDContents, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after, reverify=True)) @inlineCallbacks def test_calendarsUpgradeWithUIDs(self): """ Verify that calendar homes in the /calendars/__uids__// form are upgraded to /calendars/__uids__/XX/YY// form """ self.setUpXMLDirectory() before = { "calendars" : { "__uids__" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_before, }, }, "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { # Plain XML freeBusyAttr : "\r\n\r\n /calendars/users/wsanchez/calendar\r\n\r\n", }, }, }, }, }, "principals" : { OLDPROXYFILE : { "@contents" : "", } } } after = { ".calendarserver_version" : { "@contents" : "2", }, "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_after, }, "@xattrs" : { cTagAttr : isValidCTag, # method below }, }, "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { freeBusyAttr : zlib.compress("\n\r\n /calendars/__uids__/6423F94A-6B76-4A3A-815B-D52CFD77935D/calendar/\r\n"), }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after, reverify=True)) @inlineCallbacks def test_calendarsUpgradeWithUIDsMultilevel(self): """ Verify that calendar homes in the /calendars/__uids__/XX/YY// form are upgraded correctly in place """ self.setUpXMLDirectory() before = { "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_before, "@xattrs" : { md5Attr : "12345", }, }, "@xattrs" : { xattrname("ignore") : "extra", cTagAttr : "12345", }, }, "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { # Zlib compressed XML freeBusyAttr : zlib.compress("\r\n\r\n /calendars/users/wsanchez/calendar\r\n\r\n"), }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : "", } } after = { ".calendarserver_version" : { "@contents" : "2", }, "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_after, "@xattrs" : { md5Attr : zlib.compress("\r\n%s\r\n" % (event01_after_md5,)), }, }, "@xattrs" : { xattrname("ignore") : "extra", cTagAttr : isValidCTag, # method below }, }, "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { freeBusyAttr : zlib.compress("\n\r\n /calendars/__uids__/6423F94A-6B76-4A3A-815B-D52CFD77935D/calendar/\r\n"), }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after, reverify=True)) @inlineCallbacks def test_calendarsUpgradeWithNoChange(self): """ Verify that calendar homes in the /calendars/__uids__/XX/YY// form which require no changes are untouched """ self.setUpXMLDirectory() before = { "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_after, "@xattrs" : { md5Attr : zlib.compress("\r\n%s\r\n" % (event01_after_md5,)), }, }, "@xattrs" : { xattrname("ignore") : "extra", cTagAttr : zlib.compress("\r\n2009-02-25 14:34:34.703093\r\n"), }, }, "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { # Zlib compressed XML freeBusyAttr : zlib.compress("\r\n\r\n /calendars/__uids__/6423F94A-6B76-4A3A-815B-D52CFD77935D/calendar/\r\n\r\n"), }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : "", } } after = { ".calendarserver_version" : { "@contents" : "2", }, "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_after, "@xattrs" : { md5Attr : zlib.compress("\r\n%s\r\n" % (event01_after_md5,)), }, }, "@xattrs" : { xattrname("ignore") : "extra", cTagAttr : zlib.compress("\r\n2009-02-25 14:34:34.703093\r\n"), }, }, "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { freeBusyAttr : zlib.compress("\r\n\r\n /calendars/__uids__/6423F94A-6B76-4A3A-815B-D52CFD77935D/calendar/\r\n\r\n"), }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after)) @inlineCallbacks def test_calendarsUpgradeWithInboxItems(self): """ Verify that inbox items older than 60 days are deleted """ self.setUpXMLDirectory() before = { "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { # Zlib compressed XML freeBusyAttr : zlib.compress("\r\n\r\n /calendars/__uids__/6423F94A-6B76-4A3A-815B-D52CFD77935D/calendar/\r\n\r\n"), }, "oldinboxitem" : { "@contents": "", "@timestamp": 1, # really old file }, "newinboxitem" : { "@contents": "", }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : "", } } after = { ".calendarserver_version" : { "@contents" : "2", }, "inboxitems.txt" : { "@contents" : None, # ignore contents, the paths inside are random test directory paths }, "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935D" : { "inbox" : { db_basename : { "@contents": "", }, "@xattrs" : { freeBusyAttr : zlib.compress("\r\n\r\n /calendars/__uids__/6423F94A-6B76-4A3A-815B-D52CFD77935D/calendar/\r\n\r\n"), }, "newinboxitem" : { "@contents": "", }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } (yield self.verifyDirectoryComparison(before, after)) @inlineCallbacks def test_calendarsUpgradeWithError(self): """ Verify that a problem with one resource doesn't stop the process, but also doesn't write the new version file """ self.setUpXMLDirectory() before = { "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935E" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_before, }, "1E238CA1-3C95-4468-B8CD-C8A399F78C73.ics" : { "@contents" : event02_broken, }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : "", } } after = { ".calendarserver_version" : { "@contents" : "2", }, "calendars" : { "__uids__" : { "64" : { "23" : { "6423F94A-6B76-4A3A-815B-D52CFD77935E" : { "calendar" : { db_basename : { "@contents": "", }, "1E238CA1-3C95-4468-B8CD-C8A399F78C72.ics" : { "@contents" : event01_after, }, "1E238CA1-3C95-4468-B8CD-C8A399F78C73.ics" : { "@contents" : event02_broken, }, }, }, }, }, }, }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, } root = self.createHierarchy(before) config.DocumentRoot = root config.DataRoot = root (yield self.doUpgrade(config)) self.assertTrue(self.verifyHierarchy(root, after)) @inlineCallbacks def test_migrateResourceInfo(self): # Fake getResourceInfo( ) assignments = { 'guid1' : (False, None, None), 'guid2' : (True, 'guid1', None), 'guid3' : (False, 'guid1', 'guid2'), 'guid4' : (True, None, 'guid3'), } def _getResourceInfo(ignored): results = [] for guid, info in assignments.iteritems(): results.append((guid, info[0], info[1], info[2])) return results self.setUpInitialStates() # Override the normal getResourceInfo method with our own: # XMLDirectoryService.getResourceInfo = _getResourceInfo self.patch(XMLDirectoryService, "getResourceInfo", _getResourceInfo) before = { "trigger_resource_migration" : { "@contents" : "x", } } after = { ".calendarserver_version" : { "@contents" : "2", }, NEWPROXYFILE : { "@contents" : None, }, MailGatewayTokensDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (MailGatewayTokensDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, }, ResourceInfoDatabase.dbFilename : { "@contents" : None, }, "%s-journal" % (ResourceInfoDatabase.dbFilename,) : { "@contents" : None, "@optional" : True, } } root = self.createHierarchy(before) config.DocumentRoot = root config.DataRoot = root config.ServerRoot = root (yield self.doUpgrade(config)) self.assertTrue(self.verifyHierarchy(root, after)) proxydbClass = namedClass(config.ProxyDBService.type) calendarUserProxyDatabase = proxydbClass(**config.ProxyDBService.params) resourceInfoDatabase = ResourceInfoDatabase(root) for guid, info in assignments.iteritems(): proxyGroup = "%s#calendar-proxy-write" % (guid,) result = (yield calendarUserProxyDatabase.getMembers(proxyGroup)) if info[1]: self.assertTrue(info[1] in result) else: self.assertTrue(not result) readOnlyProxyGroup = "%s#calendar-proxy-read" % (guid,) result = (yield calendarUserProxyDatabase.getMembers(readOnlyProxyGroup)) if info[2]: self.assertTrue(info[2] in result) else: self.assertTrue(not result) autoSchedule = resourceInfoDatabase._db_value_for_sql("select AUTOSCHEDULE from RESOURCEINFO where GUID = :1", guid) autoSchedule = autoSchedule == 1 self.assertEquals(info[0], autoSchedule) def test_removeIllegalCharacters(self): """ Control characters aside from NL and CR are removed. """ data = "Contains\x03 control\x06 characters\x12 some\x0a\x09allowed\x0d" after, changed = removeIllegalCharacters(data) self.assertEquals(after, "Contains control characters some\x0a\x09allowed\x0d") self.assertTrue(changed) data = "Contains\x09only\x0a legal\x0d" after, changed = removeIllegalCharacters(data) self.assertEquals(after, "Contains\x09only\x0a legal\x0d") self.assertFalse(changed) def test_normalizeCUAddrs(self): """ Ensure that calendar user addresses (CUAs) are cached so we can reduce the number of principal lookup calls during upgrade. """ class StubPrincipal(object): def __init__(self, record): self.record = record class StubRecord(object): def __init__(self, fullName, guid, cuas): self.fullName = fullName self.guid = guid self.calendarUserAddresses = cuas class StubDirectory(object): def __init__(self): self.count = 0 def principalForCalendarUserAddress(self, cuaddr): self.count += 1 record = records.get(cuaddr, None) if record is not None: return StubPrincipal(record) else: raise Exception records = { "mailto:a@example.com" : StubRecord("User A", 123, ("mailto:a@example.com", "urn:uuid:123")), "mailto:b@example.com" : StubRecord("User B", 234, ("mailto:b@example.com", "urn:uuid:234")), "/principals/users/a" : StubRecord("User A", 123, ("mailto:a@example.com", "urn:uuid:123")), "/principals/users/b" : StubRecord("User B", 234, ("mailto:b@example.com", "urn:uuid:234")), } directory = StubDirectory() cuaCache = {} normalizeCUAddrs(normalizeEvent, directory, cuaCache) normalizeCUAddrs(normalizeEvent, directory, cuaCache) # Ensure we only called principalForCalendarUserAddress 3 times. It # would have been 8 times without the cuaCache. self.assertEquals(directory.count, 3) normalizeEvent = """BEGIN:VCALENDAR VERSION:2.0 BEGIN:VEVENT TRANSP:OPAQUE UID:1E238CA1-3C95-4468-B8CD-C8A399F78C71 DTSTART:20090203 DTEND:20090204 ORGANIZER;CN="User A":mailto:a@example.com SUMMARY:New Event DESCRIPTION:Foo ATTENDEE;CN="User A";CUTYPE=INDIVIDUAL;PARTSTAT=ACCEPTED:mailto:a@example.com ATTENDEE;CN="User B";CUTYPE=INDIVIDUAL;PARTSTAT=ACCEPTED:mailto:b@example.com ATTENDEE;CN="Unknown";CUTYPE=INDIVIDUAL;PARTSTAT=ACCEPTED:mailto:unknown@example.com END:VEVENT END:VCALENDAR """.replace("\n", "\r\n") event01_before = """BEGIN:VCALENDAR VERSION:2.0 PRODID:-//Apple Inc.//iCal 3.0//EN CALSCALE:GREGORIAN BEGIN:VTIMEZONE TZID:US/Pacific BEGIN:DAYLIGHT TZOFFSETFROM:-0800 TZOFFSETTO:-0700 DTSTART:20070311T020000 RRULE:FREQ=YEARLY;BYMONTH=3;BYDAY=2SU TZNAME:PDT END:DAYLIGHT BEGIN:STANDARD TZOFFSETFROM:-0700 TZOFFSETTO:-0800 DTSTART:20071104T020000 RRULE:FREQ=YEARLY;BYMONTH=11;BYDAY=1SU TZNAME:PST END:STANDARD END:VTIMEZONE BEGIN:VEVENT SEQUENCE:2 TRANSP:OPAQUE UID:1E238CA1-3C95-4468-B8CD-C8A399F78C71 DTSTART;TZID=US/Pacific:20090203T120000 ORGANIZER;CN="Cyrus":mailto:cdaboo@example.com DTSTAMP:20090203T181924Z SUMMARY:New Event DESCRIPTION:This has \\" Bad Quotes \\" in it ATTENDEE;CN="Wilfredo";CUTYPE=INDIVIDUAL;PARTSTAT=ACCEPTED:mailto:wsanchez @example.com ATTENDEE;CN="Double";CUTYPE=INDIVIDUAL;PARTSTAT=ACCEPTED:mailto:doublequotes @example.com ATTENDEE;CN="Cyrus";CUTYPE=INDIVIDUAL;PARTSTAT=ACCEPTED;ROLE=REQ-PARTICI PANT:mailto:cdaboo@example.com CREATED:20090203T181910Z DTEND;TZID=US/Pacific:20090203T130000 END:VEVENT END:VCALENDAR """.replace("\n", "\r\n") event01_after = """BEGIN:VCALENDAR VERSION:2.0 CALSCALE:GREGORIAN PRODID:-//Apple Inc.//iCal 3.0//EN BEGIN:VTIMEZONE TZID:US/Pacific BEGIN:DAYLIGHT DTSTART:20070311T020000 RRULE:FREQ=YEARLY;BYDAY=2SU;BYMONTH=3 TZNAME:PDT TZOFFSETFROM:-0800 TZOFFSETTO:-0700 END:DAYLIGHT BEGIN:STANDARD DTSTART:20071104T020000 RRULE:FREQ=YEARLY;BYDAY=1SU;BYMONTH=11 TZNAME:PST TZOFFSETFROM:-0700 TZOFFSETTO:-0800 END:STANDARD END:VTIMEZONE BEGIN:VEVENT UID:1E238CA1-3C95-4468-B8CD-C8A399F78C71 DTSTART;TZID=US/Pacific:20090203T120000 DTEND;TZID=US/Pacific:20090203T130000 ATTENDEE;CN=Wilfredo Sanchez;CUTYPE=INDIVIDUAL;EMAIL=wsanchez@example.com; PARTSTAT=ACCEPTED:urn:uuid:6423F94A-6B76-4A3A-815B-D52CFD77935D ATTENDEE;CN=Double 'quotey' Quotes;CUTYPE=INDIVIDUAL;EMAIL=doublequotes@ex ample.com;PARTSTAT=ACCEPTED:urn:uuid:8E04787E-336D-41ED-A70B-D233AD0DCE6F ATTENDEE;CN=Cyrus Daboo;CUTYPE=INDIVIDUAL;EMAIL=cdaboo@example.com;PARTSTA T=ACCEPTED;ROLE=REQ-PARTICIPANT:urn:uuid:5A985493-EE2C-4665-94CF-4DFEA3A89 500 CREATED:20090203T181910Z DESCRIPTION:This has " Bad Quotes " in it DTSTAMP:20090203T181924Z ORGANIZER;CN=Cyrus Daboo;EMAIL=cdaboo@example.com:urn:uuid:5A985493-EE2C-4 665-94CF-4DFEA3A89500 SEQUENCE:2 SUMMARY:New Event TRANSP:OPAQUE END:VEVENT END:VCALENDAR """.replace("\n", "\r\n") event02_broken = "Invalid!" event01_after_md5 = hashlib.md5(event01_after).hexdigest() def isValidCTag(value): """ Since ctag is generated from datetime.now(), let's make sure that at least the value is zlib compressed XML """ try: value = zlib.decompress(value) except zlib.error: return False try: WebDAVDocument.fromString(value) return True except ValueError: return False calendarserver-5.2+dfsg/twistedcaldav/test/test_accounting.py0000644000175000017500000000647112263343324023701 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twext.web2.channel.http import HTTPLoggingChannelRequest from twext.web2 import http_headers from twext.web2.channel.http import HTTPChannel from twistedcaldav.accounting import emitAccounting from twistedcaldav.config import config import twistedcaldav.test.util import os import stat class AccountingITIP (twistedcaldav.test.util.TestCase): def setUp(self): super(AccountingITIP, self).setUp() config.AccountingCategories.iTIP = True config.AccountingPrincipals = ["*", ] os.mkdir(config.AccountingLogRoot) class _Principal(object): class _Record(object): def __init__(self, guid): self.guid = guid def __init__(self, guid): self.record = self._Record(guid) def test_permissions_makedirs(self): """ Test permissions when creating accounting """ # Make log root non-writeable os.chmod(config.AccountingLogRoot, stat.S_IRUSR) emitAccounting("iTIP", self._Principal("1234-5678"), "bogus") def test_file_instead_of_directory(self): """ Test permissions when creating accounting """ # Make log root a file config.AccountingLogRoot = "other" open(config.AccountingLogRoot, "w").close() emitAccounting("iTIP", self._Principal("1234-5678"), "bogus") class AccountingHTTP (twistedcaldav.test.util.TestCase): def setUp(self): super(AccountingHTTP, self).setUp() config.AccountingCategories.HTTP = True config.AccountingPrincipals = ["*", ] def test_channel_request(self): """ Test permissions when creating accounting """ # Make channel request object channelRequest = HTTPLoggingChannelRequest(HTTPChannel()) self.assertTrue(channelRequest != None) def test_logging(self): """ Test permissions when creating accounting """ class FakeRequest(object): def handleContentChunk(self, data): pass def handleContentComplete(self): pass # Make log root a file channelRequest = HTTPLoggingChannelRequest(HTTPChannel(), queued=1) channelRequest.request = FakeRequest() channelRequest.gotInitialLine("GET / HTTP/1.1") channelRequest.lineReceived("Host:localhost") channelRequest.lineReceived("Content-Length:5") channelRequest.handleContentChunk("Bogus") channelRequest.handleContentComplete() channelRequest.writeHeaders(200, http_headers.Headers({"Content-Type": http_headers.MimeType('text', 'plain'), "Content-Length": "4"})) channelRequest.transport.write("Data") channelRequest.finish() calendarserver-5.2+dfsg/twistedcaldav/test/test_multiget.py0000644000175000017500000002535012263343324023376 0ustar rahulrahul# Copyright (c) 2006-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twext.python.filepath import CachingFilePath as FilePath from twext.web2 import responsecode from twext.web2.dav.util import davXMLFromStream, joinURL from twext.web2.iweb import IResponse from twext.web2.stream import MemoryStream from twisted.internet.defer import inlineCallbacks, returnValue from twistedcaldav import caldavxml from twistedcaldav import ical from twistedcaldav.config import config from twistedcaldav.test.util import todo, StoreTestCase, SimpleStoreRequest from txdav.xml import element as davxml import os class CalendarMultiget (StoreTestCase): """ calendar-multiget REPORT """ data_dir = os.path.join(os.path.dirname(__file__), "data") holidays_dir = os.path.join(data_dir, "Holidays") def test_multiget_some_events(self): """ All events. (CalDAV-access-09, section 7.6.8) """ okuids = [r[0] for r in (os.path.splitext(f) for f in os.listdir(self.holidays_dir)) if r[1] == ".ics"] okuids[:] = okuids[1:10] baduids = ["12345%40example.com", "67890%40example.com"] return self.simple_event_multiget("/calendar_multiget_events/", okuids, baduids) def test_multiget_all_events(self): """ All events. (CalDAV-access-09, section 7.6.8) """ okuids = [r[0] for r in (os.path.splitext(f) for f in os.listdir(self.holidays_dir)) if r[1] == ".ics"] baduids = ["12345%40example.com", "67890%40example.com"] return self.simple_event_multiget("/calendar_multiget_events/", okuids, baduids) def test_multiget_limited_with_data(self): """ All events. (CalDAV-access-09, section 7.6.8) """ oldValue = config.MaxMultigetWithDataHrefs config.MaxMultigetWithDataHrefs = 1 def _restoreValueOK(f): config.MaxMultigetWithDataHrefs = oldValue self.fail("REPORT must fail with 403") def _restoreValueError(f): config.MaxMultigetWithDataHrefs = oldValue return None okuids = [r[0] for r in (os.path.splitext(f) for f in os.listdir(self.holidays_dir)) if r[1] == ".ics"] baduids = ["12345%40example.com", "67890%40example.com"] d = self.simple_event_multiget("/calendar_multiget_events/", okuids, baduids) d.addCallbacks(_restoreValueOK, _restoreValueError) return d def test_multiget_limited_no_data(self): """ All events. (CalDAV-access-09, section 7.6.8) """ oldValue = config.MaxMultigetWithDataHrefs config.MaxMultigetWithDataHrefs = 1 def _restoreValueOK(f): config.MaxMultigetWithDataHrefs = oldValue return None def _restoreValueError(f): config.MaxMultigetWithDataHrefs = oldValue self.fail("REPORT must not fail with 403") okuids = [r[0] for r in (os.path.splitext(f) for f in os.listdir(self.holidays_dir)) if r[1] == ".ics"] baduids = ["12345%40example.com", "67890%40example.com"] return self.simple_event_multiget("/calendar_multiget_events/", okuids, baduids, withData=False) @todo("Remove: Does not work with new store") @inlineCallbacks def test_multiget_one_broken_event(self): """ All events. (CalDAV-access-09, section 7.6.8) """ okuids = ["good", "bad", ] baduids = [] data = { "good": """BEGIN:VCALENDAR CALSCALE:GREGORIAN PRODID:-//Apple Computer\, Inc//iCal 2.0//EN VERSION:2.0 BEGIN:VEVENT UID:good DTSTART;VALUE=DATE:20020101 DTEND;VALUE=DATE:20020102 DTSTAMP:20020101T121212Z RRULE:FREQ=YEARLY;INTERVAL=1;UNTIL=20031231;BYMONTH=1 SUMMARY:New Year's Day END:VEVENT END:VCALENDAR """.replace("\n", "\r\n"), "bad": """BEGIN:VCALENDAR CALSCALE:GREGORIAN PRODID:-//Apple Computer\, Inc//iCal 2.0//EN VERSION:2.0 BEGIN:VEVENT UID:bad DTSTART;VALUE=DATE:20020214 DTEND;VALUE=DATE:20020215 DTSTAMP:20020101T121212Z RRULE:FREQ=YEARLY;INTERVAL=1;BYMONTH=2 SUMMARY:Valentine's Day END:VEVENT END:VCALENDAR """.replace("\n", "\r\n") } yield self.simple_event_multiget("/calendar_multiget_events/", okuids, baduids, data) # Now forcibly corrupt one piece of calendar data calendar_path = os.path.join(self.docroot, "calendar_multiget_events/", "bad.ics") f = open(calendar_path, "w") f.write("""BEGIN:VCALENDAR CALSCALE:GREGORIAN PRODID:-//Apple Computer\, Inc//iCal 2.0//EN VERSION:2.0 BEGIN:VEVENT UID:bad DTSTART;VALUE=DATE:20020214 DTEND;VALUE=DATE:20020 DTSTAMP:20020101T121212Z END:VCALENDAR """.replace("\n", "\r\n")) f.close okuids = ["good", ] baduids = ["bad", ] yield self.simple_event_multiget("/calendar_multiget_events/", okuids, baduids, data, no_init=True) def simple_event_multiget(self, cal_uri, okuids, baduids, data=None, no_init=False, withData=True): cal_uri = joinURL("/calendars/users/wsanchez", cal_uri) props = ( davxml.GETETag(), ) if withData: props += ( caldavxml.CalendarData(), ) children = [] children.append(davxml.PropertyContainer(*props)) okhrefs = [joinURL(cal_uri, x + ".ics") for x in okuids] badhrefs = [joinURL(cal_uri, x + ".ics") for x in baduids] for href in okhrefs + badhrefs: children.append(davxml.HRef.fromString(href)) query = caldavxml.CalendarMultiGet(*children) def got_xml(doc): if not isinstance(doc.root_element, davxml.MultiStatus): self.fail("REPORT response XML root element is not multistatus: %r" % (doc.root_element,)) for response in doc.root_element.childrenOfType(davxml.PropertyStatusResponse): href = str(response.childOfType(davxml.HRef)) for propstat in response.childrenOfType(davxml.PropertyStatus): status = propstat.childOfType(davxml.Status) if status.code != responsecode.OK: self.fail("REPORT failed (status %s) to locate properties: %r" % (status.code, href)) properties = propstat.childOfType(davxml.PropertyContainer).children for property in properties: qname = property.qname() if qname == (davxml.dav_namespace, "getetag"): continue if qname != (caldavxml.caldav_namespace, "calendar-data"): self.fail("Response included unexpected property %r" % (property,)) result_calendar = property.calendar() if result_calendar is None: self.fail("Invalid response CalDAV:calendar-data: %r" % (property,)) uid = result_calendar.resourceUID() if uid in okuids: okuids.remove(uid) else: self.fail("Got calendar for unexpected UID %r" % (uid,)) if data: original_calendar = ical.Component.fromString(data[uid]) else: original_filename = file(os.path.join(self.holidays_dir, uid + ".ics")) original_calendar = ical.Component.fromStream(original_filename) self.assertEqual(result_calendar, original_calendar) for response in doc.root_element.childrenOfType(davxml.StatusResponse): href = str(response.childOfType(davxml.HRef)) propstatus = response.childOfType(davxml.PropertyStatus) if propstatus is not None: status = propstatus.childOfType(davxml.Status) else: status = response.childOfType(davxml.Status) if status.code != responsecode.OK: if href in okhrefs: self.fail("REPORT failed (status %s) to locate properties: %r" % (status.code, href)) else: if href in badhrefs: badhrefs.remove(href) continue else: self.fail("Got unexpected href %r" % (href,)) if withData and (len(okuids) + len(badhrefs)): self.fail("Some components were not returned: %r, %r" % (okuids, badhrefs)) return self.calendar_query(cal_uri, query, got_xml, data, no_init) @inlineCallbacks def calendar_query(self, calendar_uri, query, got_xml, data, no_init): if not no_init: response = yield self.send(SimpleStoreRequest(self, "MKCALENDAR", calendar_uri, authid="wsanchez")) response = IResponse(response) if response.code != responsecode.CREATED: self.fail("MKCALENDAR failed: %s" % (response.code,)) if data: for filename, icaldata in data.iteritems(): request = SimpleStoreRequest(self, "PUT", joinURL(calendar_uri, filename + ".ics"), authid="wsanchez") request.stream = MemoryStream(icaldata) yield self.send(request) else: # Add holiday events to calendar for child in FilePath(self.holidays_dir).children(): if os.path.splitext(child.basename())[1] != ".ics": continue request = SimpleStoreRequest(self, "PUT", joinURL(calendar_uri, child.basename()), authid="wsanchez") request.stream = MemoryStream(child.getContent()) yield self.send(request) request = SimpleStoreRequest(self, "REPORT", calendar_uri, authid="wsanchez") request.stream = MemoryStream(query.toxml()) response = yield self.send(request) response = IResponse(response) if response.code != responsecode.MULTI_STATUS: self.fail("REPORT failed: %s" % (response.code,)) returnValue( (yield davXMLFromStream(response.stream).addCallback(got_xml)) ) calendarserver-5.2+dfsg/twistedcaldav/test/test_dateops.py0000644000175000017500000003132112263343324023176 0ustar rahulrahul## # Copyright (c) 2005-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## import twistedcaldav.test.util from twisted.trial.unittest import SkipTest from pycalendar.datetime import PyCalendarDateTime from twistedcaldav.dateops import parseSQLTimestampToPyCalendar, \ parseSQLDateToPyCalendar, pyCalendarTodatetime, \ normalizeForExpand, normalizeForIndex, normalizeToUTC, timeRangesOverlap import datetime import dateutil from pycalendar.timezone import PyCalendarTimezone from twistedcaldav.timezones import TimezoneCache class Dateops(twistedcaldav.test.util.TestCase): """ dateops.py tests """ def setUp(self): super(Dateops, self).setUp() TimezoneCache.create() def test_normalizeForIndex(self): """ Test that dateops.normalizeForIndex works correctly on all four types of date/time: date only, floating, UTC and local time. """ data = ( (PyCalendarDateTime(2012, 1, 1), PyCalendarDateTime(2012, 1, 1, 0, 0, 0)), (PyCalendarDateTime(2012, 1, 1, 10, 0, 0), PyCalendarDateTime(2012, 1, 1, 10, 0, 0)), (PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True))), (PyCalendarDateTime(2012, 1, 1, 12, 0, 0, tzid=PyCalendarTimezone(tzid="America/New_York")), PyCalendarDateTime(2012, 1, 1, 17, 0, 0, tzid=PyCalendarTimezone(utc=True))), ) for value, result in data: self.assertEqual(normalizeForIndex(value), result) def test_normalizeToUTC(self): """ Test that dateops.normalizeToUTC works correctly on all four types of date/time: date only, floating, UTC and local time. """ data = ( (PyCalendarDateTime(2012, 1, 1), PyCalendarDateTime(2012, 1, 1, 0, 0, 0, tzid=PyCalendarTimezone(utc=True))), (PyCalendarDateTime(2012, 1, 1, 10, 0, 0), PyCalendarDateTime(2012, 1, 1, 10, 0, 0, tzid=PyCalendarTimezone(utc=True))), (PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True))), (PyCalendarDateTime(2012, 1, 1, 12, 0, 0, tzid=PyCalendarTimezone(tzid="America/New_York")), PyCalendarDateTime(2012, 1, 1, 17, 0, 0, tzid=PyCalendarTimezone(utc=True))), ) for value, result in data: self.assertEqual(normalizeToUTC(value), result) def test_normalizeForExpand(self): """ Test that dateops.normalizeForExpand works correctly on all four types of date/time: date only, floating, UTC and local time. """ data = ( (PyCalendarDateTime(2012, 1, 1), PyCalendarDateTime(2012, 1, 1)), (PyCalendarDateTime(2012, 1, 1, 10, 0, 0), PyCalendarDateTime(2012, 1, 1, 10, 0, 0)), (PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True))), (PyCalendarDateTime(2012, 1, 1, 12, 0, 0, tzid=PyCalendarTimezone(tzid="America/New_York")), PyCalendarDateTime(2012, 1, 1, 17, 0, 0, tzid=PyCalendarTimezone(utc=True))), ) for value, result in data: self.assertEqual(normalizeForExpand(value), result) def test_floatoffset(self): raise SkipTest("test unimplemented") def test_adjustFloatingToTimezone(self): raise SkipTest("test unimplemented") def test_compareDateTime(self): raise SkipTest("test unimplemented") def test_differenceDateTime(self): raise SkipTest("test unimplemented") def test_timeRangesOverlap(self): data = ( # Timed ( "Start within, end within - overlap", PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 1, 12, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 1, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), True, ), ( "Start before, end before - no overlap", PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 1, 12, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 3, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), False, ), ( "Start before, end right before - no overlap", PyCalendarDateTime(2012, 1, 1, 23, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 3, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), False, ), ( "Start before, end within - overlap", PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 3, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), True, ), ( "Start after, end after - no overlap", PyCalendarDateTime(2012, 1, 2, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 12, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 1, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), False, ), ( "Start right after, end after - no overlap", PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 1, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 1, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), False, ), ( "Start within, end after - overlap", PyCalendarDateTime(2012, 1, 1, 12, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 12, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 1, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), True, ), ( "Start before, end after - overlap", PyCalendarDateTime(2012, 1, 1, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 3, 11, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 2, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 3, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), True, ), # All day ( "All day: Start within, end within - overlap", PyCalendarDateTime(2012, 1, 9), PyCalendarDateTime(2012, 1, 10), PyCalendarDateTime(2012, 1, 8, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 15, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), True, ), ( "All day: Start before, end before - no overlap", PyCalendarDateTime(2012, 1, 1), PyCalendarDateTime(2012, 1, 2), PyCalendarDateTime(2012, 1, 8, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 15, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), False, ), ( "All day: Start before, end right before - no overlap", PyCalendarDateTime(2012, 1, 7), PyCalendarDateTime(2012, 1, 8), PyCalendarDateTime(2012, 1, 8, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 15, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), False, ), ( "All day: Start before, end within - overlap", PyCalendarDateTime(2012, 1, 7), PyCalendarDateTime(2012, 1, 9), PyCalendarDateTime(2012, 1, 8, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 15, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), True, ), ( "All day: Start after, end after - no overlap", PyCalendarDateTime(2012, 1, 16), PyCalendarDateTime(2012, 1, 17), PyCalendarDateTime(2012, 1, 8, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 15, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), False, ), ( "All day: Start right after, end after - no overlap", PyCalendarDateTime(2012, 1, 15), PyCalendarDateTime(2012, 1, 16), PyCalendarDateTime(2012, 1, 8, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 15, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), False, ), ( "All day: Start within, end after - overlap", PyCalendarDateTime(2012, 1, 14), PyCalendarDateTime(2012, 1, 16), PyCalendarDateTime(2012, 1, 8, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 15, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), True, ), ( "All day: Start before, end after - overlap", PyCalendarDateTime(2012, 1, 7), PyCalendarDateTime(2012, 1, 16), PyCalendarDateTime(2012, 1, 8, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), PyCalendarDateTime(2012, 1, 15, 0, 0, 0, tzid=PyCalendarTimezone(utc=True)), True, ), ) for title, start1, end1, start2, end2, result in data: self.assertEqual(timeRangesOverlap(start1, end1, start2, end2), result, msg="Failed: %s" % (title,)) def test_normalizePeriodList(self): raise SkipTest("test unimplemented") def test_clipPeriod(self): raise SkipTest("test unimplemented") def test_pyCalendarTodatetime(self): """ dateops.pyCalendarTodatetime """ tests = ( (PyCalendarDateTime(2012, 4, 4, 12, 34, 56), datetime.datetime(2012, 4, 4, 12, 34, 56, tzinfo=dateutil.tz.tzutc())), (PyCalendarDateTime(2012, 12, 31), datetime.date(2012, 12, 31)), ) for pycal, result in tests: self.assertEqual(pyCalendarTodatetime(pycal), result) def test_parseSQLTimestampToPyCalendar(self): """ dateops.parseSQLTimestampToPyCalendar """ tests = ( ("2012-04-04 12:34:56", PyCalendarDateTime(2012, 4, 4, 12, 34, 56)), ("2012-12-31 01:01:01", PyCalendarDateTime(2012, 12, 31, 1, 1, 1)), ) for sqlStr, result in tests: self.assertEqual(parseSQLTimestampToPyCalendar(sqlStr), result) def test_parseSQLDateToPyCalendar(self): """ dateops.parseSQLDateToPyCalendar """ tests = ( ("2012-04-04", PyCalendarDateTime(2012, 4, 4)), ("2012-12-31 00:00:00", PyCalendarDateTime(2012, 12, 31)), ) for sqlStr, result in tests: self.assertEqual(parseSQLDateToPyCalendar(sqlStr), result) def test_datetimeMktime(self): raise SkipTest("test unimplemented") calendarserver-5.2+dfsg/twistedcaldav/test/test_memcachelock.py0000644000175000017500000002120212145176743024157 0ustar rahulrahul# Copyright (c) 2007 Twisted Matrix Laboratories. # See LICENSE for details. """ Test the memcache client protocol. """ from twext.protocols.memcache import MemCacheProtocol from twisted.test.proto_helpers import StringTransportWithDisconnection from twisted.internet.task import Clock from twisted.internet.defer import inlineCallbacks from twistedcaldav.memcachelock import MemcacheLock, MemcacheLockTimeoutError from twistedcaldav.test.util import TestCase class MemCacheTestCase(TestCase): """ Test client protocol class L{MemCacheProtocol}. """ class FakedMemcacheLock(MemcacheLock): def __init__(self, faked, namespace, locktoken, timeout=5.0, retry_interval=0.1, expire_time=0): """ @param namespace: a unique namespace for this lock's tokens @type namespace: C{str} @param locktoken: the name of the locktoken @type locktoken: C{str} @param timeout: the maximum time in seconds that the lock should block @type timeout: C{float} @param retry_interval: the interval to retry acquiring the lock @type retry_interval: C{float} @param expiryTime: the time in seconds for the lock to expire. Zero: no expiration. @type expiryTime: C{float} """ super(MemCacheTestCase.FakedMemcacheLock, self).__init__(namespace, locktoken, timeout, retry_interval, expire_time) self.faked = faked def _getMemcacheProtocol(self): return self.faked def setUp(self): """ Create a memcache client, connect it to a string protocol, and make it use a deterministic clock. """ TestCase.setUp(self) self.proto = MemCacheProtocol() self.clock = Clock() self.proto.callLater = self.clock.callLater self.transport = StringTransportWithDisconnection() self.transport.protocol = self.proto self.proto.makeConnection(self.transport) def _test(self, d, send, recv, result): """ Shortcut method for classic tests. @param d: the resulting deferred from the memcache command. @type d: C{Deferred} @param send: the expected data to be sent. @type send: C{str} @param recv: the data to simulate as reception. @type recv: C{str} @param result: the expected result. @type result: C{any} """ def cb(res): self.assertEquals(res, result) self.assertEquals(self.transport.value(), send) self.transport.clear() d.addCallback(cb) self.proto.dataReceived(recv) return d def test_get(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", "locking") return self._test( lock.get("foo"), "get lock:foo-acbd18db4cc2f85cedef654fccc4a4d8\r\n", "VALUE lock:foo-acbd18db4cc2f85cedef654fccc4a4d8 0 3\r\nbar\r\nEND\r\n", "bar" ) def test_set(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", "locking") return self._test( lock.set("foo", "bar"), "set lock:foo-acbd18db4cc2f85cedef654fccc4a4d8 0 0 3\r\nbar\r\n", "STORED\r\n", True ) @inlineCallbacks def test_acquire(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", "locking") yield self._test( lock.acquire(), "add lock:locking-559159aa00cc525bfe5c4b34cf16cccb 0 0 1\r\n1\r\n", "STORED\r\n", True ) self.assertTrue(lock._hasLock) @inlineCallbacks def test_acquire_ok_timeout_0(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", "locking", timeout=0) yield self._test( lock.acquire(), "add lock:locking-559159aa00cc525bfe5c4b34cf16cccb 0 0 1\r\n1\r\n", "STORED\r\n", True ) self.assertTrue(lock._hasLock) @inlineCallbacks def test_acquire_fails_timeout_0(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", "locking", timeout=0) try: yield self._test( lock.acquire(), "add lock:locking-559159aa00cc525bfe5c4b34cf16cccb 0 0 1\r\n1\r\n", "NOT_STORED\r\n", True ) except MemcacheLockTimeoutError: pass except Exception, e: self.fail("Unknown exception thrown: %s" % (e,)) else: self.fail("No timeout exception thrown") self.assertFalse(lock._hasLock) @inlineCallbacks def test_acquire_release(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", "locking") yield self._test( lock.acquire(), "add lock:locking-559159aa00cc525bfe5c4b34cf16cccb 0 0 1\r\n1\r\n", "STORED\r\n", True ) self.assertTrue(lock._hasLock) yield self._test( lock.release(), "delete lock:locking-559159aa00cc525bfe5c4b34cf16cccb\r\n", "DELETED\r\n", True ) self.assertFalse(lock._hasLock) @inlineCallbacks def test_acquire_clean(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", "locking") yield self._test( lock.acquire(), "add lock:locking-559159aa00cc525bfe5c4b34cf16cccb 0 0 1\r\n1\r\n", "STORED\r\n", True ) yield self._test( lock.clean(), "delete lock:locking-559159aa00cc525bfe5c4b34cf16cccb\r\n", "DELETED\r\n", True ) @inlineCallbacks def test_acquire_unicode(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", u"locking") yield self._test( lock.acquire(), "add lock:locking-559159aa00cc525bfe5c4b34cf16cccb 0 0 1\r\n1\r\n", "STORED\r\n", True ) self.assertTrue(lock._hasLock) @inlineCallbacks def test_acquire_invalid_token1(self): """ L{MemCacheProtocol.get} should return a L{Deferred} which is called back with the value and the flag associated with the given key if the server returns a successful result. """ try: lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", 1) yield lock.acquire() self.fail("AssertionError not raised") except AssertionError: pass except: self.fail("AssertionError not raised") try: lock = MemCacheTestCase.FakedMemcacheLock(self.proto, "lock", ("abc",)) yield lock.acquire() self.fail("AssertionError not raised") except AssertionError: pass except: self.fail("AssertionError not raised") calendarserver-5.2+dfsg/twistedcaldav/test/test_extensions.py0000644000175000017500000002273612263343324023750 0ustar rahulrahul# -*- coding: utf-8 -*- ## # Copyright (c) 2009-2014 Apple Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## from twext.python.filepath import CachingFilePath as FilePath from twext.web2.http_headers import MimeType from twext.web2.static import MetaDataMixin from twisted.internet.defer import inlineCallbacks, Deferred, succeed from twisted.web.microdom import parseString from twistedcaldav.extensions import DAVFile, DAVResourceWithChildrenMixin, extractCalendarServerPrincipalSearchData, validateTokens from twistedcaldav.test.util import TestCase from txdav.xml.element import WebDAVElement, ResourceType from txdav.xml.parser import WebDAVDocument from xml.etree.cElementTree import XML class UnicodeProperty(WebDAVElement): """ An element with a unicode name. """ name = u'unicode' allowed_children = {} class StrProperty(WebDAVElement): """ An element with a unicode name. """ name = 'str' allowed_children = {} class SimpleFakeRequest(object): """ Emulate a very small portion of the web2 'Request' API, just enough to render a L{DAVFile}. @ivar path: the path portion of the URL being rendered. """ def __init__(self, path): self.path = path def urlForResource(self, resource): """ @return: this L{SimpleFakeRequest}'s 'path' attribute, since this request can render only one thing. """ return self.path def browserHTML2ETree(htmlString): """ Loosely interpret an HTML string as XML and return an ElementTree object for it. We're not promising strict XML (in fact, we're specifically saying HTML) in the content-type of certain responses, but it's much easier to work with the ElementTree data structures present in Python 2.5+ for testing; so we'll use Twisted's built-in facilities to sanitize the inputs before making any structured assertions about them. A more precise implementation would use U{HTML5Lib}'s etree bindings to do the parsing, as that is more directly 'what a browser would do', but Twisted's built-in stuff is a good approximation and doesn't drag in another dependency. @param htmlString: a L{str}, encoded in UTF-8, representing a pile of browser-friendly HTML tag soup. @return: an object implementing the standard library ElementTree interface. """ return XML(parseString(htmlString, beExtremelyLenient=True).toxml()) nonASCIIFilename = u"アニメ.txt" class DirectoryListingTest(TestCase): """ Test cases for HTML directory listing. """ @inlineCallbacks def doDirectoryTest(self, addedNames, modify=lambda x: None, expectedNames=None): """ Do a test of a L{DAVFile} pointed at a directory, verifying that files existing with the given names will be faithfully 'played back' via HTML rendering. """ if expectedNames is None: expectedNames = addedNames fp = FilePath(self.mktemp()) fp.createDirectory() for sampleName in expectedNames: fp.child(sampleName).touch() df = DAVFile(fp) modify(df) responseText = (yield df.render(SimpleFakeRequest('/'))).stream.read() responseXML = browserHTML2ETree(responseText) names = set([element.text.encode("utf-8") for element in responseXML.findall(".//a")]) self.assertEquals(set(expectedNames), names) def test_simpleList(self): """ Rendering a L{DAVFile} that is backed by a directory will produce an HTML document including links to its contents. """ return self.doDirectoryTest([u'gamma.txt', u'beta.html', u'alpha.xml']) def test_emptyList(self): """ Listing a directory with no files in it will produce an index with no links. """ return self.doDirectoryTest([]) def test_nonASCIIList(self): """ Listing a directory with a file in it that includes characters that fall outside of the 'Basic Latin' and 'Latin-1 Supplement' unicode blocks should result in those characters being rendered as links in the index. """ return self.doDirectoryTest([nonASCIIFilename.encode("utf-8")]) @inlineCallbacks def test_nonASCIIListMixedChildren(self): """ Listing a directory that contains unicode content-type metadata and non-ASCII characters in a filename should result in a listing that contains the names of both entities. """ unicodeChildName = "test" def addUnicodeChild(davFile): m = MetaDataMixin() m.contentType = lambda: MimeType.fromString('text/plain') m.resourceType = lambda: ResourceType() m.isCollection = lambda: False davFile.putChild(unicodeChildName, m) yield self.doDirectoryTest([nonASCIIFilename], addUnicodeChild, [nonASCIIFilename.encode("utf-8"), unicodeChildName]) @inlineCallbacks def test_nonASCIIListMixedProperties(self): """ Listing a directory that contains unicode DAV properties and non-ASCII characters in a filename should result in a listing that contains the names of both entities. """ def addUnicodeChild(davFile): davFile.writeProperty(UnicodeProperty(), None) davFile.writeProperty(StrProperty(), None) yield self.doDirectoryTest([nonASCIIFilename], addUnicodeChild, [nonASCIIFilename.encode("utf-8")]) def test_quotedCharacters(self): """ Filenames might contain < or > characters, which need to be quoted in HTML. """ return self.doDirectoryTest([u'.txt', u'