calendarserver-5.2+dfsg/ 0000755 0001750 0001750 00000000000 12322625327 014313 5 ustar rahul rahul calendarserver-5.2+dfsg/python 0000755 0001750 0001750 00000000157 11615630672 015570 0 ustar rahul rahul #!/usr/bin/env bash
wd="$(cd "$(dirname "$0")" && pwd)";
. "${wd}/support/shell.sh"
exec "${python}" "$@";
calendarserver-5.2+dfsg/conf/ 0000755 0001750 0001750 00000000000 12322625306 015235 5 ustar rahul rahul calendarserver-5.2+dfsg/conf/remoteservers.xml 0000644 0001750 0001750 00000001665 12263343324 020675 0 ustar rahul rahul
calendarserver-5.2+dfsg/conf/caldavd-partitioning-secondary.plist 0000644 0001750 0001750 00000004400 12263343324 024401 0 ustar rahul rahul
Servers
Enabled
ConfigFile
localservers.xml
MaxClients
5
ServerPartitionID
00002
ProxyDBService
type
twistedcaldav.directory.calendaruserproxy.ProxyPostgreSQLDB
params
host
localhost
database
proxies
Memcached
Pools
CommonToAllNodes
ClientEnabled
ServerEnabled
BindAddress
localhost
Port
11311
HandleCacheTypes
ProxyDB
PrincipalToken
DIGESTCREDENTIALS
MaxClients
5
memcached
../memcached/_root/bin/memcached
Options
calendarserver-5.2+dfsg/conf/mime.types 0000644 0001750 0001750 00000035254 10535615510 017263 0 ustar rahul rahul # This is a comment. I love comments.
# This file controls what Internet media types are sent to the client for
# given file extension(s). Sending the correct media type to the client
# is important so they know how to handle the content of the file.
# Extra types can either be added here or by using an AddType directive
# in your config files. For more information about Internet media types,
# please read RFC 2045, 2046, 2047, 2048, and 2077. The Internet media type
# registry is at .
# MIME type Extensions
application/activemessage
application/andrew-inset ez
application/applefile
application/atom+xml atom
application/atomicmail
application/batch-smtp
application/beep+xml
application/cals-1840
application/cnrp+xml
application/commonground
application/cpl+xml
application/cybercash
application/dca-rft
application/dec-dx
application/dvcs
application/edi-consent
application/edifact
application/edi-x12
application/eshop
application/font-tdpfr
application/http
application/hyperstudio
application/iges
application/index
application/index.cmd
application/index.obj
application/index.response
application/index.vnd
application/iotp
application/ipp
application/isup
application/mac-binhex40 hqx
application/mac-compactpro cpt
application/macwriteii
application/marc
application/mathematica
application/mathml+xml mathml
application/msword doc
application/news-message-id
application/news-transmission
application/ocsp-request
application/ocsp-response
application/octet-stream bin dms lha lzh exe class so dll dmg
application/oda oda
application/ogg ogg
application/parityfec
application/pdf pdf
application/pgp-encrypted
application/pgp-keys
application/pgp-signature
application/pkcs10
application/pkcs7-mime
application/pkcs7-signature
application/pkix-cert
application/pkix-crl
application/pkixcmp
application/postscript ai eps ps
application/prs.alvestrand.titrax-sheet
application/prs.cww
application/prs.nprend
application/prs.plucker
application/qsig
application/rdf+xml rdf
application/reginfo+xml
application/remote-printing
application/riscos
application/rtf
application/sdp
application/set-payment
application/set-payment-initiation
application/set-registration
application/set-registration-initiation
application/sgml
application/sgml-open-catalog
application/sieve
application/slate
application/smil smi smil
application/srgs gram
application/srgs+xml grxml
application/timestamp-query
application/timestamp-reply
application/tve-trigger
application/vemmi
application/vnd.3gpp.pic-bw-large
application/vnd.3gpp.pic-bw-small
application/vnd.3gpp.pic-bw-var
application/vnd.3gpp.sms
application/vnd.3m.post-it-notes
application/vnd.accpac.simply.aso
application/vnd.accpac.simply.imp
application/vnd.acucobol
application/vnd.acucorp
application/vnd.adobe.xfdf
application/vnd.aether.imp
application/vnd.amiga.ami
application/vnd.anser-web-certificate-issue-initiation
application/vnd.anser-web-funds-transfer-initiation
application/vnd.audiograph
application/vnd.blueice.multipass
application/vnd.bmi
application/vnd.businessobjects
application/vnd.canon-cpdl
application/vnd.canon-lips
application/vnd.cinderella
application/vnd.claymore
application/vnd.commerce-battelle
application/vnd.commonspace
application/vnd.contact.cmsg
application/vnd.cosmocaller
application/vnd.criticaltools.wbs+xml
application/vnd.ctc-posml
application/vnd.cups-postscript
application/vnd.cups-raster
application/vnd.cups-raw
application/vnd.curl
application/vnd.cybank
application/vnd.data-vision.rdz
application/vnd.dna
application/vnd.dpgraph
application/vnd.dreamfactory
application/vnd.dxr
application/vnd.ecdis-update
application/vnd.ecowin.chart
application/vnd.ecowin.filerequest
application/vnd.ecowin.fileupdate
application/vnd.ecowin.series
application/vnd.ecowin.seriesrequest
application/vnd.ecowin.seriesupdate
application/vnd.enliven
application/vnd.epson.esf
application/vnd.epson.msf
application/vnd.epson.quickanime
application/vnd.epson.salt
application/vnd.epson.ssf
application/vnd.ericsson.quickcall
application/vnd.eudora.data
application/vnd.fdf
application/vnd.ffsns
application/vnd.fints
application/vnd.flographit
application/vnd.framemaker
application/vnd.fsc.weblaunch
application/vnd.fujitsu.oasys
application/vnd.fujitsu.oasys2
application/vnd.fujitsu.oasys3
application/vnd.fujitsu.oasysgp
application/vnd.fujitsu.oasysprs
application/vnd.fujixerox.ddd
application/vnd.fujixerox.docuworks
application/vnd.fujixerox.docuworks.binder
application/vnd.fut-misnet
application/vnd.grafeq
application/vnd.groove-account
application/vnd.groove-help
application/vnd.groove-identity-message
application/vnd.groove-injector
application/vnd.groove-tool-message
application/vnd.groove-tool-template
application/vnd.groove-vcard
application/vnd.hbci
application/vnd.hhe.lesson-player
application/vnd.hp-hpgl
application/vnd.hp-hpid
application/vnd.hp-hps
application/vnd.hp-pcl
application/vnd.hp-pclxl
application/vnd.httphone
application/vnd.hzn-3d-crossword
application/vnd.ibm.afplinedata
application/vnd.ibm.electronic-media
application/vnd.ibm.minipay
application/vnd.ibm.modcap
application/vnd.ibm.rights-management
application/vnd.ibm.secure-container
application/vnd.informix-visionary
application/vnd.intercon.formnet
application/vnd.intertrust.digibox
application/vnd.intertrust.nncp
application/vnd.intu.qbo
application/vnd.intu.qfx
application/vnd.irepository.package+xml
application/vnd.is-xpr
application/vnd.japannet-directory-service
application/vnd.japannet-jpnstore-wakeup
application/vnd.japannet-payment-wakeup
application/vnd.japannet-registration
application/vnd.japannet-registration-wakeup
application/vnd.japannet-setstore-wakeup
application/vnd.japannet-verification
application/vnd.japannet-verification-wakeup
application/vnd.jisp
application/vnd.kde.karbon
application/vnd.kde.kchart
application/vnd.kde.kformula
application/vnd.kde.kivio
application/vnd.kde.kontour
application/vnd.kde.kpresenter
application/vnd.kde.kspread
application/vnd.kde.kword
application/vnd.kenameaapp
application/vnd.koan
application/vnd.liberty-request+xml
application/vnd.llamagraphics.life-balance.desktop
application/vnd.llamagraphics.life-balance.exchange+xml
application/vnd.lotus-1-2-3
application/vnd.lotus-approach
application/vnd.lotus-freelance
application/vnd.lotus-notes
application/vnd.lotus-organizer
application/vnd.lotus-screencam
application/vnd.lotus-wordpro
application/vnd.mcd
application/vnd.mediastation.cdkey
application/vnd.meridian-slingshot
application/vnd.micrografx.flo
application/vnd.micrografx.igx
application/vnd.mif mif
application/vnd.minisoft-hp3000-save
application/vnd.mitsubishi.misty-guard.trustweb
application/vnd.mobius.daf
application/vnd.mobius.dis
application/vnd.mobius.mbk
application/vnd.mobius.mqy
application/vnd.mobius.msl
application/vnd.mobius.plc
application/vnd.mobius.txf
application/vnd.mophun.application
application/vnd.mophun.certificate
application/vnd.motorola.flexsuite
application/vnd.motorola.flexsuite.adsi
application/vnd.motorola.flexsuite.fis
application/vnd.motorola.flexsuite.gotap
application/vnd.motorola.flexsuite.kmr
application/vnd.motorola.flexsuite.ttc
application/vnd.motorola.flexsuite.wem
application/vnd.mozilla.xul+xml xul
application/vnd.ms-artgalry
application/vnd.ms-asf
application/vnd.ms-excel xls
application/vnd.ms-lrm
application/vnd.ms-powerpoint ppt
application/vnd.ms-project
application/vnd.ms-tnef
application/vnd.ms-works
application/vnd.ms-wpl
application/vnd.mseq
application/vnd.msign
application/vnd.music-niff
application/vnd.musician
application/vnd.netfpx
application/vnd.noblenet-directory
application/vnd.noblenet-sealer
application/vnd.noblenet-web
application/vnd.novadigm.edm
application/vnd.novadigm.edx
application/vnd.novadigm.ext
application/vnd.obn
application/vnd.osa.netdeploy
application/vnd.palm
application/vnd.pg.format
application/vnd.pg.osasli
application/vnd.powerbuilder6
application/vnd.powerbuilder6-s
application/vnd.powerbuilder7
application/vnd.powerbuilder7-s
application/vnd.powerbuilder75
application/vnd.powerbuilder75-s
application/vnd.previewsystems.box
application/vnd.publishare-delta-tree
application/vnd.pvi.ptid1
application/vnd.pwg-multiplexed
application/vnd.pwg-xhtml-print+xml
application/vnd.quark.quarkxpress
application/vnd.rapid
application/vnd.s3sms
application/vnd.sealed.net
application/vnd.seemail
application/vnd.shana.informed.formdata
application/vnd.shana.informed.formtemplate
application/vnd.shana.informed.interchange
application/vnd.shana.informed.package
application/vnd.smaf
application/vnd.sss-cod
application/vnd.sss-dtf
application/vnd.sss-ntf
application/vnd.street-stream
application/vnd.svd
application/vnd.swiftview-ics
application/vnd.triscape.mxs
application/vnd.trueapp
application/vnd.truedoc
application/vnd.ufdl
application/vnd.uplanet.alert
application/vnd.uplanet.alert-wbxml
application/vnd.uplanet.bearer-choice
application/vnd.uplanet.bearer-choice-wbxml
application/vnd.uplanet.cacheop
application/vnd.uplanet.cacheop-wbxml
application/vnd.uplanet.channel
application/vnd.uplanet.channel-wbxml
application/vnd.uplanet.list
application/vnd.uplanet.list-wbxml
application/vnd.uplanet.listcmd
application/vnd.uplanet.listcmd-wbxml
application/vnd.uplanet.signal
application/vnd.vcx
application/vnd.vectorworks
application/vnd.vidsoft.vidconference
application/vnd.visio
application/vnd.visionary
application/vnd.vividence.scriptfile
application/vnd.vsf
application/vnd.wap.sic
application/vnd.wap.slc
application/vnd.wap.wbxml wbxml
application/vnd.wap.wmlc wmlc
application/vnd.wap.wmlscriptc wmlsc
application/vnd.webturbo
application/vnd.wrq-hp3000-labelled
application/vnd.wt.stf
application/vnd.wv.csp+wbxml
application/vnd.xara
application/vnd.xfdl
application/vnd.yamaha.hv-dic
application/vnd.yamaha.hv-script
application/vnd.yamaha.hv-voice
application/vnd.yellowriver-custom-menu
application/voicexml+xml vxml
application/watcherinfo+xml
application/whoispp-query
application/whoispp-response
application/wita
application/wordperfect5.1
application/x-bcpio bcpio
application/x-cdlink vcd
application/x-chess-pgn pgn
application/x-compress
application/x-cpio cpio
application/x-csh csh
application/x-director dcr dir dxr
application/x-dvi dvi
application/x-futuresplash spl
application/x-gtar gtar
application/x-gzip
application/x-hdf hdf
application/x-javascript js
application/x-koan skp skd skt skm
application/x-latex latex
application/x-netcdf nc cdf
application/x-sh sh
application/x-shar shar
application/x-shockwave-flash swf
application/x-stuffit sit
application/x-sv4cpio sv4cpio
application/x-sv4crc sv4crc
application/x-tar tar
application/x-tcl tcl
application/x-tex tex
application/x-texinfo texinfo texi
application/x-troff t tr roff
application/x-troff-man man
application/x-troff-me me
application/x-troff-ms ms
application/x-ustar ustar
application/x-wais-source src
application/x400-bp
application/xhtml+xml xhtml xht
application/xslt+xml xslt
application/xml xml xsl
application/xml-dtd dtd
application/xml-external-parsed-entity
application/zip zip
audio/32kadpcm
audio/amr
audio/amr-wb
audio/basic au snd
audio/cn
audio/dat12
audio/dsr-es201108
audio/dvi4
audio/evrc
audio/evrc0
audio/g722
audio/g.722.1
audio/g723
audio/g726-16
audio/g726-24
audio/g726-32
audio/g726-40
audio/g728
audio/g729
audio/g729D
audio/g729E
audio/gsm
audio/gsm-efr
audio/l8
audio/l16
audio/l20
audio/l24
audio/lpc
audio/midi mid midi kar
audio/mpa
audio/mpa-robust
audio/mp4a-latm
audio/mpeg mpga mp2 mp3
audio/parityfec
audio/pcma
audio/pcmu
audio/prs.sid
audio/qcelp
audio/red
audio/smv
audio/smv0
audio/telephone-event
audio/tone
audio/vdvi
audio/vnd.3gpp.iufp
audio/vnd.cisco.nse
audio/vnd.cns.anp1
audio/vnd.cns.inf1
audio/vnd.digital-winds
audio/vnd.everad.plj
audio/vnd.lucent.voice
audio/vnd.nortel.vbk
audio/vnd.nuera.ecelp4800
audio/vnd.nuera.ecelp7470
audio/vnd.nuera.ecelp9600
audio/vnd.octel.sbc
audio/vnd.qcelp
audio/vnd.rhetorex.32kadpcm
audio/vnd.vmx.cvsd
audio/x-aiff aif aiff aifc
audio/x-alaw-basic
audio/x-mpegurl m3u
audio/x-pn-realaudio ram ra
audio/x-pn-realaudio-plugin
application/vnd.rn-realmedia rm
audio/x-wav wav
chemical/x-pdb pdb
chemical/x-xyz xyz
image/bmp bmp
image/cgm cgm
image/g3fax
image/gif gif
image/ief ief
image/jpeg jpeg jpg jpe
image/naplps
image/png png
image/prs.btif
image/prs.pti
image/svg+xml svg
image/t38
image/tiff tiff tif
image/tiff-fx
image/vnd.cns.inf2
image/vnd.djvu djvu djv
image/vnd.dwg
image/vnd.dxf
image/vnd.fastbidsheet
image/vnd.fpx
image/vnd.fst
image/vnd.fujixerox.edmics-mmr
image/vnd.fujixerox.edmics-rlc
image/vnd.globalgraphics.pgb
image/vnd.mix
image/vnd.ms-modi
image/vnd.net-fpx
image/vnd.svf
image/vnd.wap.wbmp wbmp
image/vnd.xiff
image/x-cmu-raster ras
image/x-icon ico
image/x-portable-anymap pnm
image/x-portable-bitmap pbm
image/x-portable-graymap pgm
image/x-portable-pixmap ppm
image/x-rgb rgb
image/x-xbitmap xbm
image/x-xpixmap xpm
image/x-xwindowdump xwd
message/delivery-status
message/disposition-notification
message/external-body
message/http
message/news
message/partial
message/rfc822
message/s-http
message/sip
message/sipfrag
model/iges igs iges
model/mesh msh mesh silo
model/vnd.dwf
model/vnd.flatland.3dml
model/vnd.gdl
model/vnd.gs-gdl
model/vnd.gtw
model/vnd.mts
model/vnd.parasolid.transmit.binary
model/vnd.parasolid.transmit.text
model/vnd.vtu
model/vrml wrl vrml
multipart/alternative
multipart/appledouble
multipart/byteranges
multipart/digest
multipart/encrypted
multipart/form-data
multipart/header-set
multipart/mixed
multipart/parallel
multipart/related
multipart/report
multipart/signed
multipart/voice-message
text/calendar ics ifb
text/css css
text/directory
text/enriched
text/html html htm
text/parityfec
text/plain asc txt
text/prs.lines.tag
text/rfc822-headers
text/richtext rtx
text/rtf rtf
text/sgml sgml sgm
text/t140
text/tab-separated-values tsv
text/uri-list
text/vnd.abc
text/vnd.curl
text/vnd.dmclientscript
text/vnd.fly
text/vnd.fmi.flexstor
text/vnd.in3d.3dml
text/vnd.in3d.spot
text/vnd.iptc.nitf
text/vnd.iptc.newsml
text/vnd.latex-z
text/vnd.motorola.reflex
text/vnd.ms-mediapackage
text/vnd.net2phone.commcenter.command
text/vnd.sun.j2me.app-descriptor
text/vnd.wap.si
text/vnd.wap.sl
text/vnd.wap.wml wml
text/vnd.wap.wmlscript wmls
text/x-setext etx
text/xml
text/xml-external-parsed-entity
video/bmpeg
video/bt656
video/celb
video/dv
video/h261
video/h263
video/h263-1998
video/h263-2000
video/jpeg
video/mp1s
video/mp2p
video/mp2t
video/mp4v-es
video/mpv
video/mpeg mpeg mpg mpe
video/nv
video/parityfec
video/pointer
video/quicktime qt mov
video/smpte292m
video/vnd.fvt
video/vnd.motorola.video
video/vnd.motorola.videop
video/vnd.mpegurl mxu m4u
video/vnd.nokia.interleaved-multimedia
video/vnd.objectvideo
video/vnd.vivo
video/x-msvideo avi
video/x-sgi-movie movie
x-conference/x-cooltalk ice
calendarserver-5.2+dfsg/conf/servers.dtd 0000644 0001750 0001750 00000001662 12263343324 017431 0 ustar rahul rahul
calendarserver-5.2+dfsg/conf/test/ 0000755 0001750 0001750 00000000000 12322625306 016214 5 ustar rahul rahul calendarserver-5.2+dfsg/conf/test/accounts.xml 0000644 0001750 0001750 00000011225 12263343324 020557 0 ustar rahul rahul
admin
admin
admin
Super User
apprentice
apprentice
Apprentice Super User
apprentice
wsanchez
wsanchez
wsanchez@example.com
Wilfredo Sanchez Vega
test
cdaboo
cdaboo
cdaboo@example.com
Cyrus Daboo
test
sagen
sagen
sagen@example.com
Morgen Sagen
test
andre
dre
dre@example.com
Andre LaBranche
test
glyph
glyph
glyph@example.com
Glyph Lefkowitz
test
i18nuser
i18nuser
i18nuser@example.com
まだ
i18nuser
user%02d
user%02d
User %02d
User %02d
user%02d@example.com
user%02d
public%02d
public%02d
Public %02d
public%02d
group01
group01
Group 01
group01
user01
group02
group02
Group 02
group02
user06
user07
group03
group03
Group 03
group03
user08
user09
group04
group04
Group 04
group04
group02
group03
user10
group05
group05
Group 05
group05
group06
user20
group06
group06
Group 06
group06
user21
group07
group07
Group 07
group07
user22
user23
user24
disabledgroup
disabledgroup
Disabled Group
disabledgroup
user01
calendarserver-5.2+dfsg/conf/auth/ 0000755 0001750 0001750 00000000000 12322625306 016176 5 ustar rahul rahul calendarserver-5.2+dfsg/conf/auth/accounts.xml 0000644 0001750 0001750 00000002317 12263343324 020543 0 ustar rahul rahul
admin
admin
Super User
test
test
Test User
users
users
Users Group
test
mercury
mercury
Mecury Conference Room, Building 1, 2nd Floor
calendarserver-5.2+dfsg/conf/auth/accounts.dtd 0000644 0001750 0001750 00000003001 12263343324 020505 0 ustar rahul rahul
>
calendarserver-5.2+dfsg/conf/auth/resources-test.xml 0000755 0001750 0001750 00000015244 12262624042 021717 0 ustar rahul rahul
fantastic
4D66A20A-1437-437D-8069-2F14E8322234
Fantastic Conference Room
63A2F949-2D8D-4C8D-B8A5-DCF2A94610F3
jupiter
jupiter
Jupiter Conference Room, Building 2, 1st Floor
uranus
uranus
Uranus Conference Room, Building 3, 1st Floor
morgensroom
03DFF660-8BCC-4198-8588-DD77F776F518
Morgen's Room
mercury
mercury
Mercury Conference Room, Building 1, 2nd Floor
location09
location09
Room 09
location08
location08
Room 08
location07
location07
Room 07
location06
location06
Room 06
location05
location05
Room 05
location04
location04
Room 04
location03
location03
Room 03
location02
location02
Room 02
location01
location01
Room 01
delegatedroom
delegatedroom
Delegated Conference Room
mars
redplanet
Mars Conference Room, Building 1, 1st Floor
sharissroom
80689D41-DAF8-4189-909C-DB017B271892
Shari's Room
6F9EE33B-78F6-481B-9289-3D0812FF0D64
pluto
pluto
Pluto Conference Room, Building 2, 1st Floor
saturn
saturn
Saturn Conference Room, Building 2, 1st Floor
location10
location10
Room 10
pretend
06E3BDCB-9C19-485A-B14E-F146A80ADDC6
Pretend Conference Room
76E7ECA6-08BC-4AE7-930D-F2E7453993A5
neptune
neptune
Neptune Conference Room, Building 2, 1st Floor
Earth
Earth
Earth Conference Room, Building 1, 1st Floor
venus
venus
Venus Conference Room, Building 1, 2nd Floor
sharisotherresource
CCE95217-A57B-481A-AC3D-FEC9AB6CE3A9
Shari's Other Resource
resource15
resource15
Resource 15
resource14
resource14
Resource 14
resource17
resource17
Resource 17
resource16
resource16
Resource 16
resource11
resource11
Resource 11
resource10
resource10
Resource 10
resource13
resource13
Resource 13
resource12
resource12
Resource 12
resource19
resource19
Resource 19
resource18
resource18
Resource 18
sharisresource
C38BEE7A-36EE-478C-9DCB-CBF4612AFE65
Shari's Resource
resource20
resource20
Resource 20
resource06
resource06
Resource 06
resource07
resource07
Resource 07
resource04
resource04
Resource 04
resource05
resource05
Resource 05
resource02
resource02
Resource 02
resource03
resource03
Resource 03
resource01
resource01
Resource 01
sharisotherresource1
0CE0BF31-5F9E-4801-A489-8C70CF287F5F
Shari's Other Resource1
resource08
resource08
Resource 08
resource09
resource09
Resource 09
testaddress1
6F9EE33B-78F6-481B-9289-3D0812FF0D64
Test Address One
20300 Stevens Creek Blvd, Cupertino, CA 95014
37.322281,-122.028345
il2
63A2F949-2D8D-4C8D-B8A5-DCF2A94610F3
IL2
2 Infinite Loop, Cupertino, CA 95014
37.332633,-122.030502
il1
76E7ECA6-08BC-4AE7-930D-F2E7453993A5
IL1
1 Infinite Loop, Cupertino, CA 95014
37.331741,-122.030333
calendarserver-5.2+dfsg/conf/auth/augments.dtd 0000644 0001750 0001750 00000002423 12263343324 020520 0 ustar rahul rahul
>
calendarserver-5.2+dfsg/conf/auth/proxies.dtd 0000644 0001750 0001750 00000001565 12263343324 020374 0 ustar rahul rahul
>
calendarserver-5.2+dfsg/conf/auth/accounts-test.xml 0000644 0001750 0001750 00000012047 12263343324 021521 0 ustar rahul rahul
admin
admin
admin
Super User
Super
User
apprentice
apprentice
apprentice
Apprentice Super User
Apprentice
Super User
wsanchez
wsanchez
wsanchez@example.com
test
Wilfredo Sanchez Vega
Wilfredo
Sanchez Vega
cdaboo
cdaboo
cdaboo@example.com
test
Cyrus Daboo
Cyrus
Daboo
sagen
sagen
sagen@example.com
test
Morgen Sagen
Morgen
Sagen
dre
andre
dre@example.com
test
Andre LaBranche
Andre
LaBranche
glyph
glyph
glyph@example.com
test
Glyph Lefkowitz
Glyph
Lefkowitz
i18nuser
i18nuser
i18nuser@example.com
i18nuser
まだ
ま
だ
user%02d
User %02d
user%02d
user%02d
User %02d
User
%02d
user%02d@example.com
public%02d
public%02d
public%02d
Public %02d
Public
%02d
group01
group01
group01
Group 01
user01
group02
group02
group02
Group 02
user06
user07
group03
group03
group03
Group 03
user08
user09
group04
group04
group04
Group 04
group02
group03
user10
group05
group05
group05
Group 05
group06
user20
group06
group06
group06
Group 06
user21
group07
group07
group07
Group 07
user22
user23
user24
disabledgroup
disabledgroup
disabledgroup
Disabled Group
user01
calendarserver-5.2+dfsg/conf/auth/augments-default.xml 0000644 0001750 0001750 00000001540 12263343324 022166 0 ustar rahul rahul
Default
true
true
true
calendarserver-5.2+dfsg/conf/auth/proxies-test.xml 0000644 0001750 0001750 00000002103 12263343324 021363 0 ustar rahul rahul
resource%02d
user01
user03
delegatedroom
group05
group07
calendarserver-5.2+dfsg/conf/auth/augments-test.xml 0000644 0001750 0001750 00000014031 12262624042 021516 0 ustar rahul rahul
Default
true
true
true
location%02d
true
true
true
true
resource%02d
true
true
true
true
resource05
true
true
true
true
none
resource06
true
true
true
true
accept-always
resource07
true
true
true
true
decline-always
resource08
true
true
true
true
accept-if-free
resource09
true
true
true
true
decline-if-busy
resource10
true
true
true
true
automatic
resource11
true
true
true
true
decline-always
group01
group%02d
true
disabledgroup
false
delegatedroom
true
true
false
false
03DFF660-8BCC-4198-8588-DD77F776F518
true
true
true
true
true
80689D41-DAF8-4189-909C-DB017B271892
true
true
true
true
true
default
C38BEE7A-36EE-478C-9DCB-CBF4612AFE65
true
true
true
true
true
default
group01
CCE95217-A57B-481A-AC3D-FEC9AB6CE3A9
true
true
true
true
true
0CE0BF31-5F9E-4801-A489-8C70CF287F5F
true
true
true
true
true
6F9EE33B-78F6-481B-9289-3D0812FF0D64
true
true
true
true
false
default
76E7ECA6-08BC-4AE7-930D-F2E7453993A5
true
true
true
true
false
default
63A2F949-2D8D-4C8D-B8A5-DCF2A94610F3
true
true
true
true
false
default
06E3BDCB-9C19-485A-B14E-F146A80ADDC6
true
true
true
true
true
default
4D66A20A-1437-437D-8069-2F14E8322234
true
true
true
true
true
default
calendarserver-5.2+dfsg/conf/sudoers.plist 0000644 0001750 0001750 00000001664 10550570476 020015 0 ustar rahul rahul
users
username
superuser
password
superuser
calendarserver-5.2+dfsg/conf/caldavd-test.plist 0000644 0001750 0001750 00000061036 12263344114 020672 0 ustar rahul rahul
ServerHostName
localhost
EnableCalDAV
EnableCardDAV
HTTPPort
8008
SSLPort
8443
EnableSSL
RedirectHTTPToHTTPS
BindAddresses
BindHTTPPorts
8008
8800
BindSSLPorts
8443
8843
ServerRoot
./data
DataRoot
Data
DatabaseRoot
Database
DocumentRoot
Documents
ConfigRoot
./conf
RunRoot
Logs/state
Aliases
UserQuota
104857600
MaxCollectionsPerHome
50
MaxResourcesPerCollection
10000
MaxResourceSize
1048576
MaxAttendeesPerInstance
100
MaxAllowedInstances
3000
DirectoryService
type
twistedcaldav.directory.xmlfile.XMLDirectoryService
params
xmlFile
./conf/auth/accounts-test.xml
ResourceService
Enabled
type
twistedcaldav.directory.xmlfile.XMLDirectoryService
params
xmlFile
./conf/auth/resources-test.xml
AugmentService
type
twistedcaldav.directory.augment.AugmentXMLDB
params
xmlFiles
./conf/auth/augments-test.xml
ProxyDBService
type
twistedcaldav.directory.calendaruserproxy.ProxySqliteDB
params
dbpath
proxies.sqlite
ProxyLoadFromFile
./conf/auth/proxies-test.xml
AdminPrincipals
/principals/__uids__/admin/
ReadPrincipals
EnableProxyPrincipals
EnableAnonymousReadRoot
EnableAnonymousReadNav
EnablePrincipalListings
EnableMonolithicCalendars
Authentication
Basic
Enabled
AllowedOverWireUnencrypted
Digest
Enabled
AllowedOverWireUnencrypted
Algorithm
md5
Qop
Kerberos
Enabled
AllowedOverWireUnencrypted
ServicePrincipal
Wiki
Enabled
Cookie
sessionID
URL
http://127.0.0.1/RPC2
UserMethod
userForSession
WikiMethod
accessLevelForUserWikiCalendar
LogRoot
Logs
AccessLogFile
access.log
RotateAccessLog
ErrorLogFile
error.log
DefaultLogLevel
info
LogLevels
PIDFile
caldavd.pid
AccountingCategories
iTIP
HTTP
AccountingPrincipals
SSLCertificate
twistedcaldav/test/data/server.pem
SSLAuthorityChain
SSLPrivateKey
twistedcaldav/test/data/server.pem
UserName
GroupName
ProcessType
Combined
MultiProcess
ProcessCount
2
Notifications
CoalesceSeconds
3
Services
AMP
Enabled
Port
62311
EnableStaggering
StaggerSeconds
3
Scheduling
CalDAV
EmailDomain
HTTPDomain
AddressPatterns
OldDraftCompatibility
ScheduleTagCompatibility
EnablePrivateComments
iSchedule
Enabled
AddressPatterns
RemoteServers
remoteservers-test.xml
iMIP
Enabled
MailGatewayServer
localhost
MailGatewayPort
62310
Sending
Server
Port
587
UseSSL
Username
Password
Address
SupressionDays
7
Receiving
Server
Port
995
Type
UseSSL
Username
Password
PollingSeconds
30
AddressPatterns
mailto:.*
Options
AllowGroupAsOrganizer
AllowLocationAsOrganizer
AllowResourceAsOrganizer
AttendeeRefreshBatch
0
AttendeeRefreshCountLimit
50
AutoSchedule
Enabled
Always
DefaultMode
automatic
FreeBusyURL
Enabled
TimePeriod
14
AnonymousAccess
EnableDropBox
EnableManagedAttachments
EnablePrivateEvents
RemoveDuplicatePrivateComments
EnableTimezoneService
TimezoneService
Enabled
Mode
primary
BasePath
XMLInfoPath
SecondaryService
Host
URI
UpdateIntervalMinutes
1440
UsePackageTimezones
EnableBatchUpload
Sharing
Enabled
AllowExternalUsers
Calendars
Enabled
AddressBooks
Enabled
EnableSACLs
EnableReadOnlyServer
EnableWebAdmin
ResponseCompression
HTTPRetryAfter
180
ControlSocket
caldavd.sock
Memcached
MaxClients
5
memcached
memcached
Options
EnableResponseCache
ResponseCacheTimeout
30
Postgres
Options
QueryCaching
Enabled
MemcachedPool
Default
ExpireSeconds
3600
GroupCaching
Enabled
EnableUpdater
MemcachedPool
Default
UpdateSeconds
300
ExpireSeconds
3600
LockSeconds
300
UseExternalProxies
MaxPrincipalSearchReportResults
500
Twisted
twistd
../Twisted/bin/twistd
Localization
TranslationsDirectory
locales
LocalesDirectory
locales
Language
en
calendarserver-5.2+dfsg/conf/resources.xml 0000644 0001750 0001750 00000001312 12263343324 017767 0 ustar rahul rahul
calendarserver-5.2+dfsg/conf/remoteservers-test.xml 0000644 0001750 0001750 00000001650 12263343324 021644 0 ustar rahul rahul
https://localhost:8543/inbox
example.org
127.0.0.1
calendarserver-5.2+dfsg/conf/localservers-test.xml 0000644 0001750 0001750 00000001746 12263343324 021451 0 ustar rahul rahul
00001
http://localhost:8008
00001
http://localhost:8008
00002
http://localhost:8108
calendarserver-5.2+dfsg/conf/servertoserver.dtd 0000644 0001750 0001750 00000002140 12263343324 021030 0 ustar rahul rahul
calendarserver-5.2+dfsg/conf/resources/ 0000755 0001750 0001750 00000000000 12322625306 017247 5 ustar rahul rahul calendarserver-5.2+dfsg/conf/resources/users-groups.xml 0000644 0001750 0001750 00000005427 12263343324 022460 0 ustar rahul rahul
admin
admin
admin
Super User
Super
User
apprentice
apprentice
apprentice
Apprentice Super User
Apprentice
Super User
user%02d
User %02d
user%02d
user%02d
User %02d
User
%02d
user%02d@example.com
public%02d
public%02d
public%02d
Public %02d
Public
%02d
group01
group01
group01
Group 01
user01
group02
group02
group02
Group 02
user06
user07
group03
group03
group03
Group 03
user08
user09
group04
group04
group04
Group 04
group02
group03
user10
disabledgroup
disabledgroup
disabledgroup
Disabled Group
user01
calendarserver-5.2+dfsg/conf/resources/caldavd-resources.plist 0000644 0001750 0001750 00000043060 12263343324 023736 0 ustar rahul rahul
ServerHostName
HTTPPort
8008
SSLPort
8443
RedirectHTTPToHTTPS
BindAddresses
BindHTTPPorts
BindSSLPorts
DataRoot
data/
DocumentRoot
twistedcaldav/test/data/
Aliases
UserQuota
104857600
MaximumAttachmentSize
1048576
MaxAttendeesPerInstance
100
MaxInstancesForRRULE
400
DirectoryService
type
twistedcaldav.directory.xmlfile.XMLDirectoryService
params
xmlFile
conf/resources/users-groups.xml
recordTypes
users
groups
ResourceService
Enabled
type
twistedcaldav.directory.xmlfile.XMLDirectoryService
params
xmlFile
conf/resources/locations-resources.xml
recordTypes
locations
resources
AugmentService
type
twistedcaldav.directory.augment.AugmentXMLDB
params
xmlFiles
conf/auth/augments-test.xml
ProxyDBService
type
twistedcaldav.directory.calendaruserproxy.ProxySqliteDB
params
dbpath
data/proxies.sqlite
ProxyLoadFromFile
conf/auth/proxies-test.xml
AdminPrincipals
/principals/__uids__/admin/
ReadPrincipals
EnableProxyPrincipals
EnableAnonymousReadRoot
EnableAnonymousReadNav
EnablePrincipalListings
EnableMonolithicCalendars
Authentication
Basic
Enabled
Digest
Enabled
Algorithm
md5
Qop
Kerberos
Enabled
ServicePrincipal
Wiki
Enabled
Cookie
sessionID
URL
http://127.0.0.1/RPC2
UserMethod
userForSession
WikiMethod
accessLevelForUserWikiCalendar
AccessLogFile
logs/access.log
RotateAccessLog
ErrorLogFile
logs/error.log
DefaultLogLevel
info
LogLevels
ServerStatsFile
logs/stats.plist
PIDFile
logs/caldavd.pid
AccountingCategories
iTIP
HTTP
AccountingPrincipals
SSLCertificate
twistedcaldav/test/data/server.pem
SSLAuthorityChain
SSLPrivateKey
twistedcaldav/test/data/server.pem
UserName
GroupName
ProcessType
Combined
MultiProcess
ProcessCount
2
Notifications
CoalesceSeconds
3
InternalNotificationHost
localhost
InternalNotificationPort
62309
Services
SimpleLineNotifier
Service
twistedcaldav.notify.SimpleLineNotifierService
Enabled
Port
62308
XMPPNotifier
Service
twistedcaldav.notify.XMPPNotifierService
Enabled
Host
xmpp.host.name
Port
5222
JID
jid@xmpp.host.name/resource
Password
password_goes_here
ServiceAddress
pubsub.xmpp.host.name
NodeConfiguration
pubsub#deliver_payloads
1
pubsub#persist_items
1
KeepAliveSeconds
120
HeartbeatMinutes
30
AllowedJIDs
Scheduling
CalDAV
EmailDomain
HTTPDomain
AddressPatterns
OldDraftCompatibility
ScheduleTagCompatibility
EnablePrivateComments
iSchedule
Enabled
AddressPatterns
Servers
conf/servertoserver-test.xml
iMIP
Enabled
MailGatewayServer
localhost
MailGatewayPort
62310
Sending
Server
Port
587
UseSSL
Username
Password
Address
Receiving
Server
Port
995
Type
UseSSL
Username
Password
PollingSeconds
30
AddressPatterns
mailto:.*
Options
AllowGroupAsOrganizer
AllowLocationAsOrganizer
AllowResourceAsOrganizer
FreeBusyURL
Enabled
TimePeriod
14
AnonymousAccess
EnableDropBox
EnablePrivateEvents
EnableTimezoneService
EnableSACLs
EnableWebAdmin
ResponseCompression
HTTPRetryAfter
180
ControlSocket
logs/caldavd.sock
Memcached
MaxClients
5
memcached
memcached
Options
EnableResponseCache
ResponseCacheTimeout
30
Twisted
twistd
../Twisted/bin/twistd
Localization
LocalesDirectory
locales
Language
English
calendarserver-5.2+dfsg/conf/resources/locations-resources-orig.xml 0000644 0001750 0001750 00000002025 12263343324 024732 0 ustar rahul rahul
location%02d
location%02d
location%02d
Room %02d
resource%02d
resource%02d
resource%02d
Resource %02d
calendarserver-5.2+dfsg/conf/resources/locations-resources.xml 0000644 0001750 0001750 00000002025 12263343324 023774 0 ustar rahul rahul
location%02d
location%02d
location%02d
Room %02d
resource%02d
resource%02d
resource%02d
Resource %02d
calendarserver-5.2+dfsg/conf/caldavd-partitioning-primary.plist 0000644 0001750 0001750 00000004377 12263343324 024112 0 ustar rahul rahul
Servers
Enabled
ConfigFile
localservers.xml
MaxClients
5
ServerPartitionID
00001
ProxyDBService
type
twistedcaldav.directory.calendaruserproxy.ProxyPostgreSQLDB
params
host
localhost
database
proxies
Memcached
Pools
CommonToAllNodes
ClientEnabled
ServerEnabled
BindAddress
localhost
Port
11311
HandleCacheTypes
ProxyDB
PrincipalToken
DIGESTCREDENTIALS
MaxClients
5
memcached
../memcached/_root/bin/memcached
Options
calendarserver-5.2+dfsg/conf/caldavd.plist 0000644 0001750 0001750 00000027076 12263343324 017725 0 ustar rahul rahul
ServerHostName
HTTPPort
80
RedirectHTTPToHTTPS
BindAddresses
BindHTTPPorts
BindSSLPorts
ServerRoot
/var/db/caldavd
DataRoot
Data
DocumentRoot
Documents
ConfigRoot
/etc/caldavd
RunRoot
/var/run
Aliases
UserQuota
104857600
MaxCollectionsPerHome
50
MaxResourcesPerCollection
10000
MaxResourceSize
1048576
MaxAttendeesPerInstance
100
MaxAllowedInstances
3000
DirectoryService
type
twistedcaldav.directory.xmlfile.XMLDirectoryService
params
xmlFile
accounts.xml
AdminPrincipals
ReadPrincipals
EnableProxyPrincipals
EnableAnonymousReadRoot
EnableAnonymousReadNav
EnablePrincipalListings
EnableMonolithicCalendars
Authentication
Basic
Enabled
Digest
Enabled
Algorithm
md5
Qop
Kerberos
Enabled
ServicePrincipal
LogRoot
/var/log/caldavd
AccessLogFile
access.log
RotateAccessLog
ErrorLogFile
error.log
DefaultLogLevel
warn
PIDFile
caldavd.pid
SSLCertificate
SSLAuthorityChain
SSLPrivateKey
UserName
daemon
GroupName
daemon
ProcessType
Combined
MultiProcess
ProcessCount
0
Notifications
CoalesceSeconds
3
Services
XMPPNotifier
Service
twistedcaldav.notify.XMPPNotifierService
Enabled
Host
xmpp.host.name
Port
5222
JID
jid@xmpp.host.name/resource
Password
password_goes_here
ServiceAddress
pubsub.xmpp.host.name
Scheduling
CalDAV
EmailDomain
HTTPDomain
AddressPatterns
iSchedule
Enabled
AddressPatterns
RemoteServers
remoteservers.xml
iMIP
Enabled
MailGatewayServer
localhost
MailGatewayPort
62310
Sending
Server
Port
587
UseSSL
Username
Password
Address
Receiving
Server
Port
995
Type
UseSSL
Username
Password
PollingSeconds
30
AddressPatterns
mailto:.*
FreeBusyURL
Enabled
TimePeriod
14
AnonymousAccess
EnablePrivateEvents
Sharing
Enabled
EnableWebAdmin
calendarserver-5.2+dfsg/conf/caldavd-apple.plist 0000644 0001750 0001750 00000032200 12263344251 021005 0 ustar rahul rahul
ServerHostName
EnableCalDAV
EnableCardDAV
HTTPPort
8008
SSLPort
8443
EnableSSL
RedirectHTTPToHTTPS
BindAddresses
BindHTTPPorts
8008
8800
BindSSLPorts
8443
8843
ServerRoot
/Library/Server/Calendar and Contacts
DBType
DSN
DBImportFile
/Library/Server/Calendar and Contacts/DataDump.sql
Postgres
Ctl
xpg_ctl
Options
-c log_lock_waits=TRUE
-c deadlock_timeout=10
-c log_line_prefix='%m [%p] '
-c logging_collector=on
-c log_truncate_on_rotation=on
-c log_directory=/var/log/caldavd/postgresql
-c log_filename=postgresql_%w.log
-c log_rotation_age=1440
ExtraConnections
20
ClusterName
cluster.pg
LogFile
xpg_ctl.log
SocketDirectory
/var/run/caldavd/PostgresSocket
DataRoot
Data
DatabaseRoot
Database.xpg
DocumentRoot
Documents
ConfigRoot
Config
RunRoot
/var/run/caldavd
Aliases
UserQuota
104857600
MaxCollectionsPerHome
50
MaxResourcesPerCollection
10000
MaxResourceSize
1048576
MaxAttendeesPerInstance
100
MaxAllowedInstances
3000
DirectoryService
type
twistedcaldav.directory.appleopendirectory.OpenDirectoryService
params
node
/Search
AdminPrincipals
ReadPrincipals
EnableProxyPrincipals
EnableAnonymousReadRoot
EnableAnonymousReadNav
EnablePrincipalListings
EnableMonolithicCalendars
Authentication
Basic
Enabled
Digest
Enabled
Algorithm
md5
Qop
Kerberos
Enabled
ServicePrincipal
Wiki
Enabled
LogRoot
/var/log/caldavd
AccessLogFile
access.log
RotateAccessLog
ErrorLogFile
error.log
DefaultLogLevel
warn
PIDFile
caldavd.pid
SSLCertificate
SSLAuthorityChain
SSLPrivateKey
UserName
calendar
GroupName
calendar
ProcessType
Combined
MultiProcess
ProcessCount
0
Notifications
CoalesceSeconds
3
Services
Scheduling
CalDAV
EmailDomain
HTTPDomain
AddressPatterns
iSchedule
Enabled
AddressPatterns
RemoteServers
remoteservers.xml
iMIP
Enabled
MailGatewayServer
localhost
MailGatewayPort
62310
Sending
Server
Port
587
UseSSL
Username
Password
Address
Receiving
Server
Port
995
Type
UseSSL
Username
Password
PollingSeconds
30
AddressPatterns
mailto:.*
FreeBusyURL
Enabled
TimePeriod
14
AnonymousAccess
EnableDropBox
EnableManagedAttachments
EnablePrivateEvents
EnableTimezoneService
Sharing
Enabled
EnableSACLs
EnableWebAdmin
WebCalendarAuthPath
/auth
DirectoryAddressBook
Enabled
params
queryUserRecords
queryPeopleRecords
EnableSearchAddressBook
Includes
/Library/Server/Calendar and Contacts/Config/caldavd-system.plist
/Library/Server/Calendar and Contacts/Config/caldavd-user.plist
WritableConfigFile
/Library/Server/Calendar and Contacts/Config/caldavd-system.plist
calendarserver-5.2+dfsg/conf/localservers.xml 0000644 0001750 0001750 00000002351 12263343324 020465 0 ustar rahul rahul
calendarserver-5.2+dfsg/twext/ 0000755 0001750 0001750 00000000000 12322625326 015465 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/python/ 0000755 0001750 0001750 00000000000 12322625326 017006 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/python/filepath.py 0000644 0001750 0001750 00000010502 12263343324 021151 0 ustar rahul rahul # -*- test-case-name: twext.python.test.test_filepath -*-
##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Extend L{twisted.python.filepath} to provide performance enhancements for
calendar server.
"""
from os import listdir as _listdir
from os.path import (join as _joinpath,
basename as _basename,
exists as _exists,
dirname as _dirname)
from time import sleep as _sleep
from types import FunctionType, MethodType
from errno import EINVAL
from twisted.python.filepath import FilePath as _FilePath
from stat import S_ISDIR
class CachingFilePath(_FilePath, object):
"""
A descendent of L{_FilePath} which implements a more aggressive caching
policy.
"""
_listdir = _listdir # integration points for tests
_sleep = _sleep
BACKOFF_MAX = 5.0 # Maximum time to wait between calls to
# listdir()
def __init__(self, path, alwaysCreate=False):
super(CachingFilePath, self).__init__(path, alwaysCreate)
self.existsCached = None
self.isDirCached = None
@property
def siblingExtensionSearch(self):
"""
Dynamically create a version of L{_FilePath.siblingExtensionSearch} that
uses a pluggable 'listdir' implementation.
"""
return MethodType(FunctionType(
_FilePath.siblingExtensionSearch.im_func.func_code,
{'listdir': self._retryListdir,
'basename': _basename,
'dirname': _dirname,
'joinpath': _joinpath,
'exists': _exists}), self, self.__class__)
def changed(self):
"""
This path may have changed in the filesystem, so forget all cached
information about it.
"""
self.statinfo = None
self.existsCached = None
self.isDirCached = None
def _retryListdir(self, pathname):
"""
Implementation of retry logic for C{listdir} and
C{siblingExtensionSearch}.
"""
delay = 0.1
while True:
try:
return self._listdir(pathname)
except OSError, e:
if e.errno == EINVAL:
self._sleep(delay)
delay = min(self.BACKOFF_MAX, delay * 2.0)
else:
raise
raise RuntimeError("unreachable code.")
def listdir(self):
"""
List the directory which C{self.path} points to, compensating for
EINVAL from C{os.listdir}.
"""
return self._retryListdir(self.path)
def restat(self, reraise=True):
"""
Re-cache stat information.
"""
try:
return super(CachingFilePath, self).restat(reraise)
finally:
if self.statinfo:
self.existsCached = True
self.isDirCached = S_ISDIR(self.statinfo.st_mode)
else:
self.existsCached = False
self.isDirCached = None
def moveTo(self, destination, followLinks=True):
"""
Override L{_FilePath.moveTo}, updating extended cache information if
necessary.
"""
result = super(CachingFilePath, self).moveTo(destination, followLinks)
self.changed()
# Work with vanilla FilePath destinations to pacify the tests.
if hasattr(destination, "changed"):
destination.changed()
return result
def remove(self):
"""
Override L{_FilePath.remove}, updating extended cache information if
necessary.
"""
try:
return super(CachingFilePath, self).remove()
finally:
self.changed()
CachingFilePath.clonePath = CachingFilePath
__all__ = ["CachingFilePath"]
calendarserver-5.2+dfsg/twext/python/parallel.py 0000644 0001750 0001750 00000006324 12263343324 021160 0 ustar rahul rahul # -*- test-case-name: twext.python.test.test_parallel -*-
##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Utilities for parallelizing tasks.
"""
from twisted.internet.defer import inlineCallbacks, DeferredList, returnValue
class Parallelizer(object):
"""
Do some operation with a degree of parallelism, using a set of resources
which may each only be used for one task at a time, given some underlying
API that returns L{Deferreds}.
@ivar available: A list of available resources from the C{resources}
constructor parameter.
@ivar busy: A list of resources which are currently being used by
operations.
"""
def __init__(self, resources):
"""
Initialize a L{Parallelizer} with a list of objects that will be passed
to the callables sent to L{Parallelizer.do}.
@param resources: objects which may be of any arbitrary type.
@type resources: C{list}
"""
self.available = list(resources)
self.busy = []
self.activeDeferreds = []
@inlineCallbacks
def do(self, operation):
"""
Call C{operation} with one of the resources in C{self.available},
removing that value for use by other callers of C{do} until the task
performed by C{operation} is complete (in other words, the L{Deferred}
returned by C{operation} has fired).
@param operation: a 1-argument callable taking a resource from
C{self.active} and returning a L{Deferred} when it's done using
that resource.
@type operation: C{callable}
@return: a L{Deferred} that fires as soon as there are resources
available such that this task can be I{started} - not completed.
"""
if not self.available:
yield DeferredList(self.activeDeferreds, fireOnOneCallback=True,
fireOnOneErrback=True)
active = self.available.pop(0)
self.busy.append(active)
o = operation(active)
def andFinally(whatever):
self.activeDeferreds.remove(o)
self.busy.remove(active)
self.available.append(active)
return whatever
self.activeDeferreds.append(o)
o.addBoth(andFinally)
returnValue(None)
def done(self):
"""
Wait until all operations started by L{Parallelizer.do} are completed.
@return: a L{Deferred} that fires (with C{None}) when all the currently
pending work on this L{Parallelizer} is completed and C{busy} is
empty again.
"""
return (DeferredList(self.activeDeferreds)
.addCallback(lambda ignored: None))
calendarserver-5.2+dfsg/twext/python/timezone.py 0000644 0001750 0001750 00000003075 12263343324 021216 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
from twistedcaldav.config import config
import twistedcaldav.timezones
DEFAULT_TIMEZONE = "America/Los_Angeles"
try:
from Foundation import NSTimeZone
def lookupSystemTimezone():
return NSTimeZone.localTimeZone().name().encode("utf-8")
except:
def lookupSystemTimezone():
return ""
def getLocalTimezone():
"""
Returns the default timezone for the server. The order of precedence is:
config.DefaultTimezone, lookupSystemTimezone( ), DEFAULT_TIMEZONE.
Also, if neither of the first two values in that list are in the timezone
database, DEFAULT_TIMEZONE is returned.
@return: The server's local timezone name
@rtype: C{str}
"""
if config.DefaultTimezone:
if twistedcaldav.timezones.hasTZ(config.DefaultTimezone):
return config.DefaultTimezone
systemTimezone = lookupSystemTimezone()
if twistedcaldav.timezones.hasTZ(systemTimezone):
return systemTimezone
return DEFAULT_TIMEZONE
calendarserver-5.2+dfsg/twext/python/test/ 0000755 0001750 0001750 00000000000 12322625326 017765 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/python/test/test_filepath.py 0000644 0001750 0001750 00000012576 12263343324 023204 0 ustar rahul rahul ##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for specialized behavior of L{CachingFilePath}
"""
from errno import EINVAL
from os.path import join as pathjoin
from twisted.internet.task import Clock
from twisted.trial.unittest import TestCase
from twext.python.filepath import CachingFilePath
# Cheat and pull in the Twisted test cases for FilePath. XXX: Twisteds should
# provide a supported way of doing this for exported interfaces. Also, it
# should export IFilePath. --glyph
from twisted.test.test_paths import FilePathTestCase
class BaseVerification(FilePathTestCase):
"""
Make sure that L{CachingFilePath} doesn't break the contracts that
L{FilePath} tries to provide.
"""
def setUp(self):
"""
Set up the test case to set the base attributes to point at
L{AbstractFilePathTestCase}.
"""
FilePathTestCase.setUp(self)
self.root = CachingFilePath(self.root.path)
self.path = CachingFilePath(self.path.path)
class EINVALTestCase(TestCase):
"""
Sometimes, L{os.listdir} will raise C{EINVAL}. This is a transient error,
and L{CachingFilePath.listdir} should work around it by retrying the
C{listdir} operation until it succeeds.
"""
def setUp(self):
"""
Create a L{CachingFilePath} for the test to use.
"""
self.cfp = CachingFilePath(self.mktemp())
self.clock = Clock()
self.cfp._sleep = self.clock.advance
def test_testValidity(self):
"""
If C{listdir} is replaced on a L{CachingFilePath}, we should be able to
observe exceptions raised by the replacement. This verifies that the
test patching done here is actually testing something.
"""
class CustomException(Exception): "Just for testing."
def blowUp(dirname):
raise CustomException()
self.cfp._listdir = blowUp
self.assertRaises(CustomException, self.cfp.listdir)
self.assertRaises(CustomException, self.cfp.children)
def test_retryLoop(self):
"""
L{CachingFilePath} should catch C{EINVAL} and respond by retrying the
C{listdir} operation until it succeeds.
"""
calls = []
def raiseEINVAL(dirname):
calls.append(dirname)
if len(calls) < 5:
raise OSError(EINVAL, "This should be caught by the test.")
return ['a', 'b', 'c']
self.cfp._listdir = raiseEINVAL
self.assertEquals(self.cfp.listdir(), ['a', 'b', 'c'])
self.assertEquals(self.cfp.children(), [
CachingFilePath(pathjoin(self.cfp.path, 'a')),
CachingFilePath(pathjoin(self.cfp.path, 'b')),
CachingFilePath(pathjoin(self.cfp.path, 'c')),])
def requireTimePassed(self, filenames):
"""
Create a replacement for listdir() which only fires after a certain
amount of time.
"""
self.calls = []
def thunk(dirname):
now = self.clock.seconds()
if now < 20.0:
self.calls.append(now)
raise OSError(EINVAL, "Not enough time has passed yet.")
else:
return filenames
self.cfp._listdir = thunk
def assertRequiredTimePassed(self):
"""
Assert that calls to the simulated time.sleep() installed by
C{requireTimePassed} have been invoked the required number of times.
"""
# Waiting should be growing by *2 each time until the additional wait
# exceeds BACKOFF_MAX (5), at which point we should wait for 5s each
# time.
def cumulative(values):
current = 0.0
for value in values:
current += value
yield current
self.assertEquals(self.calls,
list(cumulative(
[0.0, 0.1, 0.2, 0.4, 0.8, 1.6, 3.2, 5.0, 5.0])))
def test_backoff(self):
"""
L{CachingFilePath} will wait for an increasing interval up to
C{BACKOFF_MAX} between calls to listdir().
"""
self.requireTimePassed(['a', 'b', 'c'])
self.assertEquals(self.cfp.listdir(), ['a', 'b', 'c'])
def test_siblingExtensionSearch(self):
"""
L{FilePath.siblingExtensionSearch} is unfortunately not implemented in
terms of L{FilePath.listdir}, so we need to verify that it will also
retry.
"""
filenames = [self.cfp.basename()+'.a',
self.cfp.basename() + '.b',
self.cfp.basename() + '.c']
siblings = map(self.cfp.sibling, filenames)
for sibling in siblings:
sibling.touch()
self.requireTimePassed(filenames)
self.assertEquals(self.cfp.siblingExtensionSearch("*"),
siblings[0])
self.assertRequiredTimePassed()
calendarserver-5.2+dfsg/twext/python/test/test_log.py 0000644 0001750 0001750 00000076162 12263343324 022172 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
from zope.interface.verify import verifyObject, BrokenMethodImplementation
from twisted.python import log as twistedLogging
from twisted.python.failure import Failure
from twisted.trial import unittest
from twext.python.log import (
LogLevel, InvalidLogLevelError,
pythonLogLevelMapping,
formatEvent, formatUnformattableEvent, formatWithCall,
Logger, LegacyLogger,
ILogObserver, LogPublisher, DefaultLogPublisher,
FilteringLogObserver, PredicateResult,
LogLevelFilterPredicate, OBSERVER_REMOVED
)
defaultLogLevel = LogLevelFilterPredicate().defaultLogLevel
clearLogLevels = Logger.publisher.levels.clearLogLevels
logLevelForNamespace = Logger.publisher.levels.logLevelForNamespace
setLogLevelForNamespace = Logger.publisher.levels.setLogLevelForNamespace
class TestLogger(Logger):
def emit(self, level, format=None, **kwargs):
if False:
print "*"*60
print "level =", level
print "format =", format
for key, value in kwargs.items():
print key, "=", value
print "*"*60
def observer(event):
self.event = event
twistedLogging.addObserver(observer)
try:
Logger.emit(self, level, format, **kwargs)
finally:
twistedLogging.removeObserver(observer)
self.emitted = {
"level": level,
"format": format,
"kwargs": kwargs,
}
class TestLegacyLogger(LegacyLogger):
def __init__(self, logger=TestLogger()):
LegacyLogger.__init__(self, logger=logger)
class LogComposedObject(object):
"""
Just a regular object.
"""
log = TestLogger()
def __init__(self, state=None):
self.state = state
def __str__(self):
return "".format(state=self.state)
class SetUpTearDown(object):
def setUp(self):
super(SetUpTearDown, self).setUp()
clearLogLevels()
def tearDown(self):
super(SetUpTearDown, self).tearDown()
clearLogLevels()
class LoggingTests(SetUpTearDown, unittest.TestCase):
"""
General module tests.
"""
def test_levelWithName(self):
"""
Look up log level by name.
"""
for level in LogLevel.iterconstants():
self.assertIdentical(LogLevel.levelWithName(level.name), level)
def test_levelWithInvalidName(self):
"""
You can't make up log level names.
"""
bogus = "*bogus*"
try:
LogLevel.levelWithName(bogus)
except InvalidLogLevelError as e:
self.assertIdentical(e.level, bogus)
else:
self.fail("Expected InvalidLogLevelError.")
def test_defaultLogLevel(self):
"""
Default log level is used.
"""
self.failUnless(logLevelForNamespace(None), defaultLogLevel)
self.failUnless(logLevelForNamespace(""), defaultLogLevel)
self.failUnless(logLevelForNamespace("rocker.cool.namespace"),
defaultLogLevel)
def test_setLogLevel(self):
"""
Setting and retrieving log levels.
"""
setLogLevelForNamespace(None, LogLevel.error)
setLogLevelForNamespace("twext.web2", LogLevel.debug)
setLogLevelForNamespace("twext.web2.dav", LogLevel.warn)
self.assertEquals(logLevelForNamespace(None),
LogLevel.error)
self.assertEquals(logLevelForNamespace("twisted"),
LogLevel.error)
self.assertEquals(logLevelForNamespace("twext.web2"),
LogLevel.debug)
self.assertEquals(logLevelForNamespace("twext.web2.dav"),
LogLevel.warn)
self.assertEquals(logLevelForNamespace("twext.web2.dav.test"),
LogLevel.warn)
self.assertEquals(logLevelForNamespace("twext.web2.dav.test1.test2"),
LogLevel.warn)
def test_setInvalidLogLevel(self):
"""
Can't pass invalid log levels to setLogLevelForNamespace().
"""
self.assertRaises(InvalidLogLevelError, setLogLevelForNamespace,
"twext.web2", object())
# Level must be a constant, not the name of a constant
self.assertRaises(InvalidLogLevelError, setLogLevelForNamespace,
"twext.web2", "debug")
def test_clearLogLevels(self):
"""
Clearing log levels.
"""
setLogLevelForNamespace("twext.web2", LogLevel.debug)
setLogLevelForNamespace("twext.web2.dav", LogLevel.error)
clearLogLevels()
self.assertEquals(logLevelForNamespace("twisted"), defaultLogLevel)
self.assertEquals(logLevelForNamespace("twext.web2"), defaultLogLevel)
self.assertEquals(logLevelForNamespace("twext.web2.dav"),
defaultLogLevel)
self.assertEquals(logLevelForNamespace("twext.web2.dav.test"),
defaultLogLevel)
self.assertEquals(logLevelForNamespace("twext.web2.dav.test1.test2"),
defaultLogLevel)
def test_namespace_default(self):
"""
Default namespace is module name.
"""
log = Logger()
self.assertEquals(log.namespace, __name__)
def test_formatWithCall(self):
"""
L{formatWithCall} is an extended version of L{unicode.format} that will
interpret a set of parentheses "C{()}" at the end of a format key to
mean that the format key ought to be I{called} rather than stringified.
"""
self.assertEquals(
formatWithCall(
u"Hello, {world}. {callme()}.",
dict(world="earth", callme=lambda: "maybe")
),
"Hello, earth. maybe."
)
self.assertEquals(
formatWithCall(
u"Hello, {repr()!r}.",
dict(repr=lambda: "repr")
),
"Hello, 'repr'."
)
def test_formatEvent(self):
"""
L{formatEvent} will format an event according to several rules:
- A string with no formatting instructions will be passed straight
through.
- PEP 3101 strings will be formatted using the keys and values of
the event as named fields.
- PEP 3101 keys ending with C{()} will be treated as instructions
to call that key (which ought to be a callable) before
formatting.
L{formatEvent} will always return L{unicode}, and if given
bytes, will always treat its format string as UTF-8 encoded.
"""
def format(log_format, **event):
event["log_format"] = log_format
result = formatEvent(event)
self.assertIdentical(type(result), unicode)
return result
self.assertEquals(u"", format(b""))
self.assertEquals(u"", format(u""))
self.assertEquals(u"abc", format("{x}", x="abc"))
self.assertEquals(u"no, yes.",
format("{not_called}, {called()}.",
not_called="no", called=lambda: "yes"))
self.assertEquals(u'S\xe1nchez', format("S\xc3\xa1nchez"))
self.assertIn(u"Unable to format event", format(b"S\xe1nchez"))
self.assertIn(u"Unable to format event",
format(b"S{a}nchez", a=b"\xe1"))
self.assertIn(u"S'\\xe1'nchez",
format(b"S{a!r}nchez", a=b"\xe1"))
def test_formatEventNoFormat(self):
"""
Formatting an event with no format.
"""
event = dict(foo=1, bar=2)
result = formatEvent(event)
self.assertIn("Unable to format event", result)
self.assertIn(repr(event), result)
def test_formatEventWeirdFormat(self):
"""
Formatting an event with a bogus format.
"""
event = dict(log_format=object(), foo=1, bar=2)
result = formatEvent(event)
self.assertIn("Log format must be unicode or bytes", result)
self.assertIn(repr(event), result)
def test_formatUnformattableEvent(self):
"""
Formatting an event that's just plain out to get us.
"""
event = dict(log_format="{evil()}", evil=lambda: 1/0)
result = formatEvent(event)
self.assertIn("Unable to format event", result)
self.assertIn(repr(event), result)
def test_formatUnformattableEventWithUnformattableKey(self):
"""
Formatting an unformattable event that has an unformattable key.
"""
event = {
"log_format": "{evil()}",
"evil": lambda: 1/0,
Unformattable(): "gurk",
}
result = formatEvent(event)
self.assertIn("MESSAGE LOST: unformattable object logged:", result)
self.assertIn("Recoverable data:", result)
self.assertIn("Exception during formatting:", result)
def test_formatUnformattableEventWithUnformattableValue(self):
"""
Formatting an unformattable event that has an unformattable value.
"""
event = dict(
log_format="{evil()}",
evil=lambda: 1/0,
gurk=Unformattable(),
)
result = formatEvent(event)
self.assertIn("MESSAGE LOST: unformattable object logged:", result)
self.assertIn("Recoverable data:", result)
self.assertIn("Exception during formatting:", result)
def test_formatUnformattableEventWithUnformattableErrorOMGWillItStop(self):
"""
Formatting an unformattable event that has an unformattable value.
"""
event = dict(
log_format="{evil()}",
evil=lambda: 1/0,
recoverable="okay",
)
# Call formatUnformattableEvent() directly with a bogus exception.
result = formatUnformattableEvent(event, Unformattable())
self.assertIn("MESSAGE LOST: unformattable object logged:", result)
self.assertIn(repr("recoverable") + " = " + repr("okay"), result)
class LoggerTests(SetUpTearDown, unittest.TestCase):
"""
Tests for L{Logger}.
"""
def test_repr(self):
"""
repr() on Logger
"""
namespace = "bleargh"
log = Logger(namespace)
self.assertEquals(repr(log), "".format(repr(namespace)))
def test_namespace_attribute(self):
"""
Default namespace for classes using L{Logger} as a descriptor is the
class name they were retrieved from.
"""
obj = LogComposedObject()
self.assertEquals(obj.log.namespace,
"twext.python.test.test_log.LogComposedObject")
self.assertEquals(LogComposedObject.log.namespace,
"twext.python.test.test_log.LogComposedObject")
self.assertIdentical(LogComposedObject.log.source, LogComposedObject)
self.assertIdentical(obj.log.source, obj)
self.assertIdentical(Logger().source, None)
def test_sourceAvailableForFormatting(self):
"""
On instances that have a L{Logger} class attribute, the C{log_source}
key is available to format strings.
"""
obj = LogComposedObject("hello")
log = obj.log
log.error("Hello, {log_source}.")
self.assertIn("log_source", log.event)
self.assertEquals(log.event["log_source"], obj)
stuff = formatEvent(log.event)
self.assertIn("Hello, .", stuff)
def test_basic_Logger(self):
"""
Test that log levels and messages are emitted correctly for
Logger.
"""
# FIXME: Need a basic test like this for logger attached to a class.
# At least: source should not be None in that case.
log = TestLogger()
for level in LogLevel.iterconstants():
format = "This is a {level_name} message"
message = format.format(level_name=level.name)
method = getattr(log, level.name)
method(format, junk=message, level_name=level.name)
# Ensure that test_emit got called with expected arguments
self.assertEquals(log.emitted["level"], level)
self.assertEquals(log.emitted["format"], format)
self.assertEquals(log.emitted["kwargs"]["junk"], message)
if level >= logLevelForNamespace(log.namespace):
self.assertTrue(hasattr(log, "event"), "No event observed.")
self.assertEquals(log.event["log_format"], format)
self.assertEquals(log.event["log_level"], level)
self.assertEquals(log.event["log_namespace"], __name__)
self.assertEquals(log.event["log_source"], None)
self.assertEquals(log.event["logLevel"],
pythonLogLevelMapping[level])
self.assertEquals(log.event["junk"], message)
# FIXME: this checks the end of message because we do
# formatting in emit()
self.assertEquals(
formatEvent(log.event),
message
)
else:
self.assertFalse(hasattr(log, "event"))
def test_defaultFailure(self):
"""
Test that log.failure() emits the right data.
"""
log = TestLogger()
try:
raise RuntimeError("baloney!")
except RuntimeError:
log.failure("Whoops")
#
# log.failure() will cause trial to complain, so here we check that
# trial saw the correct error and remove it from the list of things to
# complain about.
#
errors = self.flushLoggedErrors(RuntimeError)
self.assertEquals(len(errors), 1)
self.assertEquals(log.emitted["level"], LogLevel.error)
self.assertEquals(log.emitted["format"], "Whoops")
def test_conflicting_kwargs(self):
"""
Make sure that kwargs conflicting with args don't pass through.
"""
log = TestLogger()
log.warn(
"*",
log_format="#",
log_level=LogLevel.error,
log_namespace="*namespace*",
log_source="*source*",
)
# FIXME: Should conflicts log errors?
self.assertEquals(log.event["log_format"], "*")
self.assertEquals(log.event["log_level"], LogLevel.warn)
self.assertEquals(log.event["log_namespace"], log.namespace)
self.assertEquals(log.event["log_source"], None)
def test_logInvalidLogLevel(self):
"""
Test passing in a bogus log level to C{emit()}.
"""
log = TestLogger()
log.emit("*bogus*")
errors = self.flushLoggedErrors(InvalidLogLevelError)
self.assertEquals(len(errors), 1)
class LogPublisherTests(SetUpTearDown, unittest.TestCase):
"""
Tests for L{LogPublisher}.
"""
def test_interface(self):
"""
L{LogPublisher} is an L{ILogObserver}.
"""
publisher = LogPublisher()
try:
verifyObject(ILogObserver, publisher)
except BrokenMethodImplementation as e:
self.fail(e)
def test_observers(self):
"""
L{LogPublisher.observers} returns the observers.
"""
o1 = lambda e: None
o2 = lambda e: None
publisher = LogPublisher(o1, o2)
self.assertEquals(set((o1, o2)), set(publisher.observers))
def test_addObserver(self):
"""
L{LogPublisher.addObserver} adds an observer.
"""
o1 = lambda e: None
o2 = lambda e: None
o3 = lambda e: None
publisher = LogPublisher(o1, o2)
publisher.addObserver(o3)
self.assertEquals(set((o1, o2, o3)), set(publisher.observers))
def test_removeObserver(self):
"""
L{LogPublisher.removeObserver} removes an observer.
"""
o1 = lambda e: None
o2 = lambda e: None
o3 = lambda e: None
publisher = LogPublisher(o1, o2, o3)
publisher.removeObserver(o2)
self.assertEquals(set((o1, o3)), set(publisher.observers))
def test_removeObserverNotRegistered(self):
"""
L{LogPublisher.removeObserver} removes an observer that is not
registered.
"""
o1 = lambda e: None
o2 = lambda e: None
o3 = lambda e: None
publisher = LogPublisher(o1, o2)
publisher.removeObserver(o3)
self.assertEquals(set((o1, o2)), set(publisher.observers))
def test_fanOut(self):
"""
L{LogPublisher} calls its observers.
"""
event = dict(foo=1, bar=2)
events1 = []
events2 = []
events3 = []
o1 = lambda e: events1.append(e)
o2 = lambda e: events2.append(e)
o3 = lambda e: events3.append(e)
publisher = LogPublisher(o1, o2, o3)
publisher(event)
self.assertIn(event, events1)
self.assertIn(event, events2)
self.assertIn(event, events3)
def test_observerRaises(self):
nonTestEvents = []
Logger.publisher.addObserver(lambda e: nonTestEvents.append(e))
event = dict(foo=1, bar=2)
exception = RuntimeError("ARGH! EVIL DEATH!")
events = []
def observer(event):
events.append(event)
raise exception
publisher = LogPublisher(observer)
publisher(event)
# Verify that the observer saw my event
self.assertIn(event, events)
# Verify that the observer raised my exception
errors = self.flushLoggedErrors(exception.__class__)
self.assertEquals(len(errors), 1)
self.assertIdentical(errors[0].value, exception)
# Verify that the exception was logged
for event in nonTestEvents:
if (
event.get("log_format", None) == OBSERVER_REMOVED and
getattr(event.get("failure", None), "value") is exception
):
break
else:
self.fail("Observer raised an exception "
"and the exception was not logged.")
def test_observerRaisesAndLoggerHatesMe(self):
nonTestEvents = []
Logger.publisher.addObserver(lambda e: nonTestEvents.append(e))
event = dict(foo=1, bar=2)
exception = RuntimeError("ARGH! EVIL DEATH!")
def observer(event):
raise RuntimeError("Sad panda")
class GurkLogger(Logger):
def failure(self, *args, **kwargs):
raise exception
publisher = LogPublisher(observer)
publisher.log = GurkLogger()
publisher(event)
# Here, the lack of an exception thus far is a success, of sorts
class DefaultLogPublisherTests(SetUpTearDown, unittest.TestCase):
def test_addObserver(self):
o1 = lambda e: None
o2 = lambda e: None
o3 = lambda e: None
publisher = DefaultLogPublisher()
publisher.addObserver(o1)
publisher.addObserver(o2, filtered=True)
publisher.addObserver(o3, filtered=False)
self.assertEquals(
set((o1, o2, publisher.legacyLogObserver)),
set(publisher.filteredPublisher.observers),
"Filtered observers do not match expected set"
)
self.assertEquals(
set((o3, publisher.filters)),
set(publisher.rootPublisher.observers),
"Root observers do not match expected set"
)
def test_addObserverAgain(self):
o1 = lambda e: None
o2 = lambda e: None
o3 = lambda e: None
publisher = DefaultLogPublisher()
publisher.addObserver(o1)
publisher.addObserver(o2, filtered=True)
publisher.addObserver(o3, filtered=False)
# Swap filtered-ness of o2 and o3
publisher.addObserver(o1)
publisher.addObserver(o2, filtered=False)
publisher.addObserver(o3, filtered=True)
self.assertEquals(
set((o1, o3, publisher.legacyLogObserver)),
set(publisher.filteredPublisher.observers),
"Filtered observers do not match expected set"
)
self.assertEquals(
set((o2, publisher.filters)),
set(publisher.rootPublisher.observers),
"Root observers do not match expected set"
)
def test_removeObserver(self):
o1 = lambda e: None
o2 = lambda e: None
o3 = lambda e: None
publisher = DefaultLogPublisher()
publisher.addObserver(o1)
publisher.addObserver(o2, filtered=True)
publisher.addObserver(o3, filtered=False)
publisher.removeObserver(o2)
publisher.removeObserver(o3)
self.assertEquals(
set((o1, publisher.legacyLogObserver)),
set(publisher.filteredPublisher.observers),
"Filtered observers do not match expected set"
)
self.assertEquals(
set((publisher.filters,)),
set(publisher.rootPublisher.observers),
"Root observers do not match expected set"
)
def test_filteredObserver(self):
namespace = __name__
event_debug = dict(log_namespace=namespace,
log_level=LogLevel.debug, log_format="")
event_error = dict(log_namespace=namespace,
log_level=LogLevel.error, log_format="")
events = []
observer = lambda e: events.append(e)
publisher = DefaultLogPublisher()
publisher.addObserver(observer, filtered=True)
publisher(event_debug)
publisher(event_error)
self.assertNotIn(event_debug, events)
self.assertIn(event_error, events)
def test_filteredObserverNoFilteringKeys(self):
event_debug = dict(log_level=LogLevel.debug)
event_error = dict(log_level=LogLevel.error)
event_none = dict()
events = []
observer = lambda e: events.append(e)
publisher = DefaultLogPublisher()
publisher.addObserver(observer, filtered=True)
publisher(event_debug)
publisher(event_error)
publisher(event_none)
self.assertNotIn(event_debug, events)
self.assertNotIn(event_error, events)
self.assertNotIn(event_none, events)
def test_unfilteredObserver(self):
namespace = __name__
event_debug = dict(log_namespace=namespace, log_level=LogLevel.debug,
log_format="")
event_error = dict(log_namespace=namespace, log_level=LogLevel.error,
log_format="")
events = []
observer = lambda e: events.append(e)
publisher = DefaultLogPublisher()
publisher.addObserver(observer, filtered=False)
publisher(event_debug)
publisher(event_error)
self.assertIn(event_debug, events)
self.assertIn(event_error, events)
class FilteringLogObserverTests(SetUpTearDown, unittest.TestCase):
"""
Tests for L{FilteringLogObserver}.
"""
def test_interface(self):
"""
L{FilteringLogObserver} is an L{ILogObserver}.
"""
observer = FilteringLogObserver(lambda e: None, ())
try:
verifyObject(ILogObserver, observer)
except BrokenMethodImplementation as e:
self.fail(e)
def filterWith(self, *filters):
events = [
dict(count=0),
dict(count=1),
dict(count=2),
dict(count=3),
]
class Filters(object):
@staticmethod
def twoMinus(event):
if event["count"] <= 2:
return PredicateResult.yes
return PredicateResult.maybe
@staticmethod
def twoPlus(event):
if event["count"] >= 2:
return PredicateResult.yes
return PredicateResult.maybe
@staticmethod
def notTwo(event):
if event["count"] == 2:
return PredicateResult.no
return PredicateResult.maybe
@staticmethod
def no(event):
return PredicateResult.no
@staticmethod
def bogus(event):
return None
predicates = (getattr(Filters, f) for f in filters)
eventsSeen = []
trackingObserver = lambda e: eventsSeen.append(e)
filteringObserver = FilteringLogObserver(trackingObserver, predicates)
for e in events:
filteringObserver(e)
return [e["count"] for e in eventsSeen]
def test_shouldLogEvent_noFilters(self):
self.assertEquals(self.filterWith(), [0, 1, 2, 3])
def test_shouldLogEvent_noFilter(self):
self.assertEquals(self.filterWith("notTwo"), [0, 1, 3])
def test_shouldLogEvent_yesFilter(self):
self.assertEquals(self.filterWith("twoPlus"), [0, 1, 2, 3])
def test_shouldLogEvent_yesNoFilter(self):
self.assertEquals(self.filterWith("twoPlus", "no"), [2, 3])
def test_shouldLogEvent_yesYesNoFilter(self):
self.assertEquals(self.filterWith("twoPlus", "twoMinus", "no"),
[0, 1, 2, 3])
def test_shouldLogEvent_badPredicateResult(self):
self.assertRaises(TypeError, self.filterWith, "bogus")
def test_call(self):
e = dict(obj=object())
def callWithPredicateResult(result):
seen = []
observer = FilteringLogObserver(lambda e: seen.append(e),
(lambda e: result,))
observer(e)
return seen
self.assertIn(e, callWithPredicateResult(PredicateResult.yes))
self.assertIn(e, callWithPredicateResult(PredicateResult.maybe))
self.assertNotIn(e, callWithPredicateResult(PredicateResult.no))
class LegacyLoggerTests(SetUpTearDown, unittest.TestCase):
"""
Tests for L{LegacyLogger}.
"""
def test_namespace_default(self):
"""
Default namespace is module name.
"""
log = TestLegacyLogger(logger=None)
self.assertEquals(log.newStyleLogger.namespace, __name__)
def test_passThroughAttributes(self):
"""
C{__getattribute__} on L{LegacyLogger} is passing through to Twisted's
logging module.
"""
log = TestLegacyLogger()
# Not passed through
self.assertIn("API-compatible", log.msg.__doc__)
self.assertIn("API-compatible", log.err.__doc__)
# Passed through
self.assertIdentical(log.addObserver, twistedLogging.addObserver)
def test_legacy_msg(self):
"""
Test LegacyLogger's log.msg()
"""
log = TestLegacyLogger()
message = "Hi, there."
kwargs = {"foo": "bar", "obj": object()}
log.msg(message, **kwargs)
self.assertIdentical(log.newStyleLogger.emitted["level"],
LogLevel.info)
self.assertEquals(log.newStyleLogger.emitted["format"], message)
for key, value in kwargs.items():
self.assertIdentical(log.newStyleLogger.emitted["kwargs"][key],
value)
log.msg(foo="")
self.assertIdentical(log.newStyleLogger.emitted["level"],
LogLevel.info)
self.assertIdentical(log.newStyleLogger.emitted["format"], None)
def test_legacy_err_implicit(self):
"""
Test LegacyLogger's log.err() capturing the in-flight exception.
"""
log = TestLegacyLogger()
exception = RuntimeError("Oh me, oh my.")
kwargs = {"foo": "bar", "obj": object()}
try:
raise exception
except RuntimeError:
log.err(**kwargs)
self.legacy_err(log, kwargs, None, exception)
def test_legacy_err_exception(self):
"""
Test LegacyLogger's log.err() with a given exception.
"""
log = TestLegacyLogger()
exception = RuntimeError("Oh me, oh my.")
kwargs = {"foo": "bar", "obj": object()}
why = "Because I said so."
try:
raise exception
except RuntimeError as e:
log.err(e, why, **kwargs)
self.legacy_err(log, kwargs, why, exception)
def test_legacy_err_failure(self):
"""
Test LegacyLogger's log.err() with a given L{Failure}.
"""
log = TestLegacyLogger()
exception = RuntimeError("Oh me, oh my.")
kwargs = {"foo": "bar", "obj": object()}
why = "Because I said so."
try:
raise exception
except RuntimeError:
log.err(Failure(), why, **kwargs)
self.legacy_err(log, kwargs, why, exception)
def test_legacy_err_bogus(self):
"""
Test LegacyLogger's log.err() with a bogus argument.
"""
log = TestLegacyLogger()
exception = RuntimeError("Oh me, oh my.")
kwargs = {"foo": "bar", "obj": object()}
why = "Because I said so."
bogus = object()
try:
raise exception
except RuntimeError:
log.err(bogus, why, **kwargs)
errors = self.flushLoggedErrors(exception.__class__)
self.assertEquals(len(errors), 0)
self.assertIdentical(log.newStyleLogger.emitted["level"],
LogLevel.error)
self.assertEquals(log.newStyleLogger.emitted["format"], repr(bogus))
self.assertIdentical(log.newStyleLogger.emitted["kwargs"]["why"], why)
for key, value in kwargs.items():
self.assertIdentical(log.newStyleLogger.emitted["kwargs"][key],
value)
def legacy_err(self, log, kwargs, why, exception):
#
# log.failure() will cause trial to complain, so here we check that
# trial saw the correct error and remove it from the list of things to
# complain about.
#
errors = self.flushLoggedErrors(exception.__class__)
self.assertEquals(len(errors), 1)
self.assertIdentical(log.newStyleLogger.emitted["level"],
LogLevel.error)
self.assertEquals(log.newStyleLogger.emitted["format"], None)
emittedKwargs = log.newStyleLogger.emitted["kwargs"]
self.assertIdentical(emittedKwargs["failure"].__class__, Failure)
self.assertIdentical(emittedKwargs["failure"].value, exception)
self.assertIdentical(emittedKwargs["why"], why)
for key, value in kwargs.items():
self.assertIdentical(log.newStyleLogger.emitted["kwargs"][key],
value)
class Unformattable(object):
"""
An object that raises an exception from C{__repr__}.
"""
def __repr__(self):
return str(1/0)
calendarserver-5.2+dfsg/twext/python/test/test_launchd.py 0000644 0001750 0001750 00000030701 12263343324 023014 0 ustar rahul rahul ##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for L{twext.python.launchd}.
"""
import sys, os, plistlib, socket, json
if __name__ == '__main__':
# This module is loaded as a launchd job by test-cases below; the following
# code looks up an appropriate function to run.
testID = sys.argv[1]
a, b = testID.rsplit(".", 1)
from twisted.python.reflect import namedAny
try:
namedAny(".".join([a, b.replace("test_", "job_")]))()
finally:
sys.stdout.flush()
sys.stderr.flush()
skt = socket.socket()
skt.connect(("127.0.0.1", int(os.environ["TESTING_PORT"])))
sys.exit(0)
try:
from twext.python.launchd import (
lib, ffi, _LaunchDictionary, _LaunchArray, _managed, constants,
plainPython, checkin, _launchify, getLaunchDSocketFDs
)
except ImportError:
skip = "LaunchD not available."
else:
skip = False
from twisted.trial.unittest import TestCase
from twisted.python.filepath import FilePath
class LaunchDataStructures(TestCase):
"""
Tests for L{_launchify} converting data structures from launchd's internals
to Python objects.
"""
def test_fd(self):
"""
Test converting a launchd FD to an integer.
"""
fd = _managed(lib.launch_data_new_fd(2))
self.assertEquals(_launchify(fd), 2)
def test_bool(self):
"""
Test converting a launchd bool to a Python bool.
"""
t = _managed(lib.launch_data_new_bool(True))
f = _managed(lib.launch_data_new_bool(False))
self.assertEqual(_launchify(t), True)
self.assertEqual(_launchify(f), False)
def test_real(self):
"""
Test converting a launchd real to a Python float.
"""
notQuitePi = _managed(lib.launch_data_new_real(3.14158))
self.assertEqual(_launchify(notQuitePi), 3.14158)
class DictionaryTests(TestCase):
"""
Tests for L{_LaunchDictionary}
"""
def setUp(self):
"""
Assemble a test dictionary.
"""
self.testDict = _managed(
lib.launch_data_alloc(lib.LAUNCH_DATA_DICTIONARY)
)
key1 = ffi.new("char[]", "alpha")
val1 = lib.launch_data_new_string("alpha-value")
key2 = ffi.new("char[]", "beta")
val2 = lib.launch_data_new_string("beta-value")
key3 = ffi.new("char[]", "gamma")
val3 = lib.launch_data_new_integer(3)
lib.launch_data_dict_insert(self.testDict, val1, key1)
lib.launch_data_dict_insert(self.testDict, val2, key2)
lib.launch_data_dict_insert(self.testDict, val3, key3)
self.assertEquals(lib.launch_data_dict_get_count(self.testDict), 3)
def test_len(self):
"""
C{len(_LaunchDictionary())} returns the number of keys in the
dictionary.
"""
self.assertEquals(len(_LaunchDictionary(self.testDict)), 3)
def test_keys(self):
"""
L{_LaunchDictionary.keys} returns keys present in a C{launch_data_dict}.
"""
dictionary = _LaunchDictionary(self.testDict)
self.assertEquals(set(dictionary.keys()),
set([b"alpha", b"beta", b"gamma"]))
def test_values(self):
"""
L{_LaunchDictionary.values} returns keys present in a
C{launch_data_dict}.
"""
dictionary = _LaunchDictionary(self.testDict)
self.assertEquals(set(dictionary.values()),
set([b"alpha-value", b"beta-value", 3]))
def test_items(self):
"""
L{_LaunchDictionary.items} returns all (key, value) tuples present in a
C{launch_data_dict}.
"""
dictionary = _LaunchDictionary(self.testDict)
self.assertEquals(set(dictionary.items()),
set([(b"alpha", b"alpha-value"),
(b"beta", b"beta-value"), (b"gamma", 3)]))
def test_plainPython(self):
"""
L{plainPython} will convert a L{_LaunchDictionary} into a Python
dictionary.
"""
self.assertEquals({b"alpha": b"alpha-value", b"beta": b"beta-value",
b"gamma": 3},
plainPython(_LaunchDictionary(self.testDict)))
def test_plainPythonNested(self):
"""
L{plainPython} will convert a L{_LaunchDictionary} containing another
L{_LaunchDictionary} into a nested Python dictionary.
"""
otherDict = lib.launch_data_alloc(lib.LAUNCH_DATA_DICTIONARY)
lib.launch_data_dict_insert(otherDict,
lib.launch_data_new_string("bar"), "foo")
lib.launch_data_dict_insert(self.testDict, otherDict, "delta")
self.assertEquals({b"alpha": b"alpha-value", b"beta": b"beta-value",
b"gamma": 3, b"delta": {b"foo": b"bar"}},
plainPython(_LaunchDictionary(self.testDict)))
class ArrayTests(TestCase):
"""
Tests for L{_LaunchArray}
"""
def setUp(self):
"""
Assemble a test array.
"""
self.testArray = ffi.gc(
lib.launch_data_alloc(lib.LAUNCH_DATA_ARRAY),
lib.launch_data_free
)
lib.launch_data_array_set_index(
self.testArray, lib.launch_data_new_string("test-string-1"), 0
)
lib.launch_data_array_set_index(
self.testArray, lib.launch_data_new_string("another string."), 1
)
lib.launch_data_array_set_index(
self.testArray, lib.launch_data_new_integer(4321), 2
)
def test_length(self):
"""
C{len(_LaunchArray(...))} returns the number of elements in the array.
"""
self.assertEquals(len(_LaunchArray(self.testArray)), 3)
def test_indexing(self):
"""
C{_LaunchArray(...)[n]} returns the n'th element in the array.
"""
array = _LaunchArray(self.testArray)
self.assertEquals(array[0], b"test-string-1")
self.assertEquals(array[1], b"another string.")
self.assertEquals(array[2], 4321)
def test_indexTooBig(self):
"""
C{_LaunchArray(...)[n]}, where C{n} is greater than the length of the
array, raises an L{IndexError}.
"""
array = _LaunchArray(self.testArray)
self.assertRaises(IndexError, lambda: array[3])
def test_iterating(self):
"""
Iterating over a C{_LaunchArray} returns each item in sequence.
"""
array = _LaunchArray(self.testArray)
i = iter(array)
self.assertEquals(i.next(), b"test-string-1")
self.assertEquals(i.next(), b"another string.")
self.assertEquals(i.next(), 4321)
self.assertRaises(StopIteration, i.next)
def test_plainPython(self):
"""
L{plainPython} converts a L{_LaunchArray} into a Python list.
"""
array = _LaunchArray(self.testArray)
self.assertEquals(plainPython(array),
[b"test-string-1", b"another string.", 4321])
def test_plainPythonNested(self):
"""
L{plainPython} converts a L{_LaunchArray} containing another
L{_LaunchArray} into a Python list.
"""
sub = lib.launch_data_alloc(lib.LAUNCH_DATA_ARRAY)
lib.launch_data_array_set_index(sub, lib.launch_data_new_integer(7), 0)
lib.launch_data_array_set_index(self.testArray, sub, 3)
array = _LaunchArray(self.testArray)
self.assertEqual(plainPython(array), [b"test-string-1",
b"another string.", 4321, [7]])
class SimpleStringConstants(TestCase):
"""
Tests for bytestring-constants wrapping.
"""
def test_constant(self):
"""
C{launchd.constants.LAUNCH_*} will return a bytes object corresponding
to a constant.
"""
self.assertEqual(constants.LAUNCH_JOBKEY_SOCKETS,
b"Sockets")
self.assertRaises(AttributeError, getattr, constants,
"launch_data_alloc")
self.assertEquals(constants.LAUNCH_DATA_ARRAY, 2)
class CheckInTests(TestCase):
"""
Integration tests making sure that actual checkin with launchd results in
the expected values.
"""
def setUp(self):
fp = FilePath(self.mktemp())
fp.makedirs()
from twisted.internet.protocol import Protocol, Factory
from twisted.internet import reactor, defer
d = defer.Deferred()
class JustLetMeMoveOn(Protocol):
def connectionMade(self):
d.callback(None)
self.transport.abortConnection()
f = Factory()
f.protocol = JustLetMeMoveOn
port = reactor.listenTCP(0, f, interface="127.0.0.1")
@self.addCleanup
def goodbyePort():
return port.stopListening()
env = dict(os.environ)
env["TESTING_PORT"] = repr(port.getHost().port)
self.stdout = fp.child("stdout.txt")
self.stderr = fp.child("stderr.txt")
self.launchLabel = ("org.calendarserver.UNIT-TESTS." +
str(os.getpid()) + "." + self.id())
plist = {
"Label": self.launchLabel,
"ProgramArguments": [sys.executable, "-m", __name__, self.id()],
"EnvironmentVariables": env,
"KeepAlive": False,
"StandardOutPath": self.stdout.path,
"StandardErrorPath": self.stderr.path,
"Sockets": {
"Awesome": [{"SecureSocketWithKey": "GeneratedSocket"}]
},
"RunAtLoad": True,
}
self.job = fp.child("job.plist")
self.job.setContent(plistlib.writePlistToString(plist))
os.spawnlp(os.P_WAIT, "launchctl", "launchctl", "load", self.job.path)
return d
@staticmethod
def job_test():
"""
Do something observable in a subprocess.
"""
sys.stdout.write("Sample Value.")
sys.stdout.flush()
def test_test(self):
"""
Since this test framework is somewhat finicky, let's just make sure
that a test can complete.
"""
self.assertEquals("Sample Value.", self.stdout.getContent())
@staticmethod
def job_checkin():
"""
Check in in the subprocess.
"""
sys.stdout.write(json.dumps(plainPython(checkin())))
def test_checkin(self):
"""
L{checkin} performs launchd checkin and returns a launchd data
structure.
"""
d = json.loads(self.stdout.getContent())
self.assertEqual(d[constants.LAUNCH_JOBKEY_LABEL], self.launchLabel)
self.assertIsInstance(d, dict)
sockets = d[constants.LAUNCH_JOBKEY_SOCKETS]
self.assertEquals(len(sockets), 1)
self.assertEqual(['Awesome'], sockets.keys())
awesomeSocket = sockets['Awesome']
self.assertEqual(len(awesomeSocket), 1)
self.assertIsInstance(awesomeSocket[0], int)
@staticmethod
def job_getFDs():
"""
Check-in via the high-level C{getLaunchDSocketFDs} API, that just gives
us listening FDs.
"""
sys.stdout.write(json.dumps(getLaunchDSocketFDs()))
def test_getFDs(self):
"""
L{getLaunchDSocketFDs} returns a Python dictionary mapping the names of
sockets specified in the property list to lists of integers
representing FDs.
"""
sockets = json.loads(self.stdout.getContent())
self.assertEquals(len(sockets), 1)
self.assertEqual(['Awesome'], sockets.keys())
awesomeSocket = sockets['Awesome']
self.assertEqual(len(awesomeSocket), 1)
self.assertIsInstance(awesomeSocket[0], int)
def tearDown(self):
"""
Un-load the launchd job and report any errors it encountered.
"""
os.spawnlp(os.P_WAIT, "launchctl",
"launchctl", "unload", self.job.path)
err = self.stderr.getContent()
if 'Traceback' in err:
self.fail(err)
calendarserver-5.2+dfsg/twext/python/test/test_sendmsg.py 0000644 0001750 0001750 00000012207 12263343324 023037 0 ustar rahul rahul ##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
import socket
from os import pipe, read, close, environ
from twext.python.filepath import CachingFilePath as FilePath
import sys
from twisted.internet.defer import Deferred
from twisted.internet.error import ProcessDone
from twisted.trial.unittest import TestCase
from twisted.internet.defer import inlineCallbacks
from twisted.internet import reactor
from twext.python.sendmsg import sendmsg, recvmsg
from twext.python.sendfd import sendfd
from twisted.internet.protocol import ProcessProtocol
class ExitedWithStderr(Exception):
"""
A process exited with some stderr.
"""
def __str__(self):
"""
Dump the errors in a pretty way in the event of a subprocess traceback.
"""
return '\n'.join([''] + list(self.args))
class StartStopProcessProtocol(ProcessProtocol):
"""
An L{IProcessProtocol} with a Deferred for events where the subprocess
starts and stops.
"""
def __init__(self):
self.started = Deferred()
self.stopped = Deferred()
self.output = ''
self.errors = ''
def connectionMade(self):
self.started.callback(self.transport)
def outReceived(self, data):
self.output += data
def errReceived(self, data):
self.errors += data
def processEnded(self, reason):
if reason.check(ProcessDone):
self.stopped.callback(self.output)
else:
self.stopped.errback(ExitedWithStderr(
self.errors, self.output))
def bootReactor():
"""
Yield this from a trial test to bootstrap the reactor in order to avoid
PotentialZombieWarning, for tests that use subprocesses. This hack will no
longer be necessary in Twisted 10.1, since U{the underlying bug was fixed
}.
"""
d = Deferred()
reactor.callLater(0, d.callback, None)
return d
class SendmsgTestCase(TestCase):
"""
Tests for sendmsg extension module and associated file-descriptor sending
functionality in L{twext.python.sendfd}.
"""
def setUp(self):
"""
Create a pair of UNIX sockets.
"""
self.input, self.output = socket.socketpair(socket.AF_UNIX)
def tearDown(self):
"""
Close the sockets opened by setUp.
"""
self.input.close()
self.output.close()
def test_roundtrip(self):
"""
L{recvmsg} will retrieve a message sent via L{sendmsg}.
"""
sendmsg(self.input.fileno(), "hello, world!", 0)
result = recvmsg(fd=self.output.fileno())
self.assertEquals(result, ("hello, world!", 0, []))
def test_wrongTypeAncillary(self):
"""
L{sendmsg} will show a helpful exception message when given the wrong
type of object for the 'ancillary' argument.
"""
error = self.assertRaises(TypeError,
sendmsg, self.input.fileno(),
"hello, world!", 0, 4321)
self.assertEquals(str(error),
"sendmsg argument 3 expected list, got int")
def spawn(self, script):
"""
Start a script that is a peer of this test as a subprocess.
@param script: the module name of the script in this directory (no
package prefix, no '.py')
@type script: C{str}
@rtype: L{StartStopProcessProtocol}
"""
sspp = StartStopProcessProtocol()
reactor.spawnProcess(
sspp, sys.executable, [
sys.executable,
FilePath(__file__).sibling(script + ".py").path,
str(self.output.fileno()),
],
environ,
childFDs={0: "w", 1: "r", 2: "r",
self.output.fileno(): self.output.fileno()}
)
return sspp
@inlineCallbacks
def test_sendSubProcessFD(self):
"""
Calling L{sendsmsg} with SOL_SOCKET, SCM_RIGHTS, and a platform-endian
packed file descriptor number should send that file descriptor to a
different process, where it can be retrieved by using L{recvmsg}.
"""
yield bootReactor()
sspp = self.spawn("pullpipe")
yield sspp.started
pipeOut, pipeIn = pipe()
self.addCleanup(close, pipeOut)
sendfd(self.input.fileno(), pipeIn, "blonk")
close(pipeIn)
yield sspp.stopped
self.assertEquals(read(pipeOut, 1024), "Test fixture data: blonk.\n")
# Make sure that the pipe is actually closed now.
self.assertEquals(read(pipeOut, 1024), "")
calendarserver-5.2+dfsg/twext/python/test/pullpipe.py 0000644 0001750 0001750 00000001607 12263343324 022174 0 ustar rahul rahul #!/usr/bin/python
# -*- test-case-name: twext.python.test.test_sendmsg -*-
##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
if __name__ == '__main__':
from twext.python.sendfd import recvfd
import sys, os
fd, description = recvfd(int(sys.argv[1]))
os.write(fd, "Test fixture data: %s.\n" % (description,))
os.close(fd)
calendarserver-5.2+dfsg/twext/python/test/test_parallel.py 0000644 0001750 0001750 00000003760 12263343324 023177 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for L{twext.python.parallel}.
"""
from twisted.internet.defer import Deferred
from twext.python.parallel import Parallelizer
from twisted.trial.unittest import TestCase
class ParallelizerTests(TestCase):
"""
Tests for L{Parallelizer}.
"""
def test_doAndDone(self):
"""
Blanket catch-all test. (TODO: split this up into more nice
fine-grained tests.)
"""
d1 = Deferred()
d2 = Deferred()
d3 = Deferred()
d4 = Deferred()
doing = []
done = []
allDone = []
p = Parallelizer(['a', 'b', 'c'])
p.do(lambda a: doing.append(a) or d1).addCallback(done.append)
p.do(lambda b: doing.append(b) or d2).addCallback(done.append)
p.do(lambda c: doing.append(c) or d3).addCallback(done.append)
p.do(lambda b1: doing.append(b1) or d4).addCallback(done.append)
p.done().addCallback(allDone.append)
self.assertEqual(allDone, [])
self.assertEqual(doing, ['a', 'b', 'c'])
self.assertEqual(done, [None, None, None])
d2.callback(1)
self.assertEqual(doing, ['a', 'b', 'c', 'b'])
self.assertEqual(done, [None, None, None, None])
self.assertEqual(allDone, [])
d3.callback(2)
d4.callback(3)
d1.callback(4)
self.assertEqual(done, [None, None, None, None])
self.assertEqual(allDone, [None])
calendarserver-5.2+dfsg/twext/python/test/test_timezone.py 0000644 0001750 0001750 00000005011 12263343324 023224 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
from twistedcaldav.test.util import TestCase
from twistedcaldav.config import config
import twext.python.timezone
import twistedcaldav.timezones
from twext.python.timezone import getLocalTimezone, DEFAULT_TIMEZONE
class DefaultTimezoneTests(TestCase):
def stubLookup(self):
return self._storedLookup
def stubHasTZ(self, ignored):
return self._storedHasTZ.pop()
def setUp(self):
self.patch(twext.python.timezone, "lookupSystemTimezone", self.stubLookup)
self.patch(twistedcaldav.timezones,
"hasTZ", self.stubHasTZ)
def test_getLocalTimezone(self):
# Empty config, system timezone known = use system timezone
self.patch(config, "DefaultTimezone", "")
self._storedLookup = "America/New_York"
self._storedHasTZ = [True]
self.assertEquals(getLocalTimezone(), "America/New_York")
# Empty config, system timezone unknown = use DEFAULT_TIMEZONE
self.patch(config, "DefaultTimezone", "")
self._storedLookup = "Unknown/Unknown"
self._storedHasTZ = [False]
self.assertEquals(getLocalTimezone(), DEFAULT_TIMEZONE)
# Known config value = use config value
self.patch(config, "DefaultTimezone", "America/New_York")
self._storedHasTZ = [True]
self.assertEquals(getLocalTimezone(), "America/New_York")
# Unknown config value, system timezone known = use system timezone
self.patch(config, "DefaultTimezone", "Unknown/Unknown")
self._storedLookup = "America/New_York"
self._storedHasTZ = [True, False]
self.assertEquals(getLocalTimezone(), "America/New_York")
# Unknown config value, system timezone unknown = use DEFAULT_TIMEZONE
self.patch(config, "DefaultTimezone", "Unknown/Unknown")
self._storedLookup = "Unknown/Unknown"
self._storedHasTZ = [False, False]
self.assertEquals(getLocalTimezone(), DEFAULT_TIMEZONE)
calendarserver-5.2+dfsg/twext/python/test/__init__.py 0000644 0001750 0001750 00000001212 12263343324 022071 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Test extensions to twisted.python.
"""
calendarserver-5.2+dfsg/twext/python/launchd.py 0000644 0001750 0001750 00000020040 12263343324 020771 0 ustar rahul rahul # -*- test-case-name: twext.python.test.test_launchd -*-
##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Bindings for launchd check-in API.
@see: U{SampleD.c
}
@var ffi: a L{cffi.FFI} instance wrapping the functions exposed by C{launch.h}.
@var lib: a L{cffi} "U{dynamic library object
}"
wrapping the functions exposed by C{launch.h}.
@var constants: Select C{LAUNCH_*} constants from C{launch.h}, exposed as plain
Python values. Note that this is not a complete wrapping, but as the
header file suggests, these APIs are only for use during check-in.
"""
from __future__ import print_function
from cffi import FFI, VerificationError
ffi = FFI()
ffi.cdef("""
static const char* LAUNCH_KEY_CHECKIN;
static const char* LAUNCH_JOBKEY_LABEL;
static const char* LAUNCH_JOBKEY_SOCKETS;
typedef enum {
LAUNCH_DATA_DICTIONARY = 1,
LAUNCH_DATA_ARRAY,
LAUNCH_DATA_FD,
LAUNCH_DATA_INTEGER,
LAUNCH_DATA_REAL,
LAUNCH_DATA_BOOL,
LAUNCH_DATA_STRING,
LAUNCH_DATA_OPAQUE,
LAUNCH_DATA_ERRNO,
LAUNCH_DATA_MACHPORT,
} launch_data_type_t;
typedef struct _launch_data *launch_data_t;
bool launch_data_dict_insert(launch_data_t, const launch_data_t, const char *);
launch_data_t launch_data_alloc(launch_data_type_t);
launch_data_t launch_data_new_string(const char *);
launch_data_t launch_data_new_integer(long long);
launch_data_t launch_data_new_fd(int);
launch_data_t launch_data_new_bool(bool);
launch_data_t launch_data_new_real(double);
launch_data_t launch_msg(const launch_data_t);
launch_data_type_t launch_data_get_type(const launch_data_t);
launch_data_t launch_data_dict_lookup(const launch_data_t, const char *);
size_t launch_data_dict_get_count(const launch_data_t);
long long launch_data_get_integer(const launch_data_t);
void launch_data_dict_iterate(
const launch_data_t, void (*)(const launch_data_t, const char *, void *),
void *);
int launch_data_get_fd(const launch_data_t);
bool launch_data_get_bool(const launch_data_t);
const char * launch_data_get_string(const launch_data_t);
double launch_data_get_real(const launch_data_t);
size_t launch_data_array_get_count(const launch_data_t);
launch_data_t launch_data_array_get_index(const launch_data_t, size_t);
bool launch_data_array_set_index(launch_data_t, const launch_data_t, size_t);
void launch_data_free(launch_data_t);
""")
try:
lib = ffi.verify("""
#include
""",
tag=__name__.replace(".", "_"))
except VerificationError as ve:
raise ImportError(ve)
class _LaunchArray(object):
def __init__(self, launchdata):
self.launchdata = launchdata
def __len__(self):
return lib.launch_data_array_get_count(self.launchdata)
def __getitem__(self, index):
if index >= len(self):
raise IndexError(index)
return _launchify(
lib.launch_data_array_get_index(self.launchdata, index)
)
class _LaunchDictionary(object):
def __init__(self, launchdata):
self.launchdata = launchdata
def keys(self):
"""
Return keys in the dictionary.
"""
keys = []
@ffi.callback("void (*)(const launch_data_t, const char *, void *)")
def icb(v, k, n):
keys.append(ffi.string(k))
lib.launch_data_dict_iterate(self.launchdata, icb, ffi.NULL)
return keys
def values(self):
"""
Return values in the dictionary.
"""
values = []
@ffi.callback("void (*)(const launch_data_t, const char *, void *)")
def icb(v, k, n):
values.append(_launchify(v))
lib.launch_data_dict_iterate(self.launchdata, icb, ffi.NULL)
return values
def items(self):
"""
Return items in the dictionary.
"""
values = []
@ffi.callback("void (*)(const launch_data_t, const char *, void *)")
def icb(v, k, n):
values.append((ffi.string(k), _launchify(v)))
lib.launch_data_dict_iterate(self.launchdata, icb, ffi.NULL)
return values
def __getitem__(self, key):
launchvalue = lib.launch_data_dict_lookup(self.launchdata, key)
try:
return _launchify(launchvalue)
except LaunchErrno:
raise KeyError(key)
def __len__(self):
return lib.launch_data_dict_get_count(self.launchdata)
def plainPython(x):
"""
Convert a launchd python-like data structure into regular Python
dictionaries and lists.
"""
if isinstance(x, _LaunchDictionary):
result = {}
for k, v in x.items():
result[k] = plainPython(v)
return result
elif isinstance(x, _LaunchArray):
return map(plainPython, x)
else:
return x
class LaunchErrno(Exception):
"""
Error from launchd.
"""
def _launchify(launchvalue):
"""
Convert a ctypes value wrapping a C{_launch_data} structure into the
relevant Python object (integer, bytes, L{_LaunchDictionary},
L{_LaunchArray}).
"""
if launchvalue == ffi.NULL:
return None
dtype = lib.launch_data_get_type(launchvalue)
if dtype == lib.LAUNCH_DATA_DICTIONARY:
return _LaunchDictionary(launchvalue)
elif dtype == lib.LAUNCH_DATA_ARRAY:
return _LaunchArray(launchvalue)
elif dtype == lib.LAUNCH_DATA_FD:
return lib.launch_data_get_fd(launchvalue)
elif dtype == lib.LAUNCH_DATA_INTEGER:
return lib.launch_data_get_integer(launchvalue)
elif dtype == lib.LAUNCH_DATA_REAL:
return lib.launch_data_get_real(launchvalue)
elif dtype == lib.LAUNCH_DATA_BOOL:
return lib.launch_data_get_bool(launchvalue)
elif dtype == lib.LAUNCH_DATA_STRING:
cvalue = lib.launch_data_get_string(launchvalue)
if cvalue == ffi.NULL:
return None
return ffi.string(cvalue)
elif dtype == lib.LAUNCH_DATA_OPAQUE:
return launchvalue
elif dtype == lib.LAUNCH_DATA_ERRNO:
raise LaunchErrno(launchvalue)
elif dtype == lib.LAUNCH_DATA_MACHPORT:
return lib.launch_data_get_machport(launchvalue)
else:
raise TypeError("Unknown Launch Data Type", dtype)
def checkin():
"""
Perform a launchd checkin, returning a Pythonic wrapped data structure
representing the retrieved check-in plist.
@return: a C{dict}-like object.
"""
lkey = lib.launch_data_new_string(lib.LAUNCH_KEY_CHECKIN)
msgr = lib.launch_msg(lkey)
return _launchify(msgr)
def _managed(obj):
"""
Automatically free an object that was allocated with a launch_data_*
function, or raise L{MemoryError} if it's C{NULL}.
"""
if obj == ffi.NULL:
raise MemoryError()
else:
return ffi.gc(obj, lib.launch_data_free)
class _Strings(object):
"""
Expose constants as Python-readable values rather than wrapped ctypes
pointers.
"""
def __getattribute__(self, name):
value = getattr(lib, name)
if isinstance(value, int):
return value
if ffi.typeof(value) != ffi.typeof("char *"):
raise AttributeError("no such constant", name)
return ffi.string(value)
constants = _Strings()
def getLaunchDSocketFDs():
"""
Perform checkin via L{checkin} and return just a dictionary mapping the
sockets to file descriptors.
"""
return plainPython(checkin()[constants.LAUNCH_JOBKEY_SOCKETS])
__all__ = [
'checkin',
'lib',
'ffi',
'plainPython',
]
calendarserver-5.2+dfsg/twext/python/sendfd.py 0000644 0001750 0001750 00000004537 12263343324 020633 0 ustar rahul rahul # -*- test-case-name: twext.python.test.test_sendmsg -*-
##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
from struct import pack, unpack, calcsize
from socket import SOL_SOCKET
from twext.python.sendmsg import sendmsg, recvmsg, SCM_RIGHTS
def sendfd(socketfd, fd, description):
"""
Send the given FD to another process via L{sendmsg} on the given C{AF_UNIX}
socket.
@param socketfd: An C{AF_UNIX} socket, attached to another process waiting
to receive sockets via the ancillary data mechanism in L{sendmsg}.
@type socketfd: C{int}
@param fd: A file descriptor to be sent to the other process.
@type fd: C{int}
@param description: a string describing the socket that was passed.
@type description: C{str}
"""
sendmsg(
socketfd, description, 0, [(SOL_SOCKET, SCM_RIGHTS, pack("i", fd))]
)
def recvfd(socketfd):
"""
Receive a file descriptor from a L{sendmsg} message on the given C{AF_UNIX}
socket.
@param socketfd: An C{AF_UNIX} socket, attached to another process waiting
to send sockets via the ancillary data mechanism in L{sendmsg}.
@param fd: C{int}
@return: a 2-tuple of (new file descriptor, description).
@rtype: 2-tuple of (C{int}, C{str})
"""
data, _ignore_flags, ancillary = recvmsg(socketfd)
[(_ignore_cmsg_level, _ignore_cmsg_type, packedFD)] = ancillary
# cmsg_level and cmsg_type really need to be SOL_SOCKET / SCM_RIGHTS, but
# since those are the *only* standard values, there's not much point in
# checking.
unpackedFD = 0
int_size = calcsize("i")
if len(packedFD) > int_size: # [ar]happens on 64 bit architecture (FreeBSD)
[unpackedFD] = unpack("i", packedFD[0:int_size])
else:
[unpackedFD] = unpack("i", packedFD)
return (unpackedFD, data)
calendarserver-5.2+dfsg/twext/python/_plistlib.py 0000644 0001750 0001750 00000035314 12113213176 021341 0 ustar rahul rahul #
# Added to standard library in Python 2.6 (Mac only in prior versions)
#
from __future__ import print_function
"""plistlib.py -- a tool to generate and parse MacOSX .plist files.
The PropertList (.plist) file format is a simple XML pickle supporting
basic object types, like dictionaries, lists, numbers and strings.
Usually the top level object is a dictionary.
To write out a plist file, use the writePlist(rootObject, pathOrFile)
function. 'rootObject' is the top level object, 'pathOrFile' is a
filename or a (writable) file object.
To parse a plist from a file, use the readPlist(pathOrFile) function,
with a file name or a (readable) file object as the only argument. It
returns the top level object (again, usually a dictionary).
To work with plist data in strings, you can use readPlistFromString()
and writePlistToString().
Values can be strings, integers, floats, booleans, tuples, lists,
dictionaries, Data or datetime.datetime objects. String values (including
dictionary keys) may be unicode strings -- they will be written out as
UTF-8.
The plist type is supported through the Data class. This is a
thin wrapper around a Python string.
Generate Plist example::
pl = dict(
aString="Doodah",
aList=["A", "B", 12, 32.1, [1, 2, 3]],
aFloat = 0.1,
anInt = 728,
aDict=dict(
anotherString="",
aUnicodeValue=u'M\xe4ssig, Ma\xdf',
aTrueValue=True,
aFalseValue=False,
),
someData = Data(""),
someMoreData = Data("" * 10),
aDate = datetime.datetime.fromtimestamp(time.mktime(time.gmtime())),
)
# unicode keys are possible, but a little awkward to use:
pl[u'\xc5benraa'] = "That was a unicode key."
writePlist(pl, fileName)
Parse Plist example::
pl = readPlist(pathOrFile)
print(pl["aKey"])
"""
__all__ = [
"readPlist", "writePlist", "readPlistFromString", "writePlistToString",
"readPlistFromResource", "writePlistToResource",
"Plist", "Data", "Dict"
]
# Note: the Plist and Dict classes have been deprecated.
import binascii
import datetime
from cStringIO import StringIO
import re
def readPlist(pathOrFile):
"""Read a .plist file. 'pathOrFile' may either be a file name or a
(readable) file object. Return the unpacked root object (which
usually is a dictionary).
"""
didOpen = 0
if isinstance(pathOrFile, (str, unicode)):
pathOrFile = open(pathOrFile)
didOpen = 1
p = PlistParser()
rootObject = p.parse(pathOrFile)
if didOpen:
pathOrFile.close()
return rootObject
def writePlist(rootObject, pathOrFile):
"""Write 'rootObject' to a .plist file. 'pathOrFile' may either be a
file name or a (writable) file object.
"""
didOpen = 0
if isinstance(pathOrFile, (str, unicode)):
pathOrFile = open(pathOrFile, "w")
didOpen = 1
writer = PlistWriter(pathOrFile)
writer.writeln("")
writer.writeValue(rootObject)
writer.writeln("")
if didOpen:
pathOrFile.close()
def readPlistFromString(data):
"""Read a plist data from a string. Return the root object.
"""
return readPlist(StringIO(data))
def writePlistToString(rootObject):
"""Return 'rootObject' as a plist-formatted string.
"""
f = StringIO()
writePlist(rootObject, f)
return f.getvalue()
def readPlistFromResource(path, restype='plst', resid=0):
"""Read plst resource from the resource fork of path.
"""
from Carbon.File import FSRef, FSGetResourceForkName
from Carbon.Files import fsRdPerm
from Carbon import Res
fsRef = FSRef(path)
resNum = Res.FSOpenResourceFile(fsRef, FSGetResourceForkName(), fsRdPerm)
Res.UseResFile(resNum)
plistData = Res.Get1Resource(restype, resid).data
Res.CloseResFile(resNum)
return readPlistFromString(plistData)
def writePlistToResource(rootObject, path, restype='plst', resid=0):
"""Write 'rootObject' as a plst resource to the resource fork of path.
"""
from Carbon.File import FSRef, FSGetResourceForkName
from Carbon.Files import fsRdWrPerm
from Carbon import Res
plistData = writePlistToString(rootObject)
fsRef = FSRef(path)
resNum = Res.FSOpenResourceFile(fsRef, FSGetResourceForkName(), fsRdWrPerm)
Res.UseResFile(resNum)
try:
Res.Get1Resource(restype, resid).RemoveResource()
except Res.Error:
pass
res = Res.Resource(plistData)
res.AddResource(restype, resid, '')
res.WriteResource()
Res.CloseResFile(resNum)
class DumbXMLWriter:
def __init__(self, file, indentLevel=0, indent="\t"):
self.file = file
self.stack = []
self.indentLevel = indentLevel
self.indent = indent
def beginElement(self, element):
self.stack.append(element)
self.writeln("<%s>" % element)
self.indentLevel += 1
def endElement(self, element):
assert self.indentLevel > 0
assert self.stack.pop() == element
self.indentLevel -= 1
self.writeln("%s>" % element)
def simpleElement(self, element, value=None):
if value is not None:
value = _escapeAndEncode(value)
self.writeln("<%s>%s%s>" % (element, value, element))
else:
self.writeln("<%s/>" % element)
def writeln(self, line):
if line:
self.file.write(self.indentLevel * self.indent + line + "\n")
else:
self.file.write("\n")
# Contents should conform to a subset of ISO 8601
# (in particular, YYYY '-' MM '-' DD 'T' HH ':' MM ':' SS 'Z'. Smaller units may be omitted with
# a loss of precision)
_dateParser = re.compile(r"(?P\d\d\d\d)(?:-(?P\d\d)(?:-(?P\d\d)(?:T(?P\d\d)(?::(?P\d\d)(?::(?P\d\d))?)?)?)?)?Z")
def _dateFromString(s):
order = ('year', 'month', 'day', 'hour', 'minute', 'second')
gd = _dateParser.match(s).groupdict()
lst = []
for key in order:
val = gd[key]
if val is None:
break
lst.append(int(val))
return datetime.datetime(*lst)
def _dateToString(d):
return '%04d-%02d-%02dT%02d:%02d:%02dZ' % (
d.year, d.month, d.day,
d.hour, d.minute, d.second
)
# Regex to find any control chars, except for \t \n and \r
_controlCharPat = re.compile(
r"[\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f"
r"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f]")
def _escapeAndEncode(text):
m = _controlCharPat.search(text)
if m is not None:
raise ValueError("strings can't contains control characters; "
"use plistlib.Data instead")
text = text.replace("\r\n", "\n") # convert DOS line endings
text = text.replace("\r", "\n") # convert Mac line endings
text = text.replace("&", "&") # escape '&'
text = text.replace("<", "<") # escape '<'
text = text.replace(">", ">") # escape '>'
return text.encode("utf-8") # encode as UTF-8
PLISTHEADER = """\
"""
class PlistWriter(DumbXMLWriter):
def __init__(self, file, indentLevel=0, indent="\t", writeHeader=1):
if writeHeader:
file.write(PLISTHEADER)
DumbXMLWriter.__init__(self, file, indentLevel, indent)
def writeValue(self, value):
if isinstance(value, (str, unicode)):
self.simpleElement("string", value)
elif isinstance(value, bool):
# must switch for bool before int, as bool is a
# subclass of int...
if value:
self.simpleElement("true")
else:
self.simpleElement("false")
elif isinstance(value, int):
self.simpleElement("integer", str(value))
elif isinstance(value, float):
self.simpleElement("real", repr(value))
elif isinstance(value, dict):
self.writeDict(value)
elif isinstance(value, Data):
self.writeData(value)
elif isinstance(value, datetime.datetime):
self.simpleElement("date", _dateToString(value))
elif isinstance(value, (tuple, list)):
self.writeArray(value)
else:
raise TypeError("unsuported type: %s" % type(value))
def writeData(self, data):
self.beginElement("data")
self.indentLevel -= 1
maxlinelength = 76 - len(self.indent.replace("\t", " " * 8) *
self.indentLevel)
for line in data.asBase64(maxlinelength).split("\n"):
if line:
self.writeln(line)
self.indentLevel += 1
self.endElement("data")
def writeDict(self, d):
self.beginElement("dict")
for key, value in sorted(d.items()):
if not isinstance(key, (str, unicode)):
raise TypeError("keys must be strings")
self.simpleElement("key", key)
self.writeValue(value)
self.endElement("dict")
def writeArray(self, array):
self.beginElement("array")
for value in array:
self.writeValue(value)
self.endElement("array")
class _InternalDict(dict):
# This class is needed while Dict is scheduled for deprecation:
# we only need to warn when a *user* instantiates Dict or when
# the "attribute notation for dict keys" is used.
def __getattr__(self, attr):
try:
value = self[attr]
except KeyError:
raise AttributeError, attr
from warnings import warn
warn("Attribute access from plist dicts is deprecated, use d[key] "
"notation instead", PendingDeprecationWarning)
return value
def __setattr__(self, attr, value):
from warnings import warn
warn("Attribute access from plist dicts is deprecated, use d[key] "
"notation instead", PendingDeprecationWarning)
self[attr] = value
def __delattr__(self, attr):
try:
del self[attr]
except KeyError:
raise AttributeError, attr
from warnings import warn
warn("Attribute access from plist dicts is deprecated, use d[key] "
"notation instead", PendingDeprecationWarning)
class Dict(_InternalDict):
def __init__(self, **kwargs):
from warnings import warn
warn("The plistlib.Dict class is deprecated, use builtin dict instead",
PendingDeprecationWarning)
super(Dict, self).__init__(**kwargs)
class Plist(_InternalDict):
"""This class has been deprecated. Use readPlist() and writePlist()
functions instead, together with regular dict objects.
"""
def __init__(self, **kwargs):
from warnings import warn
warn("The Plist class is deprecated, use the readPlist() and "
"writePlist() functions instead", PendingDeprecationWarning)
super(Plist, self).__init__(**kwargs)
def fromFile(cls, pathOrFile):
"""Deprecated. Use the readPlist() function instead."""
rootObject = readPlist(pathOrFile)
plist = cls()
plist.update(rootObject)
return plist
fromFile = classmethod(fromFile)
def write(self, pathOrFile):
"""Deprecated. Use the writePlist() function instead."""
writePlist(self, pathOrFile)
def _encodeBase64(s, maxlinelength=76):
# copied from base64.encodestring(), with added maxlinelength argument
maxbinsize = (maxlinelength//4)*3
pieces = []
for i in xrange(0, len(s), maxbinsize):
chunk = s[i : i + maxbinsize]
pieces.append(binascii.b2a_base64(chunk))
return "".join(pieces)
class Data:
"""Wrapper for binary data."""
def __init__(self, data):
self.data = data
def fromBase64(cls, data):
# base64.decodestring just calls binascii.a2b_base64;
# it seems overkill to use both base64 and binascii.
return cls(binascii.a2b_base64(data))
fromBase64 = classmethod(fromBase64)
def asBase64(self, maxlinelength=76):
return _encodeBase64(self.data, maxlinelength)
def __cmp__(self, other):
if isinstance(other, self.__class__):
return cmp(self.data, other.data)
elif isinstance(other, str):
return cmp(self.data, other)
else:
return cmp(id(self), id(other))
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, repr(self.data))
class PlistParser:
def __init__(self):
self.stack = []
self.currentKey = None
self.root = None
def parse(self, fileobj):
from xml.parsers.expat import ParserCreate
parser = ParserCreate()
parser.StartElementHandler = self.handleBeginElement
parser.EndElementHandler = self.handleEndElement
parser.CharacterDataHandler = self.handleData
parser.ParseFile(fileobj)
return self.root
def handleBeginElement(self, element, attrs):
self.data = []
handler = getattr(self, "begin_" + element, None)
if handler is not None:
handler(attrs)
def handleEndElement(self, element):
handler = getattr(self, "end_" + element, None)
if handler is not None:
handler()
def handleData(self, data):
self.data.append(data)
def addObject(self, value):
if self.currentKey is not None:
self.stack[-1][self.currentKey] = value
self.currentKey = None
elif not self.stack:
# this is the root object
self.root = value
else:
self.stack[-1].append(value)
def getData(self):
data = "".join(self.data)
try:
data = data.encode("ascii")
except UnicodeError:
pass
self.data = []
return data
# element handlers
def begin_dict(self, attrs):
d = _InternalDict()
self.addObject(d)
self.stack.append(d)
def end_dict(self):
self.stack.pop()
def end_key(self):
self.currentKey = self.getData()
def begin_array(self, attrs):
a = []
self.addObject(a)
self.stack.append(a)
def end_array(self):
self.stack.pop()
def end_true(self):
self.addObject(True)
def end_false(self):
self.addObject(False)
def end_integer(self):
self.addObject(int(self.getData()))
def end_real(self):
self.addObject(float(self.getData()))
def end_string(self):
self.addObject(self.getData())
def end_data(self):
self.addObject(Data.fromBase64(self.getData()))
def end_date(self):
self.addObject(_dateFromString(self.getData()))
calendarserver-5.2+dfsg/twext/python/sendmsg.c 0000644 0001750 0001750 00000027610 12263343324 020617 0 ustar rahul rahul /*
* Copyright (c) 2010-2014 Apple Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define PY_SSIZE_T_CLEAN 1
#include
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
/* This may cause some warnings, but if you want to get rid of them, upgrade
* your Python version. */
typedef int Py_ssize_t;
#endif
#include
#include
#include
/*
* As per
* :
*
* "To forestall portability problems, it is recommended that applications
* not use values larger than (2**31)-1 for the socklen_t type."
*/
#define SOCKLEN_MAX 0x7FFFFFFF
PyObject *sendmsg_socket_error;
static PyObject *sendmsg_sendmsg(PyObject *self, PyObject *args, PyObject *keywds);
static PyObject *sendmsg_recvmsg(PyObject *self, PyObject *args, PyObject *keywds);
static PyObject *sendmsg_getsockfam(PyObject *self, PyObject *args, PyObject *keywds);
static PyMethodDef sendmsg_methods[] = {
{"sendmsg", (PyCFunction) sendmsg_sendmsg, METH_VARARGS | METH_KEYWORDS,
NULL},
{"recvmsg", (PyCFunction) sendmsg_recvmsg, METH_VARARGS | METH_KEYWORDS,
NULL},
{"getsockfam", (PyCFunction) sendmsg_getsockfam,
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL}
};
PyMODINIT_FUNC initsendmsg(void) {
PyObject *module;
sendmsg_socket_error = NULL; /* Make sure that this has a known value
before doing anything that might exit. */
module = Py_InitModule("sendmsg", sendmsg_methods);
if (!module) {
return;
}
/*
The following is the only value mentioned by POSIX:
http://www.opengroup.org/onlinepubs/9699919799/basedefs/sys_socket.h.html
*/
if (-1 == PyModule_AddIntConstant(module, "SCM_RIGHTS", SCM_RIGHTS)) {
return;
}
/* BSD, Darwin, Hurd */
#if defined(SCM_CREDS)
if (-1 == PyModule_AddIntConstant(module, "SCM_CREDS", SCM_CREDS)) {
return;
}
#endif
/* Linux */
#if defined(SCM_CREDENTIALS)
if (-1 == PyModule_AddIntConstant(module, "SCM_CREDENTIALS", SCM_CREDENTIALS)) {
return;
}
#endif
/* Apparently everywhere, but not standardized. */
#if defined(SCM_TIMESTAMP)
if (-1 == PyModule_AddIntConstant(module, "SCM_TIMESTAMP", SCM_TIMESTAMP)) {
return;
}
#endif
module = PyImport_ImportModule("socket");
if (!module) {
return;
}
sendmsg_socket_error = PyObject_GetAttrString(module, "error");
if (!sendmsg_socket_error) {
return;
}
}
static PyObject *sendmsg_sendmsg(PyObject *self, PyObject *args, PyObject *keywds) {
int fd;
int flags = 0;
Py_ssize_t sendmsg_result, iovec_length;
struct msghdr message_header;
struct iovec iov[1];
PyObject *ancillary = NULL;
PyObject *ultimate_result = NULL;
static char *kwlist[] = {"fd", "data", "flags", "ancillary", NULL};
if (!PyArg_ParseTupleAndKeywords(
args, keywds, "it#|iO:sendmsg", kwlist,
&fd,
&iov[0].iov_base,
&iovec_length,
&flags,
&ancillary)) {
return NULL;
}
iov[0].iov_len = iovec_length;
message_header.msg_name = NULL;
message_header.msg_namelen = 0;
message_header.msg_iov = iov;
message_header.msg_iovlen = 1;
message_header.msg_control = NULL;
message_header.msg_controllen = 0;
message_header.msg_flags = 0;
if (ancillary) {
if (!PyList_Check(ancillary)) {
PyErr_Format(PyExc_TypeError,
"sendmsg argument 3 expected list, got %s",
ancillary->ob_type->tp_name);
goto finished;
}
PyObject *iterator = PyObject_GetIter(ancillary);
PyObject *item = NULL;
if (iterator == NULL) {
goto finished;
}
size_t all_data_len = 0;
/* First we need to know how big the buffer needs to be in order to
have enough space for all of the messages. */
while ( (item = PyIter_Next(iterator)) ) {
int type, level;
Py_ssize_t data_len;
size_t prev_all_data_len;
char *data;
if (!PyArg_ParseTuple(
item, "iit#:sendmsg ancillary data (level, type, data)",
&level, &type, &data, &data_len)) {
Py_DECREF(item);
Py_DECREF(iterator);
goto finished;
}
prev_all_data_len = all_data_len;
all_data_len += CMSG_SPACE(data_len);
Py_DECREF(item);
if (all_data_len < prev_all_data_len) {
Py_DECREF(iterator);
PyErr_Format(PyExc_OverflowError,
"Too much msg_control to fit in a size_t: %zu",
prev_all_data_len);
goto finished;
}
}
Py_DECREF(iterator);
iterator = NULL;
/* Allocate the buffer for all of the ancillary elements, if we have
* any. */
if (all_data_len) {
if (all_data_len > SOCKLEN_MAX) {
PyErr_Format(PyExc_OverflowError,
"Too much msg_control to fit in a socklen_t: %zu",
all_data_len);
goto finished;
}
message_header.msg_control = malloc(all_data_len);
if (!message_header.msg_control) {
PyErr_NoMemory();
goto finished;
}
}
message_header.msg_controllen = (socklen_t) all_data_len;
iterator = PyObject_GetIter(ancillary); /* again */
item = NULL;
if (!iterator) {
goto finished;
}
/* Unpack the tuples into the control message. */
struct cmsghdr *control_message = CMSG_FIRSTHDR(&message_header);
while ( (item = PyIter_Next(iterator)) ) {
int type, level;
Py_ssize_t data_len;
size_t data_size;
unsigned char *data, *cmsg_data;
/* We explicitly allocated enough space for all ancillary data
above; if there isn't enough room, all bets are off. */
assert(control_message);
if (!PyArg_ParseTuple(item,
"iit#:sendmsg ancillary data (level, type, data)",
&level,
&type,
&data,
&data_len)) {
Py_DECREF(item);
Py_DECREF(iterator);
goto finished;
}
control_message->cmsg_level = level;
control_message->cmsg_type = type;
data_size = CMSG_LEN(data_len);
if (data_size > SOCKLEN_MAX) {
Py_DECREF(item);
Py_DECREF(iterator);
PyErr_Format(PyExc_OverflowError,
"CMSG_LEN(%zd) > SOCKLEN_MAX", data_len);
goto finished;
}
control_message->cmsg_len = (socklen_t) data_size;
cmsg_data = CMSG_DATA(control_message);
memcpy(cmsg_data, data, data_len);
Py_DECREF(item);
control_message = CMSG_NXTHDR(&message_header, control_message);
}
Py_DECREF(iterator);
if (PyErr_Occurred()) {
goto finished;
}
}
sendmsg_result = sendmsg(fd, &message_header, flags);
if (sendmsg_result < 0) {
PyErr_SetFromErrno(sendmsg_socket_error);
goto finished;
} else {
ultimate_result = Py_BuildValue("n", sendmsg_result);
}
finished:
if (message_header.msg_control) {
free(message_header.msg_control);
}
return ultimate_result;
}
static PyObject *sendmsg_recvmsg(PyObject *self, PyObject *args, PyObject *keywds) {
int fd = -1;
int flags = 0;
int maxsize = 8192;
int cmsg_size = 4*1024;
size_t cmsg_space;
Py_ssize_t recvmsg_result;
struct msghdr message_header;
struct cmsghdr *control_message;
struct iovec iov[1];
char *cmsgbuf;
PyObject *ancillary;
PyObject *final_result = NULL;
static char *kwlist[] = {"fd", "flags", "maxsize", "cmsg_size", NULL};
if (!PyArg_ParseTupleAndKeywords(args, keywds, "i|iii:recvmsg", kwlist,
&fd, &flags, &maxsize, &cmsg_size)) {
return NULL;
}
cmsg_space = CMSG_SPACE(cmsg_size);
/* overflow check */
if (cmsg_space > SOCKLEN_MAX) {
PyErr_Format(PyExc_OverflowError,
"CMSG_SPACE(cmsg_size) greater than SOCKLEN_MAX: %d",
cmsg_size);
return NULL;
}
message_header.msg_name = NULL;
message_header.msg_namelen = 0;
iov[0].iov_len = maxsize;
iov[0].iov_base = malloc(maxsize);
if (!iov[0].iov_base) {
PyErr_NoMemory();
return NULL;
}
message_header.msg_iov = iov;
message_header.msg_iovlen = 1;
cmsgbuf = malloc(cmsg_space);
if (!cmsgbuf) {
free(iov[0].iov_base);
PyErr_NoMemory();
return NULL;
}
memset(cmsgbuf, 0, cmsg_space);
message_header.msg_control = cmsgbuf;
/* see above for overflow check */
message_header.msg_controllen = (socklen_t) cmsg_space;
recvmsg_result = recvmsg(fd, &message_header, flags);
if (recvmsg_result < 0) {
PyErr_SetFromErrno(sendmsg_socket_error);
goto finished;
}
ancillary = PyList_New(0);
if (!ancillary) {
goto finished;
}
for (control_message = CMSG_FIRSTHDR(&message_header);
control_message;
control_message = CMSG_NXTHDR(&message_header,
control_message)) {
PyObject *entry;
/* Some platforms apparently always fill out the ancillary data
structure with a single bogus value if none is provided; ignore it,
if that is the case. */
if ((!(control_message->cmsg_level)) &&
(!(control_message->cmsg_type))) {
continue;
}
entry = Py_BuildValue(
"(iis#)",
control_message->cmsg_level,
control_message->cmsg_type,
CMSG_DATA(control_message),
(Py_ssize_t) (control_message->cmsg_len - sizeof(struct cmsghdr)));
if (!entry) {
Py_DECREF(ancillary);
goto finished;
}
if (PyList_Append(ancillary, entry) < 0) {
Py_DECREF(ancillary);
Py_DECREF(entry);
goto finished;
} else {
Py_DECREF(entry);
}
}
final_result = Py_BuildValue(
"s#iO",
iov[0].iov_base,
recvmsg_result,
message_header.msg_flags,
ancillary);
Py_DECREF(ancillary);
finished:
free(iov[0].iov_base);
free(cmsgbuf);
return final_result;
}
static PyObject *sendmsg_getsockfam(PyObject *self, PyObject *args,
PyObject *keywds) {
int fd;
struct sockaddr sa;
static char *kwlist[] = {"fd", NULL};
if (!PyArg_ParseTupleAndKeywords(args, keywds, "i", kwlist, &fd)) {
return NULL;
}
socklen_t sz = sizeof(sa);
if (getsockname(fd, &sa, &sz)) {
PyErr_SetFromErrno(sendmsg_socket_error);
return NULL;
}
return Py_BuildValue("i", sa.sa_family);
}
calendarserver-5.2+dfsg/twext/python/plistlib.py 0000644 0001750 0001750 00000001364 12263343324 021205 0 ustar rahul rahul ##
# Copyright (c) 2008-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
try:
_plistlib = __import__("plistlib")
except ImportError:
from twext.python import _plistlib
import sys
sys.modules[__name__] = _plistlib
calendarserver-5.2+dfsg/twext/python/log.py 0000644 0001750 0001750 00000067476 12263343324 020164 0 ustar rahul rahul # -*- test-case-name: twext.python.test.test_log-*-
##
# Copyright (c) 2006-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Classes and functions to do granular logging.
Example usage in a module C{some.module}::
from twext.python.log import Logger
log = Logger()
def handleData(data):
log.debug("Got data: {data!r}.", data=data)
Or in a class::
from twext.python.log import Logger
class Foo(object):
log = Logger()
def oops(self, data):
self.log.error("Oops! Invalid data from server: {data!r}",
data=data)
C{Logger}s have namespaces, for which logging can be configured independently.
Namespaces may be specified by passing in a C{namespace} argument to L{Logger}
when instantiating it, but if none is given, the logger will derive its own
namespace by using the module name of the callable that instantiated it, or, in
the case of a class, by using the fully qualified name of the class.
In the first example above, the namespace would be C{some.module}, and in the
second example, it would be C{some.module.Foo}.
"""
__all__ = [
"InvalidLogLevelError",
"LogLevel",
"formatEvent",
"Logger",
"LegacyLogger",
"ILogObserver",
"ILegacyLogObserver",
"LogPublisher",
"PredicateResult",
"ILogFilterPredicate",
"FilteringLogObserver",
"LogLevelFilterPredicate",
"LegacyLogObserver",
"replaceTwistedLoggers",
#"StandardIOObserver",
]
import sys
from sys import stdout, stderr
from string import Formatter
import inspect
import logging
import time
from zope.interface import Interface, implementer
from twisted.python.constants import NamedConstant, Names
from twisted.python.failure import Failure
from twisted.python.reflect import safe_str, safe_repr
import twisted.python.log
from twisted.python.log import msg as twistedLogMessage
from twisted.python.log import addObserver, removeObserver
from twisted.python.log import ILogObserver as ILegacyLogObserver
OBSERVER_REMOVED = (
"Temporarily removing observer {observer} due to exception: {e}"
)
#
# Log level definitions
#
class InvalidLogLevelError(Exception):
"""
Someone tried to use a L{LogLevel} that is unknown to the logging system.
"""
def __init__(self, level):
"""
@param level: a L{LogLevel}
"""
super(InvalidLogLevelError, self).__init__(str(level))
self.level = level
class LogLevel(Names):
"""
Constants denoting log levels:
- C{debug}: Information of use to a developer of the software, not
generally of interest to someone running the software unless they are
attempting to diagnose a software issue.
- C{info}: Informational events: Routine information about the status of
an application, such as incoming connections, startup of a subsystem,
etc.
- C{warn}: Warnings events: Events that may require greater attention than
informational events but are not a systemic failure condition, such as
authorization failures, bad data from a network client, etc.
- C{error}: Error conditions: Events indicating a systemic failure, such
as unhandled exceptions, loss of connectivity to a back-end database,
etc.
"""
debug = NamedConstant()
info = NamedConstant()
warn = NamedConstant()
error = NamedConstant()
@classmethod
def levelWithName(cls, name):
"""
@param name: the name of a L{LogLevel}
@return: the L{LogLevel} with the specified C{name}
"""
try:
return cls.lookupByName(name)
except ValueError:
raise InvalidLogLevelError(name)
@classmethod
def _priorityForLevel(cls, constant):
"""
We want log levels to have defined ordering - the order of definition -
but they aren't value constants (the only value is the name). This is
arguably a bug in Twisted, so this is just a workaround for U{until
this is fixed in some way
}.
"""
return cls._levelPriorities[constant]
LogLevel._levelPriorities = dict(
(constant, idx) for (idx, constant) in
(enumerate(LogLevel.iterconstants()))
)
#
# Mappings to Python's logging module
#
pythonLogLevelMapping = {
LogLevel.debug: logging.DEBUG,
LogLevel.info: logging.INFO,
LogLevel.warn: logging.WARNING,
LogLevel.error: logging.ERROR,
# LogLevel.critical: logging.CRITICAL,
}
##
# Loggers
##
def formatEvent(event):
"""
Formats an event as a L{unicode}, using the format in
C{event["log_format"]}.
This implementation should never raise an exception; if the formatting
cannot be done, the returned string will describe the event generically so
that a useful message is emitted regardless.
@param event: a logging event
@return: a L{unicode}
"""
try:
format = event.get("log_format", None)
if format is None:
raise ValueError("No log format provided")
# Make sure format is unicode.
if isinstance(format, bytes):
# If we get bytes, assume it's UTF-8 bytes
format = format.decode("utf-8")
elif isinstance(format, unicode):
pass
else:
raise TypeError("Log format must be unicode or bytes, not {0!r}"
.format(format))
return formatWithCall(format, event)
except BaseException as e:
return formatUnformattableEvent(event, e)
def formatUnformattableEvent(event, error):
"""
Formats an event as a L{unicode} that describes the event generically and a
formatting error.
@param event: a logging event
@type dict: L{dict}
@param error: the formatting error
@type error: L{Exception}
@return: a L{unicode}
"""
try:
return (
u"Unable to format event {event!r}: {error}"
.format(event=event, error=error)
)
except BaseException:
# Yikes, something really nasty happened.
#
# Try to recover as much formattable data as possible; hopefully at
# least the namespace is sane, which will help you find the offending
# logger.
failure = Failure()
text = ", ".join(" = ".join((safe_repr(key), safe_repr(value)))
for key, value in event.items())
return (
u"MESSAGE LOST: unformattable object logged: {error}\n"
u"Recoverable data: {text}\n"
u"Exception during formatting:\n{failure}"
.format(error=safe_repr(error), failure=failure, text=text)
)
class Logger(object):
"""
Logging object.
"""
publisher = lambda e: None
@staticmethod
def _namespaceFromCallingContext():
"""
Derive a namespace from the module containing the caller's caller.
@return: a namespace
"""
return inspect.currentframe().f_back.f_back.f_globals["__name__"]
def __init__(self, namespace=None, source=None):
"""
@param namespace: The namespace for this logger. Uses a dotted
notation, as used by python modules. If not C{None}, then the name
of the module of the caller is used.
@param source: The object which is emitting events to this
logger; this is automatically set on instances of a class
if this L{Logger} is an attribute of that class.
"""
if namespace is None:
namespace = self._namespaceFromCallingContext()
self.namespace = namespace
self.source = source
def __get__(self, oself, type=None):
"""
When used as a descriptor, i.e.::
# athing.py
class Something(object):
log = Logger()
def hello(self):
self.log.info("Hello")
a L{Logger}'s namespace will be set to the name of the class it is
declared on. In the above example, the namespace would be
C{athing.Something}.
Additionally, it's source will be set to the actual object referring to
the L{Logger}. In the above example, C{Something.log.source} would be
C{Something}, and C{Something().log.source} would be an instance of
C{Something}.
"""
if oself is None:
source = type
else:
source = oself
return self.__class__(
'.'.join([type.__module__, type.__name__]),
source
)
def __repr__(self):
return "<%s %r>" % (self.__class__.__name__, self.namespace)
def emit(self, level, format=None, **kwargs):
"""
Emit a log event to all log observers at the given level.
@param level: a L{LogLevel}
@param format: a message format using new-style (PEP 3101)
formatting. The logging event (which is a L{dict}) is
used to render this format string.
@param kwargs: additional keyword parameters to include with
the event.
"""
# FIXME: Updated Twisted supports 'in' on constants container
if level not in LogLevel.iterconstants():
self.failure(
"Got invalid log level {invalidLevel!r} in {logger}.emit().",
Failure(InvalidLogLevelError(level)),
invalidLevel=level,
logger=self,
)
#level = LogLevel.error
# FIXME: continue to emit?
return
kwargs.update(
log_logger=self, log_level=level, log_namespace=self.namespace,
log_source=self.source, log_format=format, log_time=time.time(),
)
self.publisher(kwargs)
def failure(self, format, failure=None, level=LogLevel.error, **kwargs):
"""
Log an failure and emit a traceback.
For example::
try:
frob(knob)
except Exception:
log.failure("While frobbing {knob}", knob=knob)
or::
d = deferredFrob(knob)
d.addErrback(lambda f: log.failure, "While frobbing {knob}",
f, knob=knob)
@param format: a message format using new-style (PEP 3101)
formatting. The logging event (which is a L{dict}) is
used to render this format string.
@param failure: a L{Failure} to log. If C{None}, a L{Failure} is
created from the exception in flight.
@param level: a L{LogLevel} to use.
@param kwargs: additional keyword parameters to include with the
event.
"""
if failure is None:
failure = Failure()
self.emit(level, format, log_failure=failure, **kwargs)
class LegacyLogger(object):
"""
A logging object that provides some compatibility with the
L{twisted.python.log} module.
"""
def __init__(self, logger=None):
if logger is None:
self.newStyleLogger = Logger(Logger._namespaceFromCallingContext())
else:
self.newStyleLogger = logger
def __getattribute__(self, name):
try:
return super(LegacyLogger, self).__getattribute__(name)
except AttributeError:
return getattr(twisted.python.log, name)
def msg(self, *message, **kwargs):
"""
This method is API-compatible with L{twisted.python.log.msg} and exists
for compatibility with that API.
"""
if message:
message = " ".join(map(safe_str, message))
else:
message = None
return self.newStyleLogger.emit(LogLevel.info, message, **kwargs)
def err(self, _stuff=None, _why=None, **kwargs):
"""
This method is API-compatible with L{twisted.python.log.err} and exists
for compatibility with that API.
"""
if _stuff is None:
_stuff = Failure()
elif isinstance(_stuff, Exception):
_stuff = Failure(_stuff)
if isinstance(_stuff, Failure):
self.newStyleLogger.emit(LogLevel.error, failure=_stuff, why=_why,
isError=1, **kwargs)
else:
# We got called with an invalid _stuff.
self.newStyleLogger.emit(LogLevel.error, repr(_stuff), why=_why,
isError=1, **kwargs)
def bindEmit(level):
doc = """
Emit a log event at log level L{{{level}}}.
@param format: a message format using new-style (PEP 3101)
formatting. The logging event (which is a L{{dict}}) is used to
render this format string.
@param kwargs: additional keyword parameters to include with the
event.
""".format(level=level.name)
#
# Attach methods to Logger
#
def log_emit(self, format=None, **kwargs):
self.emit(level, format, **kwargs)
log_emit.__doc__ = doc
setattr(Logger, level.name, log_emit)
def _bindLevels():
for level in LogLevel.iterconstants():
bindEmit(level)
_bindLevels()
#
# Observers
#
class ILogObserver(Interface):
"""
An observer which can handle log events.
"""
def __call__(event):
"""
Log an event.
@type event: C{dict} with (native) C{str} keys.
@param event: A dictionary with arbitrary keys as defined by
the application emitting logging events, as well as keys
added by the logging system, with are:
...
"""
@implementer(ILogObserver)
class LogPublisher(object):
"""
I{ILogObserver} that fans out events to other observers.
Keeps track of a set of L{ILogObserver} objects and forwards
events to each.
"""
log = Logger()
def __init__(self, *observers):
self._observers = set(observers)
@property
def observers(self):
return frozenset(self._observers)
def addObserver(self, observer):
"""
Registers an observer with this publisher.
@param observer: An L{ILogObserver} to add.
"""
self._observers.add(observer)
def removeObserver(self, observer):
"""
Unregisters an observer with this publisher.
@param observer: An L{ILogObserver} to remove.
"""
try:
self._observers.remove(observer)
except KeyError:
pass
def __call__(self, event):
for observer in self.observers:
try:
observer(event)
except BaseException as e:
#
# We have to remove the offending observer because
# we're going to badmouth it to all of its friends
# (other observers) and it might get offended and
# raise again, causing an infinite loop.
#
self.removeObserver(observer)
try:
self.log.failure(OBSERVER_REMOVED, observer=observer, e=e)
except BaseException:
pass
finally:
self.addObserver(observer)
class PredicateResult(Names):
"""
Predicate results.
"""
yes = NamedConstant() # Log this
no = NamedConstant() # Don't log this
maybe = NamedConstant() # No opinion
class ILogFilterPredicate(Interface):
"""
A predicate that determined whether an event should be logged.
"""
def __call__(event):
"""
Determine whether an event should be logged.
@returns: a L{PredicateResult}.
"""
@implementer(ILogObserver)
class FilteringLogObserver(object):
"""
L{ILogObserver} that wraps another L{ILogObserver}, but filters
out events based on applying a series of L{ILogFilterPredicate}s.
"""
def __init__(self, observer, predicates):
"""
@param observer: an L{ILogObserver} to which this observer
will forward events.
@param predicates: an ordered iterable of predicates to apply
to events before forwarding to the wrapped observer.
"""
self.observer = observer
self.predicates = list(predicates)
def shouldLogEvent(self, event):
"""
Determine whether an event should be logged, based
C{self.predicates}.
@param event: an event
"""
for predicate in self.predicates:
result = predicate(event)
if result == PredicateResult.yes:
return True
if result == PredicateResult.no:
return False
if result == PredicateResult.maybe:
continue
raise TypeError("Invalid predicate result: {0!r}".format(result))
return True
def __call__(self, event):
if self.shouldLogEvent(event):
self.observer(event)
@implementer(ILogFilterPredicate)
class LogLevelFilterPredicate(object):
"""
L{ILogFilterPredicate} that filters out events with a log level
lower than the log level for the event's namespace.
Events that not not have a log level or namespace are also dropped.
"""
def __init__(self):
# FIXME: Make this a class variable. But that raises an
# _initializeEnumerants constants error in Twisted 12.2.0.
self.defaultLogLevel = LogLevel.info
self._logLevelsByNamespace = {}
self.clearLogLevels()
def logLevelForNamespace(self, namespace):
"""
@param namespace: a logging namespace, or C{None} for the default
namespace.
@return: the L{LogLevel} for the specified namespace.
"""
if not namespace:
return self._logLevelsByNamespace[None]
if namespace in self._logLevelsByNamespace:
return self._logLevelsByNamespace[namespace]
segments = namespace.split(".")
index = len(segments) - 1
while index > 0:
namespace = ".".join(segments[:index])
if namespace in self._logLevelsByNamespace:
return self._logLevelsByNamespace[namespace]
index -= 1
return self._logLevelsByNamespace[None]
def setLogLevelForNamespace(self, namespace, level):
"""
Sets the global log level for a logging namespace.
@param namespace: a logging namespace
@param level: the L{LogLevel} for the given namespace.
"""
if level not in LogLevel.iterconstants():
raise InvalidLogLevelError(level)
if namespace:
self._logLevelsByNamespace[namespace] = level
else:
self._logLevelsByNamespace[None] = level
def clearLogLevels(self):
"""
Clears all global log levels to the default.
"""
self._logLevelsByNamespace.clear()
self._logLevelsByNamespace[None] = self.defaultLogLevel
def __call__(self, event):
level = event.get("log_level", None)
namespace = event.get("log_namespace", None)
if (
level is None or
namespace is None or
LogLevel._priorityForLevel(level) <
LogLevel._priorityForLevel(self.logLevelForNamespace(namespace))
):
return PredicateResult.no
return PredicateResult.maybe
@implementer(ILogObserver)
class LegacyLogObserver(object):
"""
L{ILogObserver} that wraps an L{ILegacyLogObserver}.
"""
def __init__(self, legacyObserver):
"""
@param legacyObserver: an L{ILegacyLogObserver} to which this
observer will forward events.
"""
self.legacyObserver = legacyObserver
def __call__(self, event):
prefix = "[{log_namespace}#{log_level.name}] ".format(**event)
level = event["log_level"]
#
# Twisted's logging supports indicating a python log level, so let's
# provide the equivalent to our logging levels.
#
if level in pythonLogLevelMapping:
event["logLevel"] = pythonLogLevelMapping[level]
# Format new style -> old style
if event["log_format"]:
#
# Create an object that implements __str__() in order to
# defer the work of formatting until it's needed by a
# legacy log observer.
#
class LegacyFormatStub(object):
def __str__(oself):
return formatEvent(event).encode("utf-8")
event["format"] = prefix + "%(log_legacy)s"
event["log_legacy"] = LegacyFormatStub()
# log.failure() -> isError blah blah
if "log_failure" in event:
event["failure"] = event["log_failure"]
event["isError"] = 1
event["why"] = "{prefix}{message}".format(
prefix=prefix, message=formatEvent(event)
)
self.legacyObserver(**event)
# FIXME: This could have a better name.
class DefaultLogPublisher(object):
"""
This observer sets up a set of chained observers as follows:
1. B{rootPublisher} - a L{LogPublisher}
2. B{filters}: a L{FilteringLogObserver} that filters out messages
using a L{LogLevelFilterPredicate}
3. B{filteredPublisher} - a L{LogPublisher}
4. B{legacyLogObserver} - a L{LegacyLogObserver} wired up to
L{twisted.python.log.msg}. This allows any observers registered
with Twisted's logging (that is, most observers in presently use) to
receive (filtered) events.
The purpose of this class is to provide a default log observer with
sufficient hooks to enable applications to add observers that can either
receive all log messages, or only log messages that are configured to pass
though the L{LogLevelFilterPredicate}::
from twext.python.log import Logger, ILogObserver
log = Logger()
@implementer(ILogObserver)
class AMPObserver(object):
def __call__(self, event):
# eg.: Hold events in a ring buffer and expose them via AMP.
...
@implementer(ILogObserver)
class FileObserver(object):
def __call__(self, event):
# eg.: Take events and write them into a file.
...
# Send all events to the AMPObserver
log.publisher.addObserver(AMPObserver(), filtered=False)
# Send filtered events to the FileObserver
log.publisher.addObserver(AMPObserver())
With no observers added, the default behavior is that the legacy Twisted
logging system sees messages as controlled by L{LogLevelFilterPredicate}.
"""
def __init__(self):
self.legacyLogObserver = LegacyLogObserver(twistedLogMessage)
self.filteredPublisher = LogPublisher(self.legacyLogObserver)
self.levels = LogLevelFilterPredicate()
self.filters = FilteringLogObserver(self.filteredPublisher,
(self.levels,))
self.rootPublisher = LogPublisher(self.filters)
def addObserver(self, observer, filtered=True):
"""
Registers an observer with this publisher.
@param observer: An L{ILogObserver} to add.
@param filtered: If true, registers C{observer} after filters are
applied; otherwise C{observer} will get all events.
"""
if filtered:
self.filteredPublisher.addObserver(observer)
self.rootPublisher.removeObserver(observer)
else:
self.rootPublisher.addObserver(observer)
self.filteredPublisher.removeObserver(observer)
def removeObserver(self, observer):
"""
Unregisters an observer with this publisher.
@param observer: An L{ILogObserver} to remove.
"""
self.rootPublisher.removeObserver(observer)
self.filteredPublisher.removeObserver(observer)
def __call__(self, event):
self.rootPublisher(event)
Logger.publisher = DefaultLogPublisher()
#
# Utilities
#
class CallMapping(object):
def __init__(self, submapping):
self._submapping = submapping
def __getitem__(self, key):
callit = key.endswith(u"()")
realKey = key[:-2] if callit else key
value = self._submapping[realKey]
if callit:
value = value()
return value
def formatWithCall(formatString, mapping):
"""
Format a string like L{unicode.format}, but:
- taking only a name mapping; no positional arguments
- with the additional syntax that an empty set of parentheses
correspond to a formatting item that should be called, and its result
C{str}'d, rather than calling C{str} on the element directly as
normal.
For example::
>>> formatWithCall("{string}, {function()}.",
... dict(string="just a string",
... function=lambda: "a function"))
'just a string, a function.'
@param formatString: A PEP-3101 format string.
@type formatString: L{unicode}
@param mapping: A L{dict}-like object to format.
@return: The string with formatted values interpolated.
@rtype: L{unicode}
"""
return unicode(
theFormatter.vformat(formatString, (), CallMapping(mapping))
)
theFormatter = Formatter()
def replaceTwistedLoggers():
"""
Visit all Python modules that have been loaded and:
- replace L{twisted.python.log} with a L{LegacyLogger}
- replace L{twisted.python.log.msg} with a L{LegacyLogger}'s C{msg}
- replace L{twisted.python.log.err} with a L{LegacyLogger}'s C{err}
"""
log = Logger()
for moduleName, module in sys.modules.iteritems():
# Oddly, this happens
if module is None:
continue
# Don't patch Twisted's logging module
if module in (twisted.python, twisted.python.log):
continue
# Don't patch this module
if moduleName is __name__:
continue
for name, obj in module.__dict__.iteritems():
newLogger = Logger(namespace=module.__name__)
legacyLogger = LegacyLogger(logger=newLogger)
if obj is twisted.python.log:
log.info("Replacing Twisted log module object {0} in {1}"
.format(name, module.__name__))
setattr(module, name, legacyLogger)
elif obj is twisted.python.log.msg:
log.info("Replacing Twisted log.msg object {0} in {1}"
.format(name, module.__name__))
setattr(module, name, legacyLogger.msg)
elif obj is twisted.python.log.err:
log.info("Replacing Twisted log.err object {0} in {1}"
.format(name, module.__name__))
setattr(module, name, legacyLogger.err)
######################################################################
# FIXME: This may not be needed; look into removing it.
class StandardIOObserver(object):
"""
(Legacy) log observer that writes to standard I/O.
"""
def emit(self, eventDict):
text = None
if eventDict["isError"]:
output = stderr
if "failure" in eventDict:
text = eventDict["failure"].getTraceback()
else:
output = stdout
if not text:
text = " ".join([str(m) for m in eventDict["message"]]) + "\n"
output.write(text)
output.flush()
def start(self):
addObserver(self.emit)
def stop(self):
removeObserver(self.emit)
calendarserver-5.2+dfsg/twext/python/memcacheclient.py 0000644 0001750 0001750 00000143101 12113213176 022313 0 ustar rahul rahul #!/usr/bin/env python
from __future__ import print_function
"""
client module for memcached (memory cache daemon)
Overview
========
See U{the MemCached homepage} for more about memcached.
Usage summary
=============
This should give you a feel for how this module operates::
import memcacheclient
mc = memcacheclient.Client(['127.0.0.1:11211'], debug=0)
mc.set("some_key", "Some value")
value = mc.get("some_key")
mc.set("another_key", 3)
mc.delete("another_key")
mc.set("key", "1") # note that the key used for incr/decr must be a string.
mc.incr("key")
mc.decr("key")
The standard way to use memcache with a database is like this::
key = derive_key(obj)
obj = mc.get(key)
if not obj:
obj = backend_api.get(...)
mc.set(obj)
# we now have obj, and future passes through this code
# will use the object from the cache.
Detailed Documentation
======================
More detailed documentation is available in the L{Client} class.
"""
import sys
import socket
import time
import os
import re
import types
from twext.python.log import Logger
from twistedcaldav.config import config
log = Logger()
try:
import cPickle as pickle
except ImportError:
import pickle
try:
from zlib import compress, decompress
_supports_compress = True
except ImportError:
_supports_compress = False
# quickly define a decompress just in case we recv compressed data.
def decompress(val):
raise _Error("received compressed data but I don't support compession (import error)")
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
from binascii import crc32 # zlib version is not cross-platform
serverHashFunction = crc32
__author__ = "Evan Martin "
__version__ = "1.44"
__copyright__ = "Copyright (C) 2003 Danga Interactive"
__license__ = "Python"
SERVER_MAX_KEY_LENGTH = 250
# Storing values larger than 1MB requires recompiling memcached. If you do,
# this value can be changed by doing "memcacheclient.SERVER_MAX_VALUE_LENGTH = N"
# after importing this module.
SERVER_MAX_VALUE_LENGTH = 1024*1024
class _Error(Exception):
pass
class MemcacheError(_Error):
"""
Memcache connection error
"""
class NotFoundError(MemcacheError):
"""
NOT_FOUND error
"""
class TokenMismatchError(MemcacheError):
"""
Check-and-set token mismatch
"""
try:
# Only exists in Python 2.4+
from threading import local
except ImportError:
# TODO: add the pure-python local implementation
class local(object):
pass
class ClientFactory(object):
# unit tests should set this to True to enable the fake test cache
allowTestCache = False
@classmethod
def getClient(cls, servers, debug=0, pickleProtocol=0,
pickler=pickle.Pickler, unpickler=pickle.Unpickler,
pload=None, pid=None):
if cls.allowTestCache:
return TestClient(servers, debug=debug,
pickleProtocol=pickleProtocol, pickler=pickler,
unpickler=unpickler, pload=pload, pid=pid)
elif config.Memcached.Pools.Default.ClientEnabled:
return Client(servers, debug=debug, pickleProtocol=pickleProtocol,
pickler=pickler, unpickler=unpickler, pload=pload, pid=pid)
else:
return None
class Client(local):
"""
Object representing a pool of memcache servers.
See L{memcache} for an overview.
In all cases where a key is used, the key can be either:
1. A simple hashable type (string, integer, etc.).
2. A tuple of C{(hashvalue, key)}. This is useful if you want to avoid
making this module calculate a hash value. You may prefer, for
example, to keep all of a given user's objects on the same memcache
server, so you could use the user's unique id as the hash value.
@group Setup: __init__, set_servers, forget_dead_hosts, disconnect_all, debuglog
@group Insertion: set, add, replace, set_multi
@group Retrieval: get, get_multi
@group Integers: incr, decr
@group Removal: delete, delete_multi
@sort: __init__, set_servers, forget_dead_hosts, disconnect_all, debuglog,\
set, set_multi, add, replace, get, get_multi, incr, decr, delete, delete_multi
"""
_FLAG_PICKLE = 1<<0
_FLAG_INTEGER = 1<<1
_FLAG_LONG = 1<<2
_FLAG_COMPRESSED = 1<<3
_SERVER_RETRIES = 10 # how many times to try finding a free server.
# exceptions for Client
class MemcachedKeyError(Exception):
pass
class MemcachedKeyLengthError(MemcachedKeyError):
pass
class MemcachedKeyCharacterError(MemcachedKeyError):
pass
class MemcachedKeyNoneError(MemcachedKeyError):
pass
class MemcachedKeyTypeError(MemcachedKeyError):
pass
class MemcachedStringEncodingError(Exception):
pass
def __init__(self, servers, debug=0, pickleProtocol=0,
pickler=pickle.Pickler, unpickler=pickle.Unpickler,
pload=None, pid=None):
"""
Create a new Client object with the given list of servers.
@param servers: C{servers} is passed to L{set_servers}.
@param debug: whether to display error messages when a server can't be
contacted.
@param pickleProtocol: number to mandate protocol used by (c)Pickle.
@param pickler: optional override of default Pickler to allow subclassing.
@param unpickler: optional override of default Unpickler to allow subclassing.
@param pload: optional persistent_load function to call on pickle loading.
Useful for cPickle since subclassing isn't allowed.
@param pid: optional persistent_id function to call on pickle storing.
Useful for cPickle since subclassing isn't allowed.
"""
local.__init__(self)
self.set_servers(servers)
self.debug = debug
self.stats = {}
# Allow users to modify pickling/unpickling behavior
self.pickleProtocol = pickleProtocol
self.pickler = pickler
self.unpickler = unpickler
self.persistent_load = pload
self.persistent_id = pid
# figure out the pickler style
file = StringIO()
try:
pickler = self.pickler(file, protocol = self.pickleProtocol)
self.picklerIsKeyword = True
except TypeError:
self.picklerIsKeyword = False
def set_servers(self, servers):
"""
Set the pool of servers used by this client.
@param servers: an array of servers.
Servers can be passed in two forms:
1. Strings of the form C{"host:port"}, which implies a default weight of 1.
2. Tuples of the form C{("host:port", weight)}, where C{weight} is
an integer weight value.
"""
self.servers = [_Host(s, self.debuglog) for s in servers]
self._init_buckets()
def get_stats(self):
'''Get statistics from each of the servers.
@return: A list of tuples ( server_identifier, stats_dictionary ).
The dictionary contains a number of name/value pairs specifying
the name of the status field and the string value associated with
it. The values are not converted from strings.
'''
data = []
for s in self.servers:
if not s.connect(): continue
if s.family == socket.AF_INET:
name = '%s:%s (%s)' % ( s.ip, s.port, s.weight )
else:
name = 'unix:%s (%s)' % ( s.address, s.weight )
s.send_cmd('stats')
serverData = {}
data.append(( name, serverData ))
readline = s.readline
while 1:
line = readline()
if not line or line.strip() == 'END': break
stats = line.split(' ', 2)
serverData[stats[1]] = stats[2]
return(data)
def get_slabs(self):
data = []
for s in self.servers:
if not s.connect(): continue
if s.family == socket.AF_INET:
name = '%s:%s (%s)' % ( s.ip, s.port, s.weight )
else:
name = 'unix:%s (%s)' % ( s.address, s.weight )
serverData = {}
data.append(( name, serverData ))
s.send_cmd('stats items')
readline = s.readline
while 1:
line = readline()
if not line or line.strip() == 'END': break
item = line.split(' ', 2)
#0 = STAT, 1 = ITEM, 2 = Value
slab = item[1].split(':', 2)
#0 = items, 1 = Slab #, 2 = Name
if not serverData.has_key(slab[1]):
serverData[slab[1]] = {}
serverData[slab[1]][slab[2]] = item[2]
return data
def flush_all(self):
'Expire all data currently in the memcache servers.'
for s in self.servers:
if not s.connect(): continue
s.send_cmd('flush_all')
s.expect("OK")
def debuglog(self, str):
if self.debug:
sys.stderr.write("MemCached: %s\n" % str)
def _statlog(self, func):
if not self.stats.has_key(func):
self.stats[func] = 1
else:
self.stats[func] += 1
def forget_dead_hosts(self):
"""
Reset every host in the pool to an "alive" state.
"""
for s in self.servers:
s.deaduntil = 0
def _init_buckets(self):
self.buckets = []
for server in self.servers:
for i in range(server.weight):
self.buckets.append(server)
def _get_server(self, key):
if type(key) == types.TupleType:
serverhash, key = key
else:
serverhash = serverHashFunction(key)
for i in range(Client._SERVER_RETRIES):
server = self.buckets[serverhash % len(self.buckets)]
if server.connect():
#print("(using server %s)" % server, end="")
return server, key
serverhash = serverHashFunction(str(serverhash) + str(i))
log.error("Memcacheclient _get_server( ) failed to connect")
return None, None
def disconnect_all(self):
for s in self.servers:
s.close_socket()
def delete_multi(self, keys, time=0, key_prefix=''):
'''
Delete multiple keys in the memcache doing just one query.
>>> notset_keys = mc.set_multi({'key1' : 'val1', 'key2' : 'val2'})
>>> mc.get_multi(['key1', 'key2']) == {'key1' : 'val1', 'key2' : 'val2'}
1
>>> mc.delete_multi(['key1', 'key2'])
1
>>> mc.get_multi(['key1', 'key2']) == {}
1
This method is recommended over iterated regular L{delete}s as it reduces total latency, since
your app doesn't have to wait for each round-trip of L{delete} before sending
the next one.
@param keys: An iterable of keys to clear
@param time: number of seconds any subsequent set / update commands should fail. Defaults to 0 for no delay.
@param key_prefix: Optional string to prepend to each key when sending to memcache.
See docs for L{get_multi} and L{set_multi}.
@return: 1 if no failure in communication with any memcacheds.
@rtype: int
'''
self._statlog('delete_multi')
server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(keys, key_prefix)
# send out all requests on each server before reading anything
dead_servers = []
rc = 1
for server in server_keys.iterkeys():
bigcmd = []
write = bigcmd.append
if time != None:
for key in server_keys[server]: # These are mangled keys
write("delete %s %d\r\n" % (key, time))
else:
for key in server_keys[server]: # These are mangled keys
write("delete %s\r\n" % key)
try:
server.send_cmds(''.join(bigcmd))
except socket.error, msg:
rc = 0
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
dead_servers.append(server)
# if any servers died on the way, don't expect them to respond.
for server in dead_servers:
del server_keys[server]
for server, keys in server_keys.iteritems():
try:
for key in keys:
server.expect("DELETED")
except socket.error, msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
rc = 0
return rc
def delete(self, key, time=0):
'''Deletes a key from the memcache.
@return: Nonzero on success.
@param time: number of seconds any subsequent set / update commands should fail. Defaults to 0 for no delay.
@rtype: int
'''
check_key(key)
server, key = self._get_server(key)
if not server:
return 0
self._statlog('delete')
if time != None:
cmd = "delete %s %d" % (key, time)
else:
cmd = "delete %s" % key
try:
server.send_cmd(cmd)
server.expect("DELETED")
except socket.error, msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
return 0
return 1
def incr(self, key, delta=1):
"""
Sends a command to the server to atomically increment the value for C{key} by
C{delta}, or by 1 if C{delta} is unspecified. Returns None if C{key} doesn't
exist on server, otherwise it returns the new value after incrementing.
Note that the value for C{key} must already exist in the memcache, and it
must be the string representation of an integer.
>>> mc.set("counter", "20") # returns 1, indicating success
1
>>> mc.incr("counter")
21
>>> mc.incr("counter")
22
Overflow on server is not checked. Be aware of values approaching
2**32. See L{decr}.
@param delta: Integer amount to increment by (should be zero or greater).
@return: New value after incrementing.
@rtype: int
"""
return self._incrdecr("incr", key, delta)
def decr(self, key, delta=1):
"""
Like L{incr}, but decrements. Unlike L{incr}, underflow is checked and
new values are capped at 0. If server value is 1, a decrement of 2
returns 0, not -1.
@param delta: Integer amount to decrement by (should be zero or greater).
@return: New value after decrementing.
@rtype: int
"""
return self._incrdecr("decr", key, delta)
def _incrdecr(self, cmd, key, delta):
check_key(key)
server, key = self._get_server(key)
if not server:
return 0
self._statlog(cmd)
cmd = "%s %s %d" % (cmd, key, delta)
try:
server.send_cmd(cmd)
line = server.readline()
return int(line)
except socket.error, msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
return None
def add(self, key, val, time = 0, min_compress_len = 0):
'''
Add new key with value.
Like L{set}, but only stores in memcache if the key doesn't already exist.
@return: Nonzero on success.
@rtype: int
'''
return self._set("add", key, val, time, min_compress_len)
def append(self, key, val, time=0, min_compress_len=0):
'''Append the value to the end of the existing key's value.
Only stores in memcache if key already exists.
Also see L{prepend}.
@return: Nonzero on success.
@rtype: int
'''
return self._set("append", key, val, time, min_compress_len)
def prepend(self, key, val, time=0, min_compress_len=0):
'''Prepend the value to the beginning of the existing key's value.
Only stores in memcache if key already exists.
Also see L{append}.
@return: Nonzero on success.
@rtype: int
'''
return self._set("prepend", key, val, time, min_compress_len)
def replace(self, key, val, time=0, min_compress_len=0):
'''Replace existing key with value.
Like L{set}, but only stores in memcache if the key already exists.
The opposite of L{add}.
@return: Nonzero on success.
@rtype: int
'''
return self._set("replace", key, val, time, min_compress_len)
def set(self, key, val, time=0, min_compress_len=0, token=None):
'''Unconditionally sets a key to a given value in the memcache.
The C{key} can optionally be an tuple, with the first element
being the server hash value and the second being the key.
If you want to avoid making this module calculate a hash value.
You may prefer, for example, to keep all of a given user's objects
on the same memcache server, so you could use the user's unique
id as the hash value.
@return: Nonzero on success.
@rtype: int
@param time: Tells memcached the time which this value should expire, either
as a delta number of seconds, or an absolute unix time-since-the-epoch
value. See the memcached protocol docs section "Storage Commands"
for more info on . We default to 0 == cache forever.
@param min_compress_len: The threshold length to kick in auto-compression
of the value using the zlib.compress() routine. If the value being cached is
a string, then the length of the string is measured, else if the value is an
, then the length of the pickle result is measured. If the resulting
attempt at compression yeilds a larger string than the input, then it is
discarded. For backwards compatability, this parameter defaults to 0,
indicating don't ever try to compress.
'''
return self._set("set", key, val, time, min_compress_len, token=token)
def _map_and_prefix_keys(self, key_iterable, key_prefix):
"""Compute the mapping of server (_Host instance) -> list of keys to stuff onto that server, as well as the mapping of
prefixed key -> original key.
"""
# Check it just once ...
key_extra_len=len(key_prefix)
if key_prefix:
check_key(key_prefix)
# server (_Host) -> list of unprefixed server keys in mapping
server_keys = {}
prefixed_to_orig_key = {}
# build up a list for each server of all the keys we want.
for orig_key in key_iterable:
if type(orig_key) is types.TupleType:
# Tuple of hashvalue, key ala _get_server(). Caller is essentially telling us what server to stuff this on.
# Ensure call to _get_server gets a Tuple as well.
str_orig_key = str(orig_key[1])
server, key = self._get_server((orig_key[0], key_prefix + str_orig_key)) # Gotta pre-mangle key before hashing to a server. Returns the mangled key.
else:
str_orig_key = str(orig_key) # set_multi supports int / long keys.
server, key = self._get_server(key_prefix + str_orig_key)
# Now check to make sure key length is proper ...
check_key(str_orig_key, key_extra_len=key_extra_len)
if not server:
continue
if not server_keys.has_key(server):
server_keys[server] = []
server_keys[server].append(key)
prefixed_to_orig_key[key] = orig_key
return (server_keys, prefixed_to_orig_key)
def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0):
'''
Sets multiple keys in the memcache doing just one query.
>>> notset_keys = mc.set_multi({'key1' : 'val1', 'key2' : 'val2'})
>>> mc.get_multi(['key1', 'key2']) == {'key1' : 'val1', 'key2' : 'val2'}
1
This method is recommended over regular L{set} as it lowers the number of
total packets flying around your network, reducing total latency, since
your app doesn't have to wait for each round-trip of L{set} before sending
the next one.
@param mapping: A dict of key/value pairs to set.
@param time: Tells memcached the time which this value should expire, either
as a delta number of seconds, or an absolute unix time-since-the-epoch
value. See the memcached protocol docs section "Storage Commands"
for more info on . We default to 0 == cache forever.
@param key_prefix: Optional string to prepend to each key when sending to memcache. Allows you to efficiently stuff these keys into a pseudo-namespace in memcache:
>>> notset_keys = mc.set_multi({'key1' : 'val1', 'key2' : 'val2'}, key_prefix='subspace_')
>>> len(notset_keys) == 0
True
>>> mc.get_multi(['subspace_key1', 'subspace_key2']) == {'subspace_key1' : 'val1', 'subspace_key2' : 'val2'}
True
Causes key 'subspace_key1' and 'subspace_key2' to be set. Useful in conjunction with a higher-level layer which applies namespaces to data in memcache.
In this case, the return result would be the list of notset original keys, prefix not applied.
@param min_compress_len: The threshold length to kick in auto-compression
of the value using the zlib.compress() routine. If the value being cached is
a string, then the length of the string is measured, else if the value is an
object, then the length of the pickle result is measured. If the resulting
attempt at compression yeilds a larger string than the input, then it is
discarded. For backwards compatability, this parameter defaults to 0,
indicating don't ever try to compress.
@return: List of keys which failed to be stored [ memcache out of memory, etc. ].
@rtype: list
'''
self._statlog('set_multi')
server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(mapping.iterkeys(), key_prefix)
# send out all requests on each server before reading anything
dead_servers = []
for server in server_keys.iterkeys():
bigcmd = []
write = bigcmd.append
try:
for key in server_keys[server]: # These are mangled keys
store_info = self._val_to_store_info(mapping[prefixed_to_orig_key[key]], min_compress_len)
write("set %s %d %d %d\r\n%s\r\n" % (key, store_info[0], time, store_info[1], store_info[2]))
server.send_cmds(''.join(bigcmd))
except socket.error, msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
dead_servers.append(server)
# if any servers died on the way, don't expect them to respond.
for server in dead_servers:
del server_keys[server]
# short-circuit if there are no servers, just return all keys
if not server_keys: return(mapping.keys())
notstored = [] # original keys.
for server, keys in server_keys.iteritems():
try:
for key in keys:
line = server.readline()
if line == 'STORED':
continue
else:
notstored.append(prefixed_to_orig_key[key]) #un-mangle.
except (_Error, socket.error), msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
return notstored
def _val_to_store_info(self, val, min_compress_len):
"""
Transform val to a storable representation, returning a tuple of the flags, the length of the new value, and the new value itself.
"""
flags = 0
if isinstance(val, str):
pass
elif isinstance(val, int):
flags |= Client._FLAG_INTEGER
val = "%d" % val
# force no attempt to compress this silly string.
min_compress_len = 0
elif isinstance(val, long):
flags |= Client._FLAG_LONG
val = "%d" % val
# force no attempt to compress this silly string.
min_compress_len = 0
else:
flags |= Client._FLAG_PICKLE
file = StringIO()
if self.picklerIsKeyword:
pickler = self.pickler(file, protocol = self.pickleProtocol)
else:
pickler = self.pickler(file, self.pickleProtocol)
if self.persistent_id:
pickler.persistent_id = self.persistent_id
pickler.dump(val)
val = file.getvalue()
lv = len(val)
# We should try to compress if min_compress_len > 0 and we could
# import zlib and this string is longer than our min threshold.
if min_compress_len and _supports_compress and lv > min_compress_len:
comp_val = compress(val)
# Only retain the result if the compression result is smaller
# than the original.
if len(comp_val) < lv:
flags |= Client._FLAG_COMPRESSED
val = comp_val
# silently do not store if value length exceeds maximum
if len(val) >= SERVER_MAX_VALUE_LENGTH: return(0)
return (flags, len(val), val)
def _set(self, cmd, key, val, time, min_compress_len = 0, token=None):
check_key(key)
server, key = self._get_server(key)
if not server:
return 0
self._statlog(cmd)
store_info = self._val_to_store_info(val, min_compress_len)
if not store_info: return(0)
if token is not None:
cmd = "cas"
fullcmd = "cas %s %d %d %d %s\r\n%s" % (key, store_info[0], time, store_info[1], token, store_info[2])
else:
fullcmd = "%s %s %d %d %d\r\n%s" % (cmd, key, store_info[0], time, store_info[1], store_info[2])
try:
server.send_cmd(fullcmd)
result = server.expect("STORED")
if (result == "STORED"):
return True
if (result == "NOT_FOUND"):
raise NotFoundError(key)
if token and result == "EXISTS":
log.debug("Memcacheclient check-and-set failed")
raise TokenMismatchError(key)
log.error("Memcacheclient %s command failed with result (%s)" %
(cmd, result))
return False
except socket.error, msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
return 0
def get(self, key):
'''Retrieves a key from the memcache.
@return: The value or None.
'''
check_key(key)
server, key = self._get_server(key)
if not server:
raise MemcacheError("Memcache connection error")
self._statlog('get')
try:
server.send_cmd("get %s" % key)
rkey, flags, rlen, = self._expectvalue(server)
if not rkey:
return None
value = self._recv_value(server, flags, rlen)
server.expect("END")
except (_Error, socket.error), msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
raise MemcacheError("Memcache connection error")
return value
def gets(self, key):
'''Retrieves a key from the memcache.
@return: The value or None.
'''
check_key(key)
server, key = self._get_server(key)
if not server:
raise MemcacheError("Memcache connection error")
self._statlog('get')
try:
server.send_cmd("gets %s" % key)
rkey, flags, rlen, cas_token = self._expectvalue_cas(server)
if not rkey:
return (None, None)
value = self._recv_value(server, flags, rlen)
server.expect("END")
except (_Error, socket.error), msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
raise MemcacheError("Memcache connection error")
return (value, cas_token)
def get_multi(self, keys, key_prefix=''):
'''
Retrieves multiple keys from the memcache doing just one query.
>>> success = mc.set("foo", "bar")
>>> success = mc.set("baz", 42)
>>> mc.get_multi(["foo", "baz", "foobar"]) == {"foo": "bar", "baz": 42}
1
>>> mc.set_multi({'k1' : 1, 'k2' : 2}, key_prefix='pfx_') == []
1
This looks up keys 'pfx_k1', 'pfx_k2', ... . Returned dict will just have unprefixed keys 'k1', 'k2'.
>>> mc.get_multi(['k1', 'k2', 'nonexist'], key_prefix='pfx_') == {'k1' : 1, 'k2' : 2}
1
get_mult [ and L{set_multi} ] can take str()-ables like ints / longs as keys too. Such as your db pri key fields.
They're rotored through str() before being passed off to memcache, with or without the use of a key_prefix.
In this mode, the key_prefix could be a table name, and the key itself a db primary key number.
>>> mc.set_multi({42: 'douglass adams', 46 : 'and 2 just ahead of me'}, key_prefix='numkeys_') == []
1
>>> mc.get_multi([46, 42], key_prefix='numkeys_') == {42: 'douglass adams', 46 : 'and 2 just ahead of me'}
1
This method is recommended over regular L{get} as it lowers the number of
total packets flying around your network, reducing total latency, since
your app doesn't have to wait for each round-trip of L{get} before sending
the next one.
See also L{set_multi}.
@param keys: An array of keys.
@param key_prefix: A string to prefix each key when we communicate with memcache.
Facilitates pseudo-namespaces within memcache. Returned dictionary keys will not have this prefix.
@return: A dictionary of key/value pairs that were available. If key_prefix was provided, the keys in the retured dictionary will not have it present.
'''
self._statlog('get_multi')
server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(keys, key_prefix)
# send out all requests on each server before reading anything
dead_servers = []
for server in server_keys.iterkeys():
try:
server.send_cmd("get %s" % " ".join(server_keys[server]))
except socket.error, msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
dead_servers.append(server)
# if any servers died on the way, don't expect them to respond.
for server in dead_servers:
del server_keys[server]
retvals = {}
for server in server_keys.iterkeys():
try:
line = server.readline()
while line and line != 'END':
rkey, flags, rlen = self._expectvalue(server, line)
# Bo Yang reports that this can sometimes be None
if rkey is not None:
val = self._recv_value(server, flags, rlen)
try:
retvals[prefixed_to_orig_key[rkey]] = val # un-prefix returned key.
except KeyError:
pass
line = server.readline()
except (_Error, socket.error), msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
return retvals
def gets_multi(self, keys, key_prefix=''):
'''
Retrieves multiple keys from the memcache doing just one query.
See also L{gets} and L{get_multi}.
'''
self._statlog('gets_multi')
server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(keys, key_prefix)
# send out all requests on each server before reading anything
dead_servers = []
for server in server_keys.iterkeys():
try:
server.send_cmd("gets %s" % " ".join(server_keys[server]))
except socket.error, msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
dead_servers.append(server)
# if any servers died on the way, don't expect them to respond.
for server in dead_servers:
del server_keys[server]
retvals = {}
for server in server_keys.iterkeys():
try:
line = server.readline()
while line and line != 'END':
rkey, flags, rlen, cas_token = self._expectvalue_cas(server, line)
# Bo Yang reports that this can sometimes be None
if rkey is not None:
val = self._recv_value(server, flags, rlen)
try:
retvals[prefixed_to_orig_key[rkey]] = (val, cas_token) # un-prefix returned key.
except KeyError:
pass
line = server.readline()
except (_Error, socket.error), msg:
if type(msg) is types.TupleType: msg = msg[1]
server.mark_dead(msg)
return retvals
def _expectvalue(self, server, line=None):
if not line:
line = server.readline()
if line[:5] == 'VALUE':
resp, rkey, flags, len = line.split()
flags = int(flags)
rlen = int(len)
return (rkey, flags, rlen)
else:
return (None, None, None)
def _expectvalue_cas(self, server, line=None):
if not line:
line = server.readline()
if line[:5] == 'VALUE':
resp, rkey, flags, len, rtoken = line.split()
flags = int(flags)
rlen = int(len)
return (rkey, flags, rlen, rtoken)
else:
return (None, None, None, None)
def _recv_value(self, server, flags, rlen):
rlen += 2 # include \r\n
buf = server.recv(rlen)
if len(buf) != rlen:
raise _Error("received %d bytes when expecting %d" % (len(buf), rlen))
if len(buf) == rlen:
buf = buf[:-2] # strip \r\n
if flags & Client._FLAG_COMPRESSED:
buf = decompress(buf)
if flags == 0 or flags == Client._FLAG_COMPRESSED:
# Either a bare string or a compressed string now decompressed...
val = buf
elif flags & Client._FLAG_INTEGER:
val = int(buf)
elif flags & Client._FLAG_LONG:
val = long(buf)
elif flags & Client._FLAG_PICKLE:
try:
file = StringIO(buf)
unpickler = self.unpickler(file)
if self.persistent_load:
unpickler.persistent_load = self.persistent_load
val = unpickler.load()
except Exception, e:
self.debuglog('Pickle error: %s\n' % e)
val = None
else:
self.debuglog("unknown flags on get: %x\n" % flags)
return val
class TestClient(Client):
"""
Fake memcache client for unit tests
"""
def __init__(self, servers, debug=0, pickleProtocol=0,
pickler=pickle.Pickler, unpickler=pickle.Unpickler,
pload=None, pid=None):
local.__init__(self)
super(TestClient, self).__init__(servers, debug=debug,
pickleProtocol=pickleProtocol, pickler=pickler, unpickler=unpickler,
pload=pload, pid=pid)
self.data = {}
self.token = 0
def get_stats(self):
raise NotImplementedError()
def get_slabs(self):
raise NotImplementedError()
def flush_all(self):
raise NotImplementedError()
def forget_dead_hosts(self):
raise NotImplementedError()
def delete_multi(self, keys, time=0, key_prefix=''):
'''
Delete multiple keys in the memcache doing just one query.
>>> notset_keys = mc.set_multi({'key1' : 'val1', 'key2' : 'val2'})
>>> mc.get_multi(['key1', 'key2']) == {'key1' : 'val1', 'key2' : 'val2'}
1
>>> mc.delete_multi(['key1', 'key2'])
1
>>> mc.get_multi(['key1', 'key2']) == {}
1
'''
self._statlog('delete_multi')
for key in keys:
key = key_prefix + key
del self.data[key]
return 1
def delete(self, key, time=0):
'''Deletes a key from the memcache.
@return: Nonzero on success.
@param time: number of seconds any subsequent set / update commands should fail. Defaults to 0 for no delay.
@rtype: int
'''
check_key(key)
del self.data[key]
return 1
def incr(self, key, delta=1):
raise NotImplementedError()
def decr(self, key, delta=1):
raise NotImplementedError()
def add(self, key, val, time = 0, min_compress_len = 0):
raise NotImplementedError()
def append(self, key, val, time=0, min_compress_len=0):
raise NotImplementedError()
def prepend(self, key, val, time=0, min_compress_len=0):
raise NotImplementedError()
def replace(self, key, val, time=0, min_compress_len=0):
raise NotImplementedError()
def set(self, key, val, time=0, min_compress_len=0, token=None):
self._statlog('set')
return self._set("set", key, val, time, min_compress_len, token=token)
def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0):
self._statlog('set_multi')
for key, val in mapping.iteritems():
key = key_prefix + key
self._set("set", key, val, time, min_compress_len)
return []
def _set(self, cmd, key, val, time, min_compress_len = 0, token=None):
check_key(key)
self._statlog(cmd)
serialized = pickle.dumps(val, pickle.HIGHEST_PROTOCOL)
if token is not None:
if self.data.has_key(key):
stored_val, stored_token = self.data[key]
if token != stored_token:
raise TokenMismatchError(key)
self.data[key] = (serialized, str(self.token))
self.token += 1
return True
def get(self, key):
check_key(key)
self._statlog('get')
if self.data.has_key(key):
stored_val, stored_token = self.data[key]
val = pickle.loads(stored_val)
return val
return None
def gets(self, key):
check_key(key)
if self.data.has_key(key):
stored_val, stored_token = self.data[key]
val = pickle.loads(stored_val)
return (val, stored_token)
return (None, None)
def get_multi(self, keys, key_prefix=''):
self._statlog('get_multi')
results = {}
for key in keys:
key = key_prefix + key
val = self.get(key)
results[key] = val
return results
def gets_multi(self, keys, key_prefix=''):
self._statlog('gets_multi')
results = {}
for key in keys:
key = key_prefix + key
result = self.gets(key)
if result[1] is not None:
results[key] = result
return results
class _Host:
_DEAD_RETRY = 1 # number of seconds before retrying a dead server.
_SOCKET_TIMEOUT = 3 # number of seconds before sockets timeout.
def __init__(self, host, debugfunc=None):
if isinstance(host, types.TupleType):
host, self.weight = host
else:
self.weight = 1
# parse the connection string
m = re.match(r'^(?Punix):(?P.*)$', host)
if not m:
m = re.match(r'^(?Pinet):'
r'(?P[^:]+)(:(?P[0-9]+))?$', host)
if not m: m = re.match(r'^(?P[^:]+):(?P[0-9]+)$', host)
if not m:
raise ValueError('Unable to parse connection string: "%s"' % host)
hostData = m.groupdict()
if hostData.get('proto') == 'unix':
self.family = socket.AF_UNIX
self.address = hostData['path']
else:
self.family = socket.AF_INET
self.ip = hostData['host']
self.port = int(hostData.get('port', 11211))
self.address = ( self.ip, self.port )
if not debugfunc:
debugfunc = lambda x: x
self.debuglog = debugfunc
self.deaduntil = 0
self.socket = None
self.buffer = ''
def _check_dead(self):
if self.deaduntil and self.deaduntil > time.time():
return 1
self.deaduntil = 0
return 0
def connect(self):
if self._get_socket():
return 1
return 0
def mark_dead(self, reason):
log.error("Memcacheclient socket marked dead (%s)" % (reason,))
self.debuglog("MemCache: %s: %s. Marking dead." % (self, reason))
self.deaduntil = time.time() + _Host._DEAD_RETRY
self.close_socket()
def _get_socket(self):
if self._check_dead():
log.error("Memcacheclient _get_socket() found dead socket")
return None
if self.socket:
return self.socket
s = socket.socket(self.family, socket.SOCK_STREAM)
if hasattr(s, 'settimeout'): s.settimeout(self._SOCKET_TIMEOUT)
try:
s.connect(self.address)
except socket.timeout, msg:
log.error("Memcacheclient _get_socket() connection timed out (%s)" %
(msg,))
self.mark_dead("connect: %s" % msg)
return None
except socket.error, msg:
if type(msg) is types.TupleType: msg = msg[1]
log.error("Memcacheclient _get_socket() connection error (%s)" %
(msg,))
self.mark_dead("connect: %s" % msg[1])
return None
self.socket = s
self.buffer = ''
return s
def close_socket(self):
if self.socket:
self.socket.close()
self.socket = None
def send_cmd(self, cmd):
self.socket.sendall(cmd + '\r\n')
def send_cmds(self, cmds):
""" cmds already has trailing \r\n's applied """
self.socket.sendall(cmds)
def readline(self):
buf = self.buffer
recv = self.socket.recv
while True:
index = buf.find('\r\n')
if index >= 0:
break
data = recv(4096)
if not data:
self.mark_dead('Connection closed while reading from %s'
% repr(self))
break
buf += data
if index >= 0:
self.buffer = buf[index+2:]
buf = buf[:index]
else:
self.buffer = ''
return buf
def expect(self, text):
line = self.readline()
if line != text:
self.debuglog("while expecting '%s', got unexpected response '%s'" % (text, line))
return line
def recv(self, rlen):
self_socket_recv = self.socket.recv
buf = self.buffer
while len(buf) < rlen:
foo = self_socket_recv(4096)
buf += foo
if len(foo) == 0:
raise _Error, ( 'Read %d bytes, expecting %d, '
'read returned 0 length bytes' % ( len(buf), rlen ))
self.buffer = buf[rlen:]
return buf[:rlen]
def __str__(self):
d = ''
if self.deaduntil:
d = " (dead until %d)" % self.deaduntil
if self.family == socket.AF_INET:
return "inet:%s:%d%s" % (self.address[0], self.address[1], d)
else:
return "unix:%s%s" % (self.address, d)
def check_key(key, key_extra_len=0):
"""Checks sanity of key. Fails if:
Key length is > SERVER_MAX_KEY_LENGTH (Raises MemcachedKeyLength).
Contains control characters (Raises MemcachedKeyCharacterError).
Is not a string (Raises MemcachedStringEncodingError)
Is an unicode string (Raises MemcachedStringEncodingError)
Is not a string (Raises MemcachedKeyError)
Is None (Raises MemcachedKeyError)
"""
return # Short-circuit this expensive method
if type(key) == types.TupleType: key = key[1]
if not key:
raise Client.MemcachedKeyNoneError, ("Key is None")
if isinstance(key, unicode):
raise Client.MemcachedStringEncodingError, ("Keys must be str()'s, not "
"unicode. Convert your unicode strings using "
"mystring.encode(charset)!")
if not isinstance(key, str):
raise Client.MemcachedKeyTypeError, ("Key must be str()'s")
if isinstance(key, basestring):
if len(key) + key_extra_len > SERVER_MAX_KEY_LENGTH:
raise Client.MemcachedKeyLengthError, ("Key length is > %s"
% SERVER_MAX_KEY_LENGTH)
for char in key:
if ord(char) < 32 or ord(char) == 127:
raise Client.MemcachedKeyCharacterError, "Control characters not allowed"
def _doctest():
import doctest, memcacheclient
servers = ["127.0.0.1:11211"]
mc = Client(servers, debug=1)
globs = {"mc": mc}
return doctest.testmod(memcacheclient, globs=globs)
if __name__ == "__main__":
print("Testing docstrings...")
_doctest()
print("Running tests:")
print
serverList = [["127.0.0.1:11211"]]
if '--do-unix' in sys.argv:
serverList.append([os.path.join(os.getcwd(), 'memcached.socket')])
for servers in serverList:
mc = Client(servers, debug=1)
def to_s(val):
if not isinstance(val, types.StringTypes):
return "%s (%s)" % (val, type(val))
return "%s" % val
def test_setget(key, val):
print("Testing set/get {'%s': %s} ..." % (to_s(key), to_s(val)), end="")
mc.set(key, val)
newval = mc.get(key)
if newval == val:
print("OK")
return 1
else:
print("FAIL")
return 0
class FooStruct:
def __init__(self):
self.bar = "baz"
def __str__(self):
return "A FooStruct"
def __eq__(self, other):
if isinstance(other, FooStruct):
return self.bar == other.bar
return 0
test_setget("a_string", "some random string")
test_setget("an_integer", 42)
if test_setget("long", long(1<<30)):
print("Testing delete ...", end="")
if mc.delete("long"):
print("OK")
else:
print("FAIL")
print("Testing get_multi ...", end="")
print(mc.get_multi(["a_string", "an_integer"]))
print("Testing get(unknown value) ...", end="")
print(to_s(mc.get("unknown_value")))
f = FooStruct()
test_setget("foostruct", f)
print("Testing incr ...", end="")
x = mc.incr("an_integer", 1)
if x == 43:
print("OK")
else:
print("FAIL")
print("Testing decr ...", end="")
x = mc.decr("an_integer", 1)
if x == 42:
print("OK")
else:
print("FAIL")
# sanity tests
print("Testing sending spaces...", end="")
try:
x = mc.set("this has spaces", 1)
except Client.MemcachedKeyCharacterError, msg:
print("OK")
else:
print("FAIL")
print("Testing sending control characters...", end="")
try:
x = mc.set("this\x10has\x11control characters\x02", 1)
except Client.MemcachedKeyCharacterError, msg:
print("OK")
else:
print("FAIL")
print("Testing using insanely long key...", end="")
try:
x = mc.set('a'*SERVER_MAX_KEY_LENGTH + 'aaaa', 1)
except Client.MemcachedKeyLengthError, msg:
print("OK")
else:
print("FAIL")
print("Testing sending a unicode-string key...", end="")
try:
x = mc.set(u'keyhere', 1)
except Client.MemcachedStringEncodingError, msg:
print("OK", end="")
else:
print("FAIL", end="")
try:
x = mc.set((u'a'*SERVER_MAX_KEY_LENGTH).encode('utf-8'), 1)
except:
print("FAIL", end="")
else:
print("OK", end="")
import pickle
s = pickle.loads('V\\u4f1a\np0\n.')
try:
x = mc.set((s*SERVER_MAX_KEY_LENGTH).encode('utf-8'), 1)
except Client.MemcachedKeyLengthError:
print("OK")
else:
print("FAIL")
print("Testing using a value larger than the memcached value limit...", end="")
x = mc.set('keyhere', 'a'*SERVER_MAX_VALUE_LENGTH)
if mc.get('keyhere') == None:
print("OK", end="")
else:
print("FAIL", end="")
x = mc.set('keyhere', 'a'*SERVER_MAX_VALUE_LENGTH + 'aaa')
if mc.get('keyhere') == None:
print("OK")
else:
print("FAIL")
print("Testing set_multi() with no memcacheds running", end="")
mc.disconnect_all()
errors = mc.set_multi({'keyhere' : 'a', 'keythere' : 'b'})
if errors != []:
print("FAIL")
else:
print("OK")
print("Testing delete_multi() with no memcacheds running", end="")
mc.disconnect_all()
ret = mc.delete_multi({'keyhere' : 'a', 'keythere' : 'b'})
if ret != 1:
print("FAIL")
else:
print("OK")
# vim: ts=4 sw=4 et :
calendarserver-5.2+dfsg/twext/python/vcomponent.py 0000644 0001750 0001750 00000001753 12263343324 021555 0 ustar rahul rahul ##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
iCalendar utilities
"""
__all__ = [
"VComponent",
"VProperty",
"InvalidICalendarDataError",
]
# FIXME: Move twistedcaldav.ical here, but that module needs some
# cleanup first. Perhaps after porting to libical?
from twistedcaldav.ical import Component as VComponent
from twistedcaldav.ical import Property as VProperty
from twistedcaldav.ical import InvalidICalendarDataError
calendarserver-5.2+dfsg/twext/python/__init__.py 0000644 0001750 0001750 00000001205 12263343324 021114 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Extensions to twisted.python.
"""
calendarserver-5.2+dfsg/twext/python/clsprop.py 0000644 0001750 0001750 00000002616 12263343324 021046 0 ustar rahul rahul ##
# Copyright (c) 2011-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
A small utility for defining static class properties.
"""
class classproperty(object):
"""
Decorator for a method that wants to return a static class property. The
decorated method will only be invoked once, for each class, and that value
will be returned for that class.
"""
def __init__(self, thunk=None, cache=True):
self.cache = cache
self.thunk = thunk
self._classcache = {}
def __call__(self, thunk):
return self.__class__(thunk, self.cache)
def __get__(self, instance, owner):
if not self.cache:
return self.thunk(owner)
cc = self._classcache
if owner in cc:
cached = cc[owner]
else:
cached = self.thunk(owner)
cc[owner] = cached
return cached
calendarserver-5.2+dfsg/twext/patches.py 0000644 0001750 0001750 00000005271 12263343324 017472 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Patches for behavior in Twisted which calendarserver requires to be different.
"""
__all__ = []
import sys
from twisted import version
from twisted.python.versions import Version
from twisted.python.modules import getModule
def _hasIPv6ClientSupport():
"""
Does the loaded version of Twisted have IPv6 client support?
"""
lastVersionWithoutIPv6Clients = Version("twisted", 12, 0, 0)
if version > lastVersionWithoutIPv6Clients:
return True
elif version == lastVersionWithoutIPv6Clients:
# It could be a snapshot of trunk or a branch with this bug fixed.
# Don't load the module, though, as that would be a bunch of
# unnecessary work.
return "_resolveIPv6" in (getModule("twisted.internet.tcp")
.filePath.getContent())
else:
return False
def _addBackports():
"""
We currently require 2 backported bugfixes from a future release of
Twisted, for IPv6 support:
- U{IPv6 client support }
- U{TCP endpoint cancellation }
This function will activate those backports. (Note it must be run before
any of the modules in question are imported or it will raise an exception.)
This function, L{_hasIPv6ClientSupport}, and all the associated backports
(i.e., all of C{twext/backport}) should be removed upon upgrading our
minimum required Twisted version.
"""
from twext.backport import internet as bpinternet
from twisted import internet
internet.__path__[:] = bpinternet.__path__ + internet.__path__
# Make sure none of the backports are loaded yet.
backports = getModule("twext.backport.internet")
for submod in backports.iterModules():
subname = submod.name.split(".")[-1]
tiname = 'twisted.internet.' + subname
if tiname in sys.modules:
raise RuntimeError(
tiname + "already loaded, cannot load required backport")
if not _hasIPv6ClientSupport():
_addBackports()
from twisted.mail.imap4 import Command
Command._1_RESPONSES += tuple(['BYE'])
calendarserver-5.2+dfsg/twext/enterprise/ 0000755 0001750 0001750 00000000000 12322625326 017645 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/enterprise/ienterprise.py 0000644 0001750 0001750 00000024237 12263343324 022557 0 ustar rahul rahul ##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Interfaces, mostly related to L{twext.enterprise.adbapi2}.
"""
__all__ = [
"IAsyncTransaction",
"ISQLExecutor",
"ICommandBlock",
"IQueuer",
"IDerivedParameter",
"AlreadyFinishedError",
"ConnectionError",
"POSTGRES_DIALECT",
"SQLITE_DIALECT",
"ORACLE_DIALECT",
"ORACLE_TABLE_NAME_MAX",
]
from zope.interface import Interface, Attribute
class AlreadyFinishedError(Exception):
"""
The transaction was already completed via an C{abort} or C{commit} and
cannot be aborted or committed again.
"""
class ConnectionError(Exception):
"""
An error occurred with the underlying database connection.
"""
POSTGRES_DIALECT = 'postgres-dialect'
ORACLE_DIALECT = 'oracle-dialect'
SQLITE_DIALECT = 'sqlite-dialect'
ORACLE_TABLE_NAME_MAX = 30
class ISQLExecutor(Interface):
"""
Base SQL-execution interface, for a group of commands or a transaction.
"""
paramstyle = Attribute(
"""
A copy of the 'paramstyle' attribute from a DB-API 2.0 module.
""")
dialect = Attribute(
"""
A copy of the 'dialect' attribute from the connection pool. One of the
C{*_DIALECT} constants in this module, such as C{POSTGRES_DIALECT}.
""")
def execSQL(sql, args=(), raiseOnZeroRowCount=None):
"""
Execute some SQL.
@param sql: an SQL string.
@type sql: C{str}
@param args: C{list} of arguments to interpolate into C{sql}.
@param raiseOnZeroRowCount: a 0-argument callable which returns an
exception to raise if the executed SQL does not affect any rows.
@return: L{Deferred} which fires C{list} of C{tuple}
@raise: C{raiseOnZeroRowCount} if it was specified and no rows were
affected.
"""
class IAsyncTransaction(ISQLExecutor):
"""
Asynchronous execution of SQL.
Note that there is no C{begin()} method; if an L{IAsyncTransaction} exists
at all, it is assumed to have been started.
"""
def commit():
"""
Commit changes caused by this transaction.
@return: L{Deferred} which fires with C{None} upon successful
completion of this transaction, or fails if this transaction could
not be committed. It fails with L{AlreadyFinishedError} if the
transaction has already been committed or rolled back.
"""
def preCommit(operation):
"""
Perform the given operation when this L{IAsyncTransaction}'s C{commit}
method is called, but before the underlying transaction commits. If
any exception is raised by this operation, underlying database commit
will be blocked and rollback run instead.
@param operation: a 0-argument callable that may return a L{Deferred}.
If it does, then the subsequent operations added by L{postCommit}
will not fire until that L{Deferred} fires.
"""
def postCommit(operation):
"""
Perform the given operation only after this L{IAsyncTransaction}
commits. These will be invoked before the L{Deferred} returned by
L{IAsyncTransaction.commit} fires.
@param operation: a 0-argument callable that may return a L{Deferred}.
If it does, then the subsequent operations added by L{postCommit}
will not fire until that L{Deferred} fires.
"""
def abort():
"""
Roll back changes caused by this transaction.
@return: L{Deferred} which fires with C{None} upon successful
rollback of this transaction.
"""
def postAbort(operation):
"""
Invoke a callback after abort.
@see: L{IAsyncTransaction.postCommit}
@param operation: 0-argument callable, potentially returning a
L{Deferred}.
"""
def commandBlock():
"""
Create an object which will cause the commands executed on it to be
grouped together.
This is useful when using database-specific features such as
sub-transactions where order of execution is importnat, but where
application code may need to perform I/O to determine what SQL,
exactly, it wants to execute. Consider this fairly contrived example
for an imaginary database::
def storeWebPage(url, block):
block.execSQL("BEGIN SUB TRANSACTION")
got = getPage(url)
def gotPage(data):
block.execSQL("INSERT INTO PAGES (TEXT) VALUES (?)",
[data])
block.execSQL("INSERT INTO INDEX (TOKENS) VALUES (?)",
[tokenize(data)])
lastStmt = block.execSQL("END SUB TRANSACTION")
block.end()
return lastStmt
return got.addCallback(gotPage)
gatherResults([storeWebPage(url, txn.commandBlock())
for url in urls]).addCallbacks(
lambda x: txn.commit(), lambda f: txn.abort()
)
This fires off all the C{getPage} requests in parallel, and prepares
all the necessary SQL immediately as the results arrive, but executes
those statements in order. In the above example, this makes sure to
store the page and its tokens together, another use for this might be
to store a computed aggregate (such as a sum) at a particular point in
a transaction, without sacrificing parallelism.
@rtype: L{ICommandBlock}
"""
class ICommandBlock(ISQLExecutor):
"""
This is a block of SQL commands that are grouped together.
@see: L{IAsyncTransaction.commandBlock}
"""
def end():
"""
End this command block, allowing other commands queued on the
underlying transaction to end.
@note: This is I{not} the same as either L{IAsyncTransaction.commit} or
L{IAsyncTransaction.abort}, since it does not denote success or
failure; merely that the command block has completed and other
statements may now be executed. Since sub-transactions are a
database-specific feature, they must be implemented at a
higher-level than this facility provides (although this facility
may be useful in their implementation). Also note that, unlike
either of those methods, this does I{not} return a Deferred: if you
want to know when the block has completed, simply add a callback to
the last L{ICommandBlock.execSQL} call executed on this
L{ICommandBlock}. (This may be changed in a future version for the
sake of convenience, however.)
"""
class IDerivedParameter(Interface):
"""
A parameter which needs to be derived from the underlying DB-API cursor;
implicitly, meaning that this must also interact with the actual thread
manipulating said cursor. If a provider of this interface is passed in the
C{args} argument to L{IAsyncTransaction.execSQL}, it will have its
C{prequery} and C{postquery} methods invoked on it before and after
executing the SQL query in question, respectively.
@note: L{IDerivedParameter} providers must also always be I{pickleable},
because in some cases the actual database cursor objects will be on the
other end of a network connection. For an explanation of why this
might be, see L{twext.enterprise.adbapi2.ConnectionPoolConnection}.
"""
def preQuery(cursor):
"""
Before running a query, invoke this method with the cursor that the
query will be run on.
(This can be used, for example, to allocate a special database-specific
variable based on the cursor, like an out parameter.)
@param cursor: the DB-API cursor.
@return: the concrete value which should be passed to the DB-API layer.
"""
def postQuery(cursor):
"""
After running a query, invoke this method in the DB-API thread.
(This can be used, for example, to manipulate any state created in the
preQuery method.)
@param cursor: the DB-API cursor.
@return: C{None}
"""
class IQueuer(Interface):
"""
An L{IQueuer} can enqueue work for later execution.
"""
def enqueueWork(self, transaction, workItemType, **kw):
"""
Perform some work, eventually.
@param transaction: an L{IAsyncTransaction} within which to I{commit}
to doing the work. Note that this work will likely be done later
(but depending on various factors, may actually be done within this
transaction as well).
@param workItemType: the type of work item to create.
@type workItemType: L{type}, specifically, a subtype of L{WorkItem
}
@param kw: The keyword parameters are relayed to C{workItemType.create}
to create an appropriately initialized item.
@return: a work proposal that allows tracking of the various phases of
completion of the work item.
@rtype: L{twext.enterprise.queue.WorkItem}
"""
def callWithNewProposals(self, callback):
"""
Tells the IQueuer to call a callback method whenever a new WorkProposal
is created.
@param callback: a callable which accepts a single parameter, a
L{WorkProposal}
"""
def transferProposalCallbacks(self, newQueuer):
"""
Transfer the registered callbacks to the new queuer.
"""
calendarserver-5.2+dfsg/twext/enterprise/test/ 0000755 0001750 0001750 00000000000 12322625326 020624 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/enterprise/test/test_fixtures.py 0000644 0001750 0001750 00000003234 12263343324 024107 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for L{twext.enterprise.fixtures}.
Quis custodiet ipsos custodes? This module, that's who.
"""
from twext.enterprise.fixtures import buildConnectionPool
from twisted.trial.unittest import TestCase
from twisted.trial.reporter import TestResult
from twext.enterprise.adbapi2 import ConnectionPool
class PoolTests(TestCase):
"""
Tests for fixtures that create a connection pool.
"""
def test_buildConnectionPool(self):
"""
L{buildConnectionPool} returns a L{ConnectionPool} which will be
running only for the duration of the test.
"""
collect = []
class SampleTest(TestCase):
def setUp(self):
self.pool = buildConnectionPool(self)
def test_sample(self):
collect.append(self.pool.running)
def tearDown(self):
collect.append(self.pool.running)
r = TestResult()
t = SampleTest("test_sample")
t.run(r)
self.assertIsInstance(t.pool, ConnectionPool)
self.assertEqual([True, False], collect)
calendarserver-5.2+dfsg/twext/enterprise/test/test_queue.py 0000644 0001750 0001750 00000072351 12276242656 023402 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for L{twext.enterprise.queue}.
"""
import datetime
# TODO: There should be a store-building utility within twext.enterprise.
from twisted.protocols.amp import Command
from twisted.internet.task import Clock as _Clock
from txdav.common.datastore.test.util import buildStore
from twext.enterprise.dal.syntax import SchemaSyntax, Select
from twext.enterprise.dal.record import fromTable
from twext.enterprise.dal.test.test_parseschema import SchemaTestHelper
from twext.enterprise.queue import (
inTransaction, PeerConnectionPool, WorkItem, astimestamp
)
from twisted.trial.unittest import TestCase
from twisted.python.failure import Failure
from twisted.internet.defer import (
Deferred, inlineCallbacks, gatherResults, passthru#, returnValue
)
from twisted.application.service import Service, MultiService
from twext.enterprise.queue import (
LocalPerformer, _IWorkPerformer, WorkerConnectionPool, SchemaAMP,
TableSyntaxByName
)
from twext.enterprise.dal.record import Record
from twext.enterprise.queue import ConnectionFromPeerNode
from twext.enterprise.fixtures import buildConnectionPool
from zope.interface.verify import verifyObject
from twisted.test.proto_helpers import StringTransport, MemoryReactor
from twext.enterprise.fixtures import SteppablePoolHelper
from twisted.internet.defer import returnValue
from twext.enterprise.queue import LocalQueuer
from twext.enterprise.fixtures import ConnectionPoolHelper
from twext.enterprise.queue import _BaseQueuer, NonPerformingQueuer
import twext.enterprise.queue
class Clock(_Clock):
"""
More careful L{IReactorTime} fake which mimics the exception behavior of
the real reactor.
"""
def callLater(self, _seconds, _f, *args, **kw):
if _seconds < 0:
raise ValueError("%s<0: " % (_seconds,))
return super(Clock, self).callLater(_seconds, _f, *args, **kw)
class MemoryReactorWithClock(MemoryReactor, Clock):
"""
Simulate a real reactor.
"""
def __init__(self):
MemoryReactor.__init__(self)
Clock.__init__(self)
def transactionally(transactionCreator):
"""
Perform the decorated function immediately in a transaction, replacing its
name with a L{Deferred}.
Use like so::
@transactionally(connectionPool.connection)
@inlineCallbacks
def it(txn):
yield txn.doSomething()
it.addCallback(firedWhenDone)
@param transactionCreator: A 0-arg callable that returns an
L{IAsyncTransaction}.
"""
def thunk(operation):
return inTransaction(transactionCreator, operation)
return thunk
class UtilityTests(TestCase):
"""
Tests for supporting utilities.
"""
def test_inTransactionSuccess(self):
"""
L{inTransaction} invokes its C{transactionCreator} argument, and then
returns a L{Deferred} which fires with the result of its C{operation}
argument when it succeeds.
"""
class faketxn(object):
def __init__(self):
self.commits = []
self.aborts = []
def commit(self):
self.commits.append(Deferred())
return self.commits[-1]
def abort(self):
self.aborts.append(Deferred())
return self.aborts[-1]
createdTxns = []
def createTxn():
createdTxns.append(faketxn())
return createdTxns[-1]
dfrs = []
def operation(t):
self.assertIdentical(t, createdTxns[-1])
dfrs.append(Deferred())
return dfrs[-1]
d = inTransaction(createTxn, operation)
x = []
d.addCallback(x.append)
self.assertEquals(x, [])
self.assertEquals(len(dfrs), 1)
dfrs[0].callback(35)
# Commit in progress, so still no result...
self.assertEquals(x, [])
createdTxns[0].commits[0].callback(42)
# Committed, everything's done.
self.assertEquals(x, [35])
class SimpleSchemaHelper(SchemaTestHelper):
def id(self):
return 'worker'
SQL = passthru
schemaText = SQL("""
create table DUMMY_WORK_ITEM (WORK_ID integer primary key,
NOT_BEFORE timestamp,
A integer, B integer,
DELETE_ON_LOAD integer default 0);
create table DUMMY_WORK_DONE (WORK_ID integer primary key,
A_PLUS_B integer);
""")
nodeSchema = SQL("""
create table NODE_INFO (HOSTNAME varchar(255) not null,
PID integer not null,
PORT integer not null,
TIME timestamp default current_timestamp not null,
primary key (HOSTNAME, PORT));
""")
schema = SchemaSyntax(SimpleSchemaHelper().schemaFromString(schemaText))
dropSQL = ["drop table {name}".format(name=table.model.name)
for table in schema]
class DummyWorkDone(Record, fromTable(schema.DUMMY_WORK_DONE)):
"""
Work result.
"""
class DummyWorkItem(WorkItem, fromTable(schema.DUMMY_WORK_ITEM)):
"""
Sample L{WorkItem} subclass that adds two integers together and stores them
in another table.
"""
def doWork(self):
return DummyWorkDone.create(self.transaction, workID=self.workID,
aPlusB=self.a + self.b)
@classmethod
@inlineCallbacks
def load(cls, txn, *a, **kw):
"""
Load L{DummyWorkItem} as normal... unless the loaded item has
C{DELETE_ON_LOAD} set, in which case, do a deletion of this same row in
a concurrent transaction, then commit it.
"""
self = yield super(DummyWorkItem, cls).load(txn, *a, **kw)
if self.deleteOnLoad:
otherTransaction = txn.concurrently()
otherSelf = yield super(DummyWorkItem, cls).load(txn, *a, **kw)
yield otherSelf.delete()
yield otherTransaction.commit()
returnValue(self)
class SchemaAMPTests(TestCase):
"""
Tests for L{SchemaAMP} faithfully relaying tables across the wire.
"""
def test_sendTableWithName(self):
"""
You can send a reference to a table through a L{SchemaAMP} via
L{TableSyntaxByName}.
"""
client = SchemaAMP(schema)
class SampleCommand(Command):
arguments = [('table', TableSyntaxByName())]
class Receiver(SchemaAMP):
@SampleCommand.responder
def gotIt(self, table):
self.it = table
return {}
server = Receiver(schema)
clientT = StringTransport()
serverT = StringTransport()
client.makeConnection(clientT)
server.makeConnection(serverT)
client.callRemote(SampleCommand, table=schema.DUMMY_WORK_ITEM)
server.dataReceived(clientT.io.getvalue())
self.assertEqual(server.it, schema.DUMMY_WORK_ITEM)
class WorkItemTests(TestCase):
"""
A L{WorkItem} is an item of work that can be executed.
"""
def test_forTable(self):
"""
L{WorkItem.forTable} returns L{WorkItem} subclasses mapped to the given
table.
"""
self.assertIdentical(WorkItem.forTable(schema.DUMMY_WORK_ITEM),
DummyWorkItem)
class WorkerConnectionPoolTests(TestCase):
"""
A L{WorkerConnectionPool} is responsible for managing, in a node's
controller (master) process, the collection of worker (slave) processes
that are capable of executing queue work.
"""
class WorkProposalTests(TestCase):
"""
Tests for L{WorkProposal}.
"""
def test_whenProposedSuccess(self):
"""
The L{Deferred} returned by L{WorkProposal.whenProposed} fires when the
SQL sent to the database has completed.
"""
cph = ConnectionPoolHelper()
cph.setUp(test=self)
cph.pauseHolders()
lq = LocalQueuer(cph.createTransaction)
enqTxn = cph.createTransaction()
wp = lq.enqueueWork(enqTxn, DummyWorkItem, a=3, b=4)
d = wp.whenProposed()
r = cph.resultOf(d)
self.assertEquals(r, [])
cph.flushHolders()
self.assertEquals(len(r), 1)
def test_whenProposedFailure(self):
"""
The L{Deferred} returned by L{WorkProposal.whenProposed} fails with an
errback when the SQL executed to create the WorkItem row fails.
"""
cph = ConnectionPoolHelper()
cph.setUp(self)
cph.pauseHolders()
firstConnection = cph.factory.willConnectTo()
enqTxn = cph.createTransaction()
# Execute some SQL on the connection before enqueueing the work-item so
# that we don't get the initial-statement.
enqTxn.execSQL("some sql")
lq = LocalQueuer(cph.createTransaction)
cph.flushHolders()
cph.pauseHolders()
wp = lq.enqueueWork(enqTxn, DummyWorkItem, a=3, b=4)
firstConnection.executeWillFail(lambda: RuntimeError("foo"))
d = wp.whenProposed()
r = cph.resultOf(d)
self.assertEquals(r, [])
cph.flushHolders()
self.assertEquals(len(r), 1)
self.assertIsInstance(r[0], Failure)
class PeerConnectionPoolUnitTests(TestCase):
"""
L{PeerConnectionPool} has many internal components.
"""
def setUp(self):
"""
Create a L{PeerConnectionPool} that is just initialized enough.
"""
self.pcp = PeerConnectionPool(None, None, 4321, schema)
def checkPerformer(self, cls):
"""
Verify that the performer returned by
L{PeerConnectionPool.choosePerformer}.
"""
performer = self.pcp.choosePerformer()
self.failUnlessIsInstance(performer, cls)
verifyObject(_IWorkPerformer, performer)
def test_choosingPerformerWhenNoPeersAndNoWorkers(self):
"""
If L{PeerConnectionPool.choosePerformer} is invoked when no workers
have spawned and no peers have established connections (either incoming
or outgoing), then it chooses an implementation of C{performWork} that
simply executes the work locally.
"""
self.checkPerformer(LocalPerformer)
def test_choosingPerformerWithLocalCapacity(self):
"""
If L{PeerConnectionPool.choosePerformer} is invoked when some workers
have spawned, then it should choose the worker pool as the local
performer.
"""
# Give it some local capacity.
wlf = self.pcp.workerListenerFactory()
proto = wlf.buildProtocol(None)
proto.makeConnection(StringTransport())
# Sanity check.
self.assertEqual(len(self.pcp.workerPool.workers), 1)
self.assertEqual(self.pcp.workerPool.hasAvailableCapacity(), True)
# Now it has some capacity.
self.checkPerformer(WorkerConnectionPool)
def test_choosingPerformerFromNetwork(self):
"""
If L{PeerConnectionPool.choosePerformer} is invoked when no workers
have spawned but some peers have connected, then it should choose a
connection from the network to perform it.
"""
peer = PeerConnectionPool(None, None, 4322, schema)
local = self.pcp.peerFactory().buildProtocol(None)
remote = peer.peerFactory().buildProtocol(None)
connection = Connection(local, remote)
connection.start()
self.checkPerformer(ConnectionFromPeerNode)
def test_performingWorkOnNetwork(self):
"""
The L{PerformWork} command will get relayed to the remote peer
controller.
"""
peer = PeerConnectionPool(None, None, 4322, schema)
local = self.pcp.peerFactory().buildProtocol(None)
remote = peer.peerFactory().buildProtocol(None)
connection = Connection(local, remote)
connection.start()
d = Deferred()
class DummyPerformer(object):
def performWork(self, table, workID):
self.table = table
self.workID = workID
return d
# Doing real database I/O in this test would be tedious so fake the
# first method in the call stack which actually talks to the DB.
dummy = DummyPerformer()
def chooseDummy(onlyLocally=False):
return dummy
peer.choosePerformer = chooseDummy
performed = local.performWork(schema.DUMMY_WORK_ITEM, 7384)
performResult = []
performed.addCallback(performResult.append)
# Sanity check.
self.assertEquals(performResult, [])
connection.flush()
self.assertEquals(dummy.table, schema.DUMMY_WORK_ITEM)
self.assertEquals(dummy.workID, 7384)
self.assertEquals(performResult, [])
d.callback(128374)
connection.flush()
self.assertEquals(performResult, [None])
def test_choosePerformerSorted(self):
"""
If L{PeerConnectionPool.choosePerformer} is invoked make it
return the peer with the least load.
"""
peer = PeerConnectionPool(None, None, 4322, schema)
class DummyPeer(object):
def __init__(self, name, load):
self.name = name
self.load = load
def currentLoadEstimate(self):
return self.load
apeer = DummyPeer("A", 1)
bpeer = DummyPeer("B", 0)
cpeer = DummyPeer("C", 2)
peer.addPeerConnection(apeer)
peer.addPeerConnection(bpeer)
peer.addPeerConnection(cpeer)
performer = peer.choosePerformer(onlyLocally=False)
self.assertEqual(performer, bpeer)
bpeer.load = 2
performer = peer.choosePerformer(onlyLocally=False)
self.assertEqual(performer, apeer)
@inlineCallbacks
def test_notBeforeWhenCheckingForLostWork(self):
"""
L{PeerConnectionPool._periodicLostWorkCheck} should execute any
outstanding work items, but only those that are expired.
"""
dbpool = buildConnectionPool(self, schemaText + nodeSchema)
# An arbitrary point in time.
fakeNow = datetime.datetime(2012, 12, 12, 12, 12, 12)
# *why* does datetime still not have .astimestamp()
sinceEpoch = astimestamp(fakeNow)
clock = Clock()
clock.advance(sinceEpoch)
qpool = PeerConnectionPool(clock, dbpool.connection, 0, schema)
# Let's create a couple of work items directly, not via the enqueue
# method, so that they exist but nobody will try to immediately execute
# them.
@transactionally(dbpool.connection)
@inlineCallbacks
def setup(txn):
# First, one that's right now.
yield DummyWorkItem.create(txn, a=1, b=2, notBefore=fakeNow)
# Next, create one that's actually far enough into the past to run.
yield DummyWorkItem.create(
txn, a=3, b=4, notBefore=(
# Schedule it in the past so that it should have already
# run.
fakeNow - datetime.timedelta(
seconds=qpool.queueProcessTimeout + 20
)
)
)
# Finally, one that's actually scheduled for the future.
yield DummyWorkItem.create(
txn, a=10, b=20, notBefore=fakeNow + datetime.timedelta(1000)
)
yield setup
yield qpool._periodicLostWorkCheck()
@transactionally(dbpool.connection)
def check(txn):
return DummyWorkDone.all(txn)
every = yield check
self.assertEquals([x.aPlusB for x in every], [7])
@inlineCallbacks
def test_notBeforeWhenEnqueueing(self):
"""
L{PeerConnectionPool.enqueueWork} enqueues some work immediately, but
only executes it when enough time has elapsed to allow the C{notBefore}
attribute of the given work item to have passed.
"""
dbpool = buildConnectionPool(self, schemaText + nodeSchema)
fakeNow = datetime.datetime(2012, 12, 12, 12, 12, 12)
sinceEpoch = astimestamp(fakeNow)
clock = Clock()
clock.advance(sinceEpoch)
qpool = PeerConnectionPool(clock, dbpool.connection, 0, schema)
realChoosePerformer = qpool.choosePerformer
performerChosen = []
def catchPerformerChoice():
result = realChoosePerformer()
performerChosen.append(True)
return result
qpool.choosePerformer = catchPerformerChoice
@transactionally(dbpool.connection)
def check(txn):
return qpool.enqueueWork(
txn, DummyWorkItem, a=3, b=9,
notBefore=datetime.datetime(2012, 12, 12, 12, 12, 20)
).whenProposed()
proposal = yield check
# This is going to schedule the work to happen with some asynchronous
# I/O in the middle; this is a problem because how do we know when it's
# time to check to see if the work has started? We need to intercept
# the thing that kicks off the work; we can then wait for the work
# itself.
self.assertEquals(performerChosen, [])
# Advance to exactly the appointed second.
clock.advance(20 - 12)
self.assertEquals(performerChosen, [True])
# FIXME: if this fails, it will hang, but that's better than no
# notification that it is broken at all.
result = yield proposal.whenExecuted()
self.assertIdentical(result, proposal)
@inlineCallbacks
def test_notBeforeBefore(self):
"""
L{PeerConnectionPool.enqueueWork} will execute its work immediately if
the C{notBefore} attribute of the work item in question is in the past.
"""
dbpool = buildConnectionPool(self, schemaText + nodeSchema)
fakeNow = datetime.datetime(2012, 12, 12, 12, 12, 12)
sinceEpoch = astimestamp(fakeNow)
clock = Clock()
clock.advance(sinceEpoch)
qpool = PeerConnectionPool(clock, dbpool.connection, 0, schema)
realChoosePerformer = qpool.choosePerformer
performerChosen = []
def catchPerformerChoice():
result = realChoosePerformer()
performerChosen.append(True)
return result
qpool.choosePerformer = catchPerformerChoice
@transactionally(dbpool.connection)
def check(txn):
return qpool.enqueueWork(
txn, DummyWorkItem, a=3, b=9,
notBefore=datetime.datetime(2012, 12, 12, 12, 12, 0)
).whenProposed()
proposal = yield check
clock.advance(1000)
# Advance far beyond the given timestamp.
self.assertEquals(performerChosen, [True])
result = yield proposal.whenExecuted()
self.assertIdentical(result, proposal)
def test_workerConnectionPoolPerformWork(self):
"""
L{WorkerConnectionPool.performWork} performs work by selecting a
L{ConnectionFromWorker} and sending it a L{PerformWork} command.
"""
clock = Clock()
peerPool = PeerConnectionPool(clock, None, 4322, schema)
factory = peerPool.workerListenerFactory()
def peer():
p = factory.buildProtocol(None)
t = StringTransport()
p.makeConnection(t)
return p, t
worker1, _ignore_trans1 = peer()
worker2, _ignore_trans2 = peer()
# Ask the worker to do something.
worker1.performWork(schema.DUMMY_WORK_ITEM, 1)
self.assertEquals(worker1.currentLoad, 1)
self.assertEquals(worker2.currentLoad, 0)
# Now ask the pool to do something
peerPool.workerPool.performWork(schema.DUMMY_WORK_ITEM, 2)
self.assertEquals(worker1.currentLoad, 1)
self.assertEquals(worker2.currentLoad, 1)
def test_poolStartServiceChecksForWork(self):
"""
L{PeerConnectionPool.startService} kicks off the idle work-check loop.
"""
reactor = MemoryReactorWithClock()
cph = SteppablePoolHelper(nodeSchema + schemaText)
then = datetime.datetime(2012, 12, 12, 12, 12, 0)
reactor.advance(astimestamp(then))
cph.setUp(self)
pcp = PeerConnectionPool(reactor, cph.pool.connection, 4321, schema)
now = then + datetime.timedelta(seconds=pcp.queueProcessTimeout * 2)
@transactionally(cph.pool.connection)
def createOldWork(txn):
one = DummyWorkItem.create(txn, workID=1, a=3, b=4, notBefore=then)
two = DummyWorkItem.create(txn, workID=2, a=7, b=9, notBefore=now)
return gatherResults([one, two])
pcp.startService()
cph.flushHolders()
reactor.advance(pcp.queueProcessTimeout * 2)
self.assertEquals(
cph.rows("select * from DUMMY_WORK_DONE"),
[(1, 7)]
)
cph.rows("delete from DUMMY_WORK_DONE")
reactor.advance(pcp.queueProcessTimeout * 2)
self.assertEquals(
cph.rows("select * from DUMMY_WORK_DONE"),
[(2, 16)]
)
class HalfConnection(object):
def __init__(self, protocol):
self.protocol = protocol
self.transport = StringTransport()
def start(self):
"""
Hook up the protocol and the transport.
"""
self.protocol.makeConnection(self.transport)
def extract(self):
"""
Extract the data currently present in this protocol's output buffer.
"""
io = self.transport.io
value = io.getvalue()
io.seek(0)
io.truncate()
return value
def deliver(self, data):
"""
Deliver the given data to this L{HalfConnection}'s protocol's
C{dataReceived} method.
@return: a boolean indicating whether any data was delivered.
@rtype: L{bool}
"""
if data:
self.protocol.dataReceived(data)
return True
return False
class Connection(object):
def __init__(self, local, remote):
"""
Connect two protocol instances to each other via string transports.
"""
self.receiver = HalfConnection(local)
self.sender = HalfConnection(remote)
def start(self):
"""
Start up the connection.
"""
self.sender.start()
self.receiver.start()
def pump(self):
"""
Relay data in one direction between the two connections.
"""
result = self.receiver.deliver(self.sender.extract())
self.receiver, self.sender = self.sender, self.receiver
return result
def flush(self, turns=10):
"""
Keep relaying data until there's no more.
"""
for _ignore_x in range(turns):
if not (self.pump() or self.pump()):
return
class PeerConnectionPoolIntegrationTests(TestCase):
"""
L{PeerConnectionPool} is the service responsible for coordinating
eventually-consistent task queuing within a cluster.
"""
@inlineCallbacks
def setUp(self):
"""
L{PeerConnectionPool} requires access to a database and the reactor.
"""
self.store = yield buildStore(self, None)
def doit(txn):
return txn.execSQL(schemaText)
yield inTransaction(lambda: self.store.newTransaction("bonus schema"),
doit)
def indirectedTransactionFactory(*a):
"""
Allow tests to replace 'self.store.newTransaction' to provide
fixtures with extra methods on a test-by-test basis.
"""
return self.store.newTransaction(*a)
def deschema():
@inlineCallbacks
def deletestuff(txn):
for stmt in dropSQL:
yield txn.execSQL(stmt)
return inTransaction(lambda *a: self.store.newTransaction(*a),
deletestuff)
self.addCleanup(deschema)
from twisted.internet import reactor
self.node1 = PeerConnectionPool(
reactor, indirectedTransactionFactory, 0, schema)
self.node2 = PeerConnectionPool(
reactor, indirectedTransactionFactory, 0, schema)
class FireMeService(Service, object):
def __init__(self, d):
super(FireMeService, self).__init__()
self.d = d
def startService(self):
self.d.callback(None)
d1 = Deferred()
d2 = Deferred()
FireMeService(d1).setServiceParent(self.node1)
FireMeService(d2).setServiceParent(self.node2)
ms = MultiService()
self.node1.setServiceParent(ms)
self.node2.setServiceParent(ms)
ms.startService()
self.addCleanup(ms.stopService)
yield gatherResults([d1, d2])
self.store.queuer = self.node1
def test_currentNodeInfo(self):
"""
There will be two C{NODE_INFO} rows in the database, retrievable as two
L{NodeInfo} objects, once both nodes have started up.
"""
@inlineCallbacks
def check(txn):
self.assertEquals(len((yield self.node1.activeNodes(txn))), 2)
self.assertEquals(len((yield self.node2.activeNodes(txn))), 2)
return inTransaction(self.store.newTransaction, check)
@inlineCallbacks
def test_enqueueHappyPath(self):
"""
When a L{WorkItem} is scheduled for execution via
L{PeerConnectionPool.enqueueWork} its C{doWork} method will be invoked
by the time the L{Deferred} returned from the resulting
L{WorkProposal}'s C{whenExecuted} method has fired.
"""
# TODO: this exact test should run against LocalQueuer as well.
def operation(txn):
# TODO: how does 'enqueue' get associated with the transaction?
# This is not the fact with a raw t.w.enterprise transaction.
# Should probably do something with components.
return txn.enqueue(DummyWorkItem, a=3, b=4, workID=4321,
notBefore=datetime.datetime.utcnow())
result = yield inTransaction(self.store.newTransaction, operation)
# Wait for it to be executed. Hopefully this does not time out :-\.
yield result.whenExecuted()
def op2(txn):
return Select([schema.DUMMY_WORK_DONE.WORK_ID,
schema.DUMMY_WORK_DONE.A_PLUS_B],
From=schema.DUMMY_WORK_DONE).on(txn)
rows = yield inTransaction(self.store.newTransaction, op2)
self.assertEquals(rows, [[4321, 7]])
@inlineCallbacks
def test_noWorkDoneWhenConcurrentlyDeleted(self):
"""
When a L{WorkItem} is concurrently deleted by another transaction, it
should I{not} perform its work.
"""
# Provide access to a method called 'concurrently' everything using
original = self.store.newTransaction
def decorate(*a, **k):
result = original(*a, **k)
result.concurrently = self.store.newTransaction
return result
self.store.newTransaction = decorate
def operation(txn):
return txn.enqueue(DummyWorkItem, a=30, b=40, workID=5678,
deleteOnLoad=1,
notBefore=datetime.datetime.utcnow())
proposal = yield inTransaction(self.store.newTransaction, operation)
yield proposal.whenExecuted()
# Sanity check on the concurrent deletion.
def op2(txn):
return Select([schema.DUMMY_WORK_ITEM.WORK_ID],
From=schema.DUMMY_WORK_ITEM).on(txn)
rows = yield inTransaction(self.store.newTransaction, op2)
self.assertEquals(rows, [])
def op3(txn):
return Select([schema.DUMMY_WORK_DONE.WORK_ID,
schema.DUMMY_WORK_DONE.A_PLUS_B],
From=schema.DUMMY_WORK_DONE).on(txn)
rows = yield inTransaction(self.store.newTransaction, op3)
self.assertEquals(rows, [])
class DummyProposal(object):
def __init__(self, *ignored):
pass
def _start(self):
pass
class BaseQueuerTests(TestCase):
def setUp(self):
self.proposal = None
self.patch(twext.enterprise.queue, "WorkProposal", DummyProposal)
def _proposalCallback(self, proposal):
self.proposal = proposal
def test_proposalCallbacks(self):
queuer = _BaseQueuer()
queuer.callWithNewProposals(self._proposalCallback)
self.assertEqual(self.proposal, None)
queuer.enqueueWork(None, None)
self.assertNotEqual(self.proposal, None)
class NonPerformingQueuerTests(TestCase):
@inlineCallbacks
def test_choosePerformer(self):
queuer = NonPerformingQueuer()
performer = queuer.choosePerformer()
result = (yield performer.performWork(None, None))
self.assertEquals(result, None)
calendarserver-5.2+dfsg/twext/enterprise/test/test_util.py 0000644 0001750 0001750 00000002313 12263343324 023210 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
import datetime
from twisted.trial.unittest import TestCase
from twext.enterprise.util import parseSQLTimestamp
class TimestampTests(TestCase):
"""
Tests for date-related functions.
"""
def test_parseSQLTimestamp(self):
"""
L{parseSQLTimestamp} parses the traditional SQL timestamp.
"""
tests = (
("2012-04-04 12:34:56", datetime.datetime(2012, 4, 4, 12, 34, 56)),
("2012-12-31 01:01:01", datetime.datetime(2012, 12, 31, 1, 1, 1)),
)
for sqlStr, result in tests:
self.assertEqual(parseSQLTimestamp(sqlStr), result)
calendarserver-5.2+dfsg/twext/enterprise/test/test_locking.py 0000644 0001750 0001750 00000005366 12263346572 023704 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for mutual exclusion locks.
"""
from twisted.internet.defer import inlineCallbacks
from twisted.trial.unittest import TestCase
from twext.enterprise.fixtures import buildConnectionPool
from twext.enterprise.locking import NamedLock, LockTimeout
from twext.enterprise.dal.syntax import Select
from twext.enterprise.locking import LockSchema
schemaText = """
create table NAMED_LOCK (LOCK_NAME varchar(255) unique primary key);
"""
class TestLocking(TestCase):
"""
Test locking and unlocking a database row.
"""
def setUp(self):
"""
Build a connection pool for the tests to use.
"""
self.pool = buildConnectionPool(self, schemaText)
@inlineCallbacks
def test_acquire(self):
"""
Acquiring a lock adds a row in that transaction.
"""
txn = self.pool.connection()
yield NamedLock.acquire(txn, u"a test lock")
rows = yield Select(From=LockSchema.NAMED_LOCK).on(txn)
self.assertEquals(rows, [tuple([u"a test lock"])])
@inlineCallbacks
def test_release(self):
"""
Releasing an acquired lock removes the row.
"""
txn = self.pool.connection()
lck = yield NamedLock.acquire(txn, u"a test lock")
yield lck.release()
rows = yield Select(From=LockSchema.NAMED_LOCK).on(txn)
self.assertEquals(rows, [])
@inlineCallbacks
def test_autoRelease(self):
"""
Committing a transaction automatically releases all of its locks.
"""
txn = self.pool.connection()
yield NamedLock.acquire(txn, u"something")
yield txn.commit()
txn2 = self.pool.connection()
rows = yield Select(From=LockSchema.NAMED_LOCK).on(txn2)
self.assertEquals(rows, [])
@inlineCallbacks
def test_timeout(self):
"""
Trying to acquire second lock times out.
"""
txn1 = self.pool.connection()
yield NamedLock.acquire(txn1, u"a test lock")
txn2 = self.pool.connection()
yield self.assertFailure(NamedLock.acquire(txn2, u"a test lock"), LockTimeout)
yield txn2.abort()
self.flushLoggedErrors()
calendarserver-5.2+dfsg/twext/enterprise/test/test_adbapi2.py 0000644 0001750 0001750 00000140377 12263343324 023552 0 ustar rahul rahul ##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for L{twext.enterprise.adbapi2}.
"""
import gc
from zope.interface.verify import verifyObject
from twisted.python.failure import Failure
from twisted.trial.unittest import TestCase
from twisted.internet.defer import Deferred, fail, succeed, inlineCallbacks
from twisted.test.proto_helpers import StringTransport
from twext.enterprise.ienterprise import ConnectionError
from twext.enterprise.ienterprise import AlreadyFinishedError
from twext.enterprise.adbapi2 import ConnectionPoolClient
from twext.enterprise.adbapi2 import ConnectionPoolConnection
from twext.enterprise.ienterprise import IAsyncTransaction
from twext.enterprise.ienterprise import ICommandBlock
from twext.enterprise.adbapi2 import FailsafeException
from twext.enterprise.adbapi2 import ConnectionPool
from twext.enterprise.fixtures import ConnectionPoolHelper
from twext.enterprise.fixtures import resultOf
from twext.enterprise.fixtures import ClockWithThreads
from twext.enterprise.fixtures import FakeConnectionError
from twext.enterprise.fixtures import RollbackFail
from twext.enterprise.fixtures import CommitFail
from twext.enterprise.adbapi2 import Commit
from twext.enterprise.adbapi2 import _HookableOperation
class TrashCollector(object):
"""
Test helper for monitoring gc.garbage.
"""
def __init__(self, testCase):
self.testCase = testCase
testCase.addCleanup(self.checkTrash)
self.start()
def start(self):
gc.collect()
self.garbageStart = len(gc.garbage)
def checkTrash(self):
"""
Ensure that the test has added no additional garbage.
"""
gc.collect()
newGarbage = gc.garbage[self.garbageStart:]
if newGarbage:
# Don't clean up twice.
self.start()
self.testCase.fail("New garbage: " + repr(newGarbage))
class AssertResultHelper(object):
"""
Mixin for asserting about synchronous Deferred results.
"""
def assertResultList(self, resultList, expected):
"""
Assert that a list created with L{resultOf} contais the expected
result.
@param resultList: The return value of L{resultOf}.
@type resultList: L{list}
@param expected: The expected value that should be present in the list;
a L{Failure} if an exception is expected to be raised.
"""
if not resultList:
self.fail("No result; Deferred didn't fire yet.")
else:
if isinstance(resultList[0], Failure):
if isinstance(expected, Failure):
resultList[0].trap(expected.type)
else:
resultList[0].raiseException()
else:
self.assertEqual(resultList, [expected])
class ConnectionPoolBootTests(TestCase):
"""
Tests for the start-up phase of L{ConnectionPool}.
"""
def test_threadCount(self):
"""
The reactor associated with a L{ConnectionPool} will have its maximum
thread count adjusted when L{ConnectionPool.startService} is called, to
accomodate for L{ConnectionPool.maxConnections} additional threads.
Stopping the service should restore it to its original value, so that a
repeatedly re-started L{ConnectionPool} will not cause the thread
ceiling to grow without bound.
"""
defaultMax = 27
connsMax = 45
combinedMax = defaultMax + connsMax
pool = ConnectionPool(None, maxConnections=connsMax)
pool.reactor = ClockWithThreads()
threadpool = pool.reactor.getThreadPool()
pool.reactor.suggestThreadPoolSize(defaultMax)
self.assertEquals(threadpool.max, defaultMax)
pool.startService()
self.assertEquals(threadpool.max, combinedMax)
justChecking = []
pool.stopService().addCallback(justChecking.append)
# No SQL run, so no threads started, so this deferred should fire
# immediately. If not, we're in big trouble, so sanity check.
self.assertEquals(justChecking, [None])
self.assertEquals(threadpool.max, defaultMax)
def test_isRunning(self):
"""
L{ConnectionPool.startService} should set its C{running} attribute to
true.
"""
pool = ConnectionPool(None)
pool.reactor = ClockWithThreads()
self.assertEquals(pool.running, False)
pool.startService()
self.assertEquals(pool.running, True)
class ConnectionPoolTests(ConnectionPoolHelper, TestCase, AssertResultHelper):
"""
Tests for L{ConnectionPool}.
"""
def test_tooManyConnections(self):
"""
When the number of outstanding busy transactions exceeds the number of
slots specified by L{ConnectionPool.maxConnections},
L{ConnectionPool.connection} will return a pooled transaction that is
not backed by any real database connection; this object will queue its
SQL statements until an existing connection becomes available.
"""
a = self.createTransaction()
alphaResult = self.resultOf(a.execSQL("alpha"))
[[counter, echo]] = alphaResult[0]
b = self.createTransaction()
# 'b' should have opened a connection.
self.assertEquals(len(self.factory.connections), 2)
betaResult = self.resultOf(b.execSQL("beta"))
[[bcounter, becho]] = betaResult[0]
# both 'a' and 'b' are holding open a connection now; let's try to open
# a third one. (The ordering will be deterministic even if this fails,
# because those threads are already busy.)
c = self.createTransaction()
gammaResult = self.resultOf(c.execSQL("gamma"))
# Did 'c' open a connection? Let's hope not...
self.assertEquals(len(self.factory.connections), 2)
# SQL shouldn't be executed too soon...
self.assertEquals(gammaResult, [])
commitResult = self.resultOf(b.commit())
# Now that 'b' has committed, 'c' should be able to complete.
[[ccounter, cecho]] = gammaResult[0]
# The connection for 'a' ought to still be busy, so let's make sure
# we're using the one for 'c'.
self.assertEquals(ccounter, bcounter)
# Sanity check: the commit should have succeded!
self.assertEquals(commitResult, [None])
def test_stopService(self):
"""
L{ConnectionPool.stopService} stops all the associated L{ThreadHolder}s
and thereby frees up the resources it is holding.
"""
a = self.createTransaction()
alphaResult = self.resultOf(a.execSQL("alpha"))
[[[counter, echo]]] = alphaResult
self.assertEquals(len(self.factory.connections), 1)
self.assertEquals(len(self.holders), 1)
[holder] = self.holders
self.assertEquals(holder.started, True)
self.assertEquals(holder.stopped, False)
self.pool.stopService()
self.assertEquals(self.pool.running, False)
self.assertEquals(len(self.holders), 1)
self.assertEquals(holder.started, True)
self.assertEquals(holder.stopped, True)
# Closing fake connections removes them from the list.
self.assertEquals(len(self.factory.connections), 1)
self.assertEquals(self.factory.connections[0].closed, True)
def test_retryAfterConnectError(self):
"""
When the C{connectionFactory} passed to L{ConnectionPool} raises an
exception, the L{ConnectionPool} will log the exception and delay
execution of a new connection's SQL methods until an attempt succeeds.
"""
self.factory.willFail()
self.factory.willFail()
self.factory.willConnect()
c = self.createTransaction()
def checkOneFailure():
errors = self.flushLoggedErrors(FakeConnectionError)
self.assertEquals(len(errors), 1)
checkOneFailure()
d = c.execSQL("alpha")
happened = []
d.addBoth(happened.append)
self.assertEquals(happened, [])
self.clock.advance(self.pool.RETRY_TIMEOUT + 0.01)
checkOneFailure()
self.assertEquals(happened, [])
self.clock.advance(self.pool.RETRY_TIMEOUT + 0.01)
self.flushHolders()
self.assertEquals(happened, [[[1, "alpha"]]])
def test_shutdownDuringRetry(self):
"""
If a L{ConnectionPool} is attempting to shut down while it's in the
process of re-trying a connection attempt that received an error, the
connection attempt should be cancelled and the shutdown should complete
as normal.
"""
self.factory.defaultFail()
self.createTransaction()
errors = self.flushLoggedErrors(FakeConnectionError)
self.assertEquals(len(errors), 1)
stopd = []
self.pool.stopService().addBoth(stopd.append)
self.assertResultList(stopd, None)
self.assertEquals(self.clock.calls, [])
[holder] = self.holders
self.assertEquals(holder.started, True)
self.assertEquals(holder.stopped, True)
def test_shutdownDuringAttemptSuccess(self):
"""
If L{ConnectionPool.stopService} is called while a connection attempt
is outstanding, the resulting L{Deferred} won't be fired until the
connection attempt has finished; in this case, succeeded.
"""
self.pauseHolders()
self.createTransaction()
stopd = []
self.pool.stopService().addBoth(stopd.append)
self.assertEquals(stopd, [])
self.flushHolders()
self.assertResultList(stopd, None)
[holder] = self.holders
self.assertEquals(holder.started, True)
self.assertEquals(holder.stopped, True)
def test_shutdownDuringAttemptFailed(self):
"""
If L{ConnectionPool.stopService} is called while a connection attempt
is outstanding, the resulting L{Deferred} won't be fired until the
connection attempt has finished; in this case, failed.
"""
self.factory.defaultFail()
self.pauseHolders()
self.createTransaction()
stopd = []
self.pool.stopService().addBoth(stopd.append)
self.assertEquals(stopd, [])
self.flushHolders()
errors = self.flushLoggedErrors(FakeConnectionError)
self.assertEquals(len(errors), 1)
self.assertResultList(stopd, None)
[holder] = self.holders
self.assertEquals(holder.started, True)
self.assertEquals(holder.stopped, True)
def test_stopServiceMidAbort(self):
"""
When L{ConnectionPool.stopService} is called with deferreds from
C{abort} still outstanding, it will wait for the currently-aborting
transaction to fully abort before firing the L{Deferred} returned from
C{stopService}.
"""
# TODO: commit() too?
self.pauseHolders()
c = self.createTransaction()
abortResult = self.resultOf(c.abort())
# Should abort instantly, as it hasn't managed to unspool anything yet.
# FIXME: kill all Deferreds associated with this thing, make sure that
# any outstanding query callback chains get nuked.
self.assertEquals(abortResult, [None])
stopResult = self.resultOf(self.pool.stopService())
self.assertEquals(stopResult, [])
self.flushHolders()
#self.assertEquals(abortResult, [None])
self.assertResultList(stopResult, None)
def test_stopServiceWithSpooled(self):
"""
When L{ConnectionPool.stopService} is called when spooled transactions
are outstanding, any pending L{Deferreds} returned by those
transactions will be failed with L{ConnectionError}.
"""
# Use up the free slots so we have to spool.
hold = []
hold.append(self.createTransaction())
hold.append(self.createTransaction())
c = self.createTransaction()
se = self.resultOf(c.execSQL("alpha"))
ce = self.resultOf(c.commit())
self.assertEquals(se, [])
self.assertEquals(ce, [])
self.resultOf(self.pool.stopService())
self.assertEquals(se[0].type, self.translateError(ConnectionError))
self.assertEquals(ce[0].type, self.translateError(ConnectionError))
def test_repoolSpooled(self):
"""
Regression test for a somewhat tricky-to-explain bug: when a spooled
transaction which has already had commit() called on it before it's
received a real connection to start executing on, it will not leave
behind any detritus that prevents stopService from working.
"""
self.pauseHolders()
c = self.createTransaction()
c2 = self.createTransaction()
c3 = self.createTransaction()
c.commit()
c2.commit()
c3.commit()
self.flushHolders()
self.assertEquals(len(self.factory.connections), 2)
stopResult = self.resultOf(self.pool.stopService())
self.assertEquals(stopResult, [None])
self.assertEquals(len(self.factory.connections), 2)
self.assertEquals(self.factory.connections[0].closed, True)
self.assertEquals(self.factory.connections[1].closed, True)
def test_connectAfterStop(self):
"""
Calls to connection() after stopService() result in transactions which
immediately fail all operations.
"""
stopResults = self.resultOf(self.pool.stopService())
self.assertEquals(stopResults, [None])
self.pauseHolders()
postClose = self.createTransaction()
queryResult = self.resultOf(postClose.execSQL("hello"))
self.assertEquals(len(queryResult), 1)
self.assertEquals(queryResult[0].type,
self.translateError(ConnectionError))
def test_connectAfterStartedStopping(self):
"""
Calls to connection() after stopService() has been called but before it
has completed will result in transactions which immediately fail all
operations.
"""
self.pauseHolders()
preClose = self.createTransaction()
preCloseResult = self.resultOf(preClose.execSQL('statement'))
stopResult = self.resultOf(self.pool.stopService())
postClose = self.createTransaction()
queryResult = self.resultOf(postClose.execSQL("hello"))
self.assertEquals(stopResult, [])
self.assertEquals(len(queryResult), 1)
self.assertEquals(queryResult[0].type,
self.translateError(ConnectionError))
self.assertEquals(len(preCloseResult), 1)
self.assertEquals(preCloseResult[0].type,
self.translateError(ConnectionError))
def test_abortFailsDuringStopService(self):
"""
L{IAsyncTransaction.abort} might fail, most likely because the
underlying database connection has already been disconnected. If this
happens, shutdown should continue.
"""
txns = []
txns.append(self.createTransaction())
txns.append(self.createTransaction())
for txn in txns:
# Make sure rollback will actually be executed.
results = self.resultOf(txn.execSQL("maybe change something!"))
[[[counter, echo]]] = results
self.assertEquals("maybe change something!", echo)
# Fail one (and only one) call to rollback().
self.factory.rollbackFail = True
stopResult = self.resultOf(self.pool.stopService())
self.assertEquals(stopResult, [None])
self.assertEquals(len(self.flushLoggedErrors(RollbackFail)), 1)
self.assertEquals(self.factory.connections[0].closed, True)
self.assertEquals(self.factory.connections[1].closed, True)
def test_abortRecycledTransaction(self):
"""
L{ConnectionPool.stopService} will shut down if a recycled transaction
is still pending.
"""
recycled = self.createTransaction()
self.resultOf(recycled.commit())
remember = []
remember.append(self.createTransaction())
self.assertEquals(self.resultOf(self.pool.stopService()), [None])
def test_abortSpooled(self):
"""
Aborting a still-spooled transaction (one which has no statements being
executed) will result in all of its Deferreds immediately failing and
none of the queued statements being executed.
"""
active = []
# Use up the available connections ...
for i in xrange(self.pool.maxConnections):
active.append(self.createTransaction())
# ... so that this one has to be spooled.
spooled = self.createTransaction()
result = self.resultOf(spooled.execSQL("alpha"))
# sanity check, it would be bad if this actually executed.
self.assertEqual(result, [])
self.resultOf(spooled.abort())
self.assertEqual(result[0].type, self.translateError(ConnectionError))
def test_waitForAlreadyAbortedTransaction(self):
"""
L{ConnectionPool.stopService} will wait for all transactions to shut
down before exiting, including those which have already been stopped.
"""
it = self.createTransaction()
self.pauseHolders()
abortResult = self.resultOf(it.abort())
# steal it from the queue so we can do it out of order
d, work = self.holders[0]._q.get()
# that should be the only work unit so don't continue if something else
# got in there
self.assertEquals(list(self.holders[0]._q.queue), [])
self.assertEquals(len(self.holders), 1)
self.flushHolders()
stopResult = self.resultOf(self.pool.stopService())
# Sanity check that we haven't actually stopped it yet
self.assertEquals(abortResult, [])
# We haven't fired it yet, so the service had better not have
# stopped...
self.assertEquals(stopResult, [])
d.callback(None)
self.flushHolders()
self.assertEquals(abortResult, [None])
self.assertEquals(stopResult, [None])
def test_garbageCollectedTransactionAborts(self):
"""
When an L{IAsyncTransaction} is garbage collected, it ought to abort
itself.
"""
t = self.createTransaction()
self.resultOf(t.execSQL("echo", []))
conns = self.factory.connections
self.assertEquals(len(conns), 1)
self.assertEquals(conns[0]._rollbackCount, 0)
del t
gc.collect()
self.flushHolders()
self.assertEquals(len(conns), 1)
self.assertEquals(conns[0]._rollbackCount, 1)
self.assertEquals(conns[0]._commitCount, 0)
def circularReferenceTest(self, finish, hook):
"""
Collecting a completed (committed or aborted) L{IAsyncTransaction}
should not leak any circular references.
"""
tc = TrashCollector(self)
commitExecuted = []
def carefullyManagedScope():
t = self.createTransaction()
def holdAReference():
"""
This is a hook that holds a reference to 't'.
"""
commitExecuted.append(True)
return t.execSQL("teardown", [])
hook(t, holdAReference)
finish(t)
self.failIf(commitExecuted, "Commit hook executed.")
carefullyManagedScope()
tc.checkTrash()
def test_noGarbageOnCommit(self):
"""
Committing a transaction does not cause gc garbage.
"""
self.circularReferenceTest(lambda txn: txn.commit(),
lambda txn, hook: txn.preCommit(hook))
def test_noGarbageOnCommitWithAbortHook(self):
"""
Committing a transaction does not cause gc garbage.
"""
self.circularReferenceTest(lambda txn: txn.commit(),
lambda txn, hook: txn.postAbort(hook))
def test_noGarbageOnAbort(self):
"""
Aborting a transaction does not cause gc garbage.
"""
self.circularReferenceTest(lambda txn: txn.abort(),
lambda txn, hook: txn.preCommit(hook))
def test_noGarbageOnAbortWithPostCommitHook(self):
"""
Aborting a transaction does not cause gc garbage.
"""
self.circularReferenceTest(lambda txn: txn.abort(),
lambda txn, hook: txn.postCommit(hook))
def test_tooManyConnectionsWhileOthersFinish(self):
"""
L{ConnectionPool.connection} will not spawn more than the maximum
connections if there are finishing transactions outstanding.
"""
a = self.createTransaction()
b = self.createTransaction()
self.pauseHolders()
a.abort()
b.abort()
# Remove the holders for the existing connections, so that the 'extra'
# connection() call wins the race and gets executed first.
self.holders[:] = []
self.createTransaction()
self.flushHolders()
self.assertEquals(len(self.factory.connections), 2)
def setParamstyle(self, paramstyle):
"""
Change the paramstyle of the transaction under test.
"""
self.pool.paramstyle = paramstyle
def test_propagateParamstyle(self):
"""
Each different type of L{ISQLExecutor} relays the C{paramstyle}
attribute from the L{ConnectionPool}.
"""
TEST_PARAMSTYLE = "justtesting"
self.setParamstyle(TEST_PARAMSTYLE)
normaltxn = self.createTransaction()
self.assertEquals(normaltxn.paramstyle, TEST_PARAMSTYLE)
self.assertEquals(normaltxn.commandBlock().paramstyle, TEST_PARAMSTYLE)
self.pauseHolders()
extra = []
extra.append(self.createTransaction())
waitingtxn = self.createTransaction()
self.assertEquals(waitingtxn.paramstyle, TEST_PARAMSTYLE)
self.flushHolders()
self.pool.stopService()
notxn = self.createTransaction()
self.assertEquals(notxn.paramstyle, TEST_PARAMSTYLE)
def setDialect(self, dialect):
"""
Change the dialect of the transaction under test.
"""
self.pool.dialect = dialect
def test_propagateDialect(self):
"""
Each different type of L{ISQLExecutor} relays the C{dialect}
attribute from the L{ConnectionPool}.
"""
TEST_DIALECT = "otherdialect"
self.setDialect(TEST_DIALECT)
normaltxn = self.createTransaction()
self.assertEquals(normaltxn.dialect, TEST_DIALECT)
self.assertEquals(normaltxn.commandBlock().dialect, TEST_DIALECT)
self.pauseHolders()
extra = []
extra.append(self.createTransaction())
waitingtxn = self.createTransaction()
self.assertEquals(waitingtxn.dialect, TEST_DIALECT)
self.flushHolders()
self.pool.stopService()
notxn = self.createTransaction()
self.assertEquals(notxn.dialect, TEST_DIALECT)
def test_reConnectWhenFirstExecFails(self):
"""
Generally speaking, DB-API 2.0 adapters do not provide information
about the cause of a failed 'execute' method; they definitely don't
provide it in a way which can be identified as related to the syntax of
the query, the state of the database itself, the state of the
connection, etc.
Therefore the best general heuristic for whether the connection to the
database has been lost and needs to be re-established is to catch
exceptions which are raised by the I{first} statement executed in a
transaction.
"""
# Allow 'connect' to succeed. This should behave basically the same
# whether connect() happened to succeed in some previous transaction
# and it's recycling the underlying transaction, or connect() just
# succeeded. Either way you just have a _SingleTxn wrapping a
# _ConnectedTxn.
txn = self.createTransaction()
self.assertEquals(len(self.factory.connections), 1,
"Sanity check failed.")
class CustomExecuteFailed(Exception):
"""
Custom 'execute-failed' exception.
"""
self.factory.connections[0].executeWillFail(CustomExecuteFailed)
results = self.resultOf(txn.execSQL("hello, world!"))
[[[counter, echo]]] = results
self.assertEquals("hello, world!", echo)
# Two execution attempts should have been made, one on each connection.
# The first failed with a RuntimeError, but that is deliberately
# obscured, because then we tried again and it succeeded.
self.assertEquals(len(self.factory.connections), 2,
"No new connection opened.")
self.assertEquals(self.factory.connections[0].executions, 1)
self.assertEquals(self.factory.connections[1].executions, 1)
self.assertEquals(self.factory.connections[0].closed, True)
self.assertEquals(self.factory.connections[1].closed, False)
# Nevertheless, since there is currently no classification of 'safe'
# errors, we should probably log these messages when they occur.
self.assertEquals(len(self.flushLoggedErrors(CustomExecuteFailed)), 1)
def test_reConnectWhenFirstExecOnExistingConnectionFails(
self, moreFailureSetup=lambda factory: None):
"""
Another situation that might arise is that a connection will be
successfully connected, executed and recycled into the connection pool;
then, the database server will shut down and the connections will die,
but we will be none the wiser until we try to use them.
"""
txn = self.createTransaction()
moreFailureSetup(self.factory)
self.assertEquals(len(self.factory.connections), 1,
"Sanity check failed.")
results = self.resultOf(txn.execSQL("hello, world!"))
txn.commit()
[[[counter, echo]]] = results
self.assertEquals("hello, world!", echo)
txn2 = self.createTransaction()
self.assertEquals(len(self.factory.connections), 1,
"Sanity check failed.")
class CustomExecFail(Exception):
"""
Custom 'execute()' failure.
"""
self.factory.connections[0].executeWillFail(CustomExecFail)
results = self.resultOf(txn2.execSQL("second try!"))
txn2.commit()
[[[counter, echo]]] = results
self.assertEquals("second try!", echo)
self.assertEquals(len(self.flushLoggedErrors(CustomExecFail)), 1)
def test_closeExceptionDoesntHinderReconnection(self):
"""
In some database bindings, if the server closes the connection,
C{close()} will fail. If C{close} fails, there's not much that could
mean except that the connection is already closed, so similar to the
condition described in
L{test_reConnectWhenFirstExecOnExistingConnectionFails}, the
failure should be logged, but transparent to application code.
"""
class BindingSpecificException(Exception):
"""
Exception that's a placeholder for something that a database
binding might raise.
"""
def alsoFailClose(factory):
factory.childCloseWillFail(BindingSpecificException())
t = self.test_reConnectWhenFirstExecOnExistingConnectionFails(
alsoFailClose
)
errors = self.flushLoggedErrors(BindingSpecificException)
self.assertEquals(len(errors), 1)
return t
def test_preCommitSuccess(self):
"""
Callables passed to L{IAsyncTransaction.preCommit} will be invoked upon
commit.
"""
txn = self.createTransaction()
def simple():
simple.done = True
simple.done = False
txn.preCommit(simple)
self.assertEquals(simple.done, False)
result = self.resultOf(txn.commit())
self.assertEquals(len(result), 1)
self.assertEquals(simple.done, True)
def test_deferPreCommit(self):
"""
If callables passed to L{IAsyncTransaction.preCommit} return
L{Deferred}s, they will defer the actual commit operation until it has
fired.
"""
txn = self.createTransaction()
d = Deferred()
def wait():
wait.started = True
def executed(it):
wait.sqlResult = it
# To make sure the _underlying_ commit operation was Deferred, we
# have to execute some SQL to make sure it happens.
return (d.addCallback(lambda ignored: txn.execSQL("some test sql"))
.addCallback(executed))
wait.started = False
wait.sqlResult = None
txn.preCommit(wait)
result = self.resultOf(txn.commit())
self.flushHolders()
self.assertEquals(wait.started, True)
self.assertEquals(wait.sqlResult, None)
self.assertEquals(result, [])
d.callback(None)
# allow network I/O for pooled / networked implementation; there should
# be the commit message now.
self.flushHolders()
self.assertEquals(len(result), 1)
self.assertEquals(wait.sqlResult, [[1, "some test sql"]])
def test_failPreCommit(self):
"""
If callables passed to L{IAsyncTransaction.preCommit} raise an
exception or return a Failure, subsequent callables will not be run,
and the transaction will be aborted.
"""
def test(flawedCallable, exc):
# Set up.
test.committed = False
test.aborted = False
# Create transaction and add monitoring hooks.
txn = self.createTransaction()
def didCommit():
test.committed = True
def didAbort():
test.aborted = True
txn.postCommit(didCommit)
txn.postAbort(didAbort)
txn.preCommit(flawedCallable)
result = self.resultOf(txn.commit())
self.flushHolders()
self.assertResultList(result, Failure(exc()))
self.assertEquals(test.committed, False)
self.assertEquals(test.aborted, True)
def failer():
return fail(ZeroDivisionError())
def raiser():
raise EOFError()
test(failer, ZeroDivisionError)
test(raiser, EOFError)
def test_noOpCommitDoesntHinderReconnection(self):
"""
Until you've executed a query or performed a statement on an ADBAPI
connection, the connection is semantically idle (between transactions).
A .commit() or .rollback() followed immediately by a .commit() is
therefore pointless, and can be ignored. Furthermore, actually
executing the commit and propagating a possible connection-oriented
error causes clients to see errors, when, if those clients had actually
executed any statements, the connection would have been recycled and
the statement transparently re-executed by the logic tested by
L{test_reConnectWhenFirstExecFails}.
"""
txn = self.createTransaction()
self.factory.commitFail = True
self.factory.rollbackFail = True
[x] = self.resultOf(txn.commit())
# No statements have been executed, so 'commit' will *not* be executed.
self.assertEquals(self.factory.commitFail, True)
self.assertIdentical(x, None)
self.assertEquals(len(self.pool._free), 1)
self.assertEquals(self.pool._finishing, [])
self.assertEquals(len(self.factory.connections), 1)
self.assertEquals(self.factory.connections[0].closed, False)
def test_reConnectWhenSecondExecFailsThenFirstExecFails(self):
"""
Other connection-oriented errors might raise exceptions if they occur
in the middle of a transaction, but that should cause the error to be
caught, the transaction to be aborted, and the (closed) connection to
be recycled, where the next transaction that attempts to do anything
with it will encounter the error immediately and discover it needs to
be recycled.
It would be better if this behavior were invisible, but that could only
be accomplished with more precise database exceptions. We may come up
with support in the future for more precisely identifying exceptions,
but I{unknown} exceptions should continue to be treated in this manner,
relaying the exception back to application code but attempting a
re-connection on the next try.
"""
txn = self.createTransaction()
[[[counter, echo]]] = self.resultOf(txn.execSQL("hello, world!", []))
self.factory.connections[0].executeWillFail(ZeroDivisionError)
[f] = self.resultOf(txn.execSQL("divide by zero", []))
f.trap(self.translateError(ZeroDivisionError))
self.assertEquals(self.factory.connections[0].executions, 2)
# Reconnection should work exactly as before.
self.assertEquals(self.factory.connections[0].closed, False)
# Application code has to roll back its transaction at this point,
# since it failed (and we don't necessarily know why it failed: not
# enough information).
self.resultOf(txn.abort())
self.factory.connections[0].executions = 0 # re-set for next test
self.assertEquals(len(self.factory.connections), 1)
self.test_reConnectWhenFirstExecFails()
def test_disconnectOnFailedRollback(self):
"""
When C{rollback} fails for any reason on a connection object, then we
don't know what state it's in. Most likely, it's already been
disconnected, so the connection should be closed and the transaction
de-pooled instead of recycled.
Also, a new connection will immediately be established to keep the pool
size the same.
"""
txn = self.createTransaction()
results = self.resultOf(txn.execSQL("maybe change something!"))
[[[counter, echo]]] = results
self.assertEquals("maybe change something!", echo)
self.factory.rollbackFail = True
[x] = self.resultOf(txn.abort())
# Abort does not propagate the error on, the transaction merely gets
# disposed of.
self.assertIdentical(x, None)
self.assertEquals(len(self.pool._free), 1)
self.assertEquals(self.pool._finishing, [])
self.assertEquals(len(self.factory.connections), 2)
self.assertEquals(self.factory.connections[0].closed, True)
self.assertEquals(self.factory.connections[1].closed, False)
self.assertEquals(len(self.flushLoggedErrors(RollbackFail)), 1)
def test_exceptionPropagatesFailedCommit(self):
"""
A failed C{rollback} is fine (the premature death of the connection
without C{commit} means that the changes are surely gone), but a failed
C{commit} has to be relayed to client code, since that actually means
some changes didn't hit the database.
"""
txn = self.createTransaction()
self.factory.commitFail = True
results = self.resultOf(txn.execSQL("maybe change something!"))
[[[counter, echo]]] = results
self.assertEquals("maybe change something!", echo)
[x] = self.resultOf(txn.commit())
x.trap(self.translateError(CommitFail))
self.assertEquals(len(self.pool._free), 1)
self.assertEquals(self.pool._finishing, [])
self.assertEquals(len(self.factory.connections), 2)
self.assertEquals(self.factory.connections[0].closed, True)
self.assertEquals(self.factory.connections[1].closed, False)
def test_commandBlock(self):
"""
L{IAsyncTransaction.commandBlock} returns an L{IAsyncTransaction}
provider which ensures that a block of commands are executed together.
"""
txn = self.createTransaction()
a = self.resultOf(txn.execSQL("a"))
cb = txn.commandBlock()
verifyObject(ICommandBlock, cb)
b = self.resultOf(cb.execSQL("b"))
d = self.resultOf(txn.execSQL("d"))
c = self.resultOf(cb.execSQL("c"))
cb.end()
e = self.resultOf(txn.execSQL("e"))
self.assertEquals(self.factory.connections[0].cursors[0].allExecutions,
[("a", []), ("b", []), ("c", []), ("d", []),
("e", [])])
self.assertEquals(len(a), 1)
self.assertEquals(len(b), 1)
self.assertEquals(len(c), 1)
self.assertEquals(len(d), 1)
self.assertEquals(len(e), 1)
def test_commandBlockWithLatency(self):
"""
A block returned by L{IAsyncTransaction.commandBlock} won't start
executing until all SQL statements scheduled before it have completed.
"""
self.pauseHolders()
txn = self.createTransaction()
a = self.resultOf(txn.execSQL("a"))
b = self.resultOf(txn.execSQL("b"))
cb = txn.commandBlock()
c = self.resultOf(cb.execSQL("c"))
d = self.resultOf(cb.execSQL("d"))
e = self.resultOf(txn.execSQL("e"))
cb.end()
self.flushHolders()
self.assertEquals(self.factory.connections[0].cursors[0].allExecutions,
[("a", []), ("b", []), ("c", []), ("d", []),
("e", [])])
self.assertEquals(len(a), 1)
self.assertEquals(len(b), 1)
self.assertEquals(len(c), 1)
self.assertEquals(len(d), 1)
self.assertEquals(len(e), 1)
def test_twoCommandBlocks(self, flush=lambda: None):
"""
When execution of one command block is complete, it will proceed to the
next queued block, then to regular SQL executed on the transaction.
"""
txn = self.createTransaction()
cb1 = txn.commandBlock()
cb2 = txn.commandBlock()
txn.execSQL("e")
cb1.execSQL("a")
cb2.execSQL("c")
cb1.execSQL("b")
cb2.execSQL("d")
cb2.end()
cb1.end()
flush()
self.flushHolders()
self.assertEquals(self.factory.connections[0].cursors[0].allExecutions,
[("a", []), ("b", []), ("c", []), ("d", []),
("e", [])])
def test_twoCommandBlocksLatently(self):
"""
Same as L{test_twoCommandBlocks}, but with slower callbacks.
"""
self.pauseHolders()
self.test_twoCommandBlocks(self.flushHolders)
def test_commandBlockEndTwice(self):
"""
L{CommandBlock.end} will raise L{AlreadyFinishedError} when called more
than once.
"""
txn = self.createTransaction()
block = txn.commandBlock()
block.end()
self.assertRaises(AlreadyFinishedError, block.end)
def test_commandBlockDelaysCommit(self):
"""
Some command blocks need to run asynchronously, without the overall
transaction-managing code knowing how far they've progressed.
Therefore when you call {IAsyncTransaction.commit}(), it should not
actually take effect if there are any pending command blocks.
"""
txn = self.createTransaction()
block = txn.commandBlock()
commitResult = self.resultOf(txn.commit())
self.resultOf(block.execSQL("in block"))
self.assertEquals(commitResult, [])
self.assertEquals(self.factory.connections[0].cursors[0].allExecutions,
[("in block", [])])
block.end()
self.flushHolders()
self.assertEquals(commitResult, [None])
def test_commandBlockDoesntDelayAbort(self):
"""
A L{CommandBlock} can't possibly have anything interesting to say about
a transaction that gets rolled back, so C{abort} applies immediately;
all outstanding C{execSQL}s will fail immediately, on both command
blocks and on the transaction itself.
"""
txn = self.createTransaction()
block = txn.commandBlock()
block2 = txn.commandBlock()
abortResult = self.resultOf(txn.abort())
self.assertEquals(abortResult, [None])
self.assertRaises(AlreadyFinishedError, block2.execSQL, "bar")
self.assertRaises(AlreadyFinishedError, block.execSQL, "foo")
self.assertRaises(AlreadyFinishedError, txn.execSQL, "baz")
self.assertEquals(self.factory.connections[0].cursors[0].allExecutions,
[])
# end() should _not_ raise an exception, because this is the sort of
# thing that might be around a try/finally or try/except; it's just
# putting the commandBlock itself into a state consistent with the
# transaction.
block.end()
block2.end()
def test_endedBlockDoesntExecuteMoreSQL(self):
"""
Attempting to execute SQL on a L{CommandBlock} which has had C{end}
called on it will result in an L{AlreadyFinishedError}.
"""
txn = self.createTransaction()
block = txn.commandBlock()
block.end()
self.assertRaises(AlreadyFinishedError, block.execSQL, "hello")
self.assertEquals(self.factory.connections[0].cursors[0].allExecutions,
[])
def test_commandBlockAfterCommitRaises(self):
"""
Once an L{IAsyncTransaction} has been committed, L{commandBlock} raises
an exception.
"""
txn = self.createTransaction()
txn.commit()
self.assertRaises(AlreadyFinishedError, txn.commandBlock)
def test_commandBlockAfterAbortRaises(self):
"""
Once an L{IAsyncTransaction} has been committed, L{commandBlock} raises
an exception.
"""
txn = self.createTransaction()
self.resultOf(txn.abort())
self.assertRaises(AlreadyFinishedError, txn.commandBlock)
def test_raiseOnZeroRowCount(self):
"""
L{IAsyncTransaction.execSQL} will return a L{Deferred} failing with the
exception passed as its raiseOnZeroRowCount argument if the underlying
query returns no rows.
"""
self.factory.hasResults = False
txn = self.createTransaction()
f = self.resultOf(
txn.execSQL("hello", raiseOnZeroRowCount=ZeroDivisionError)
)[0]
self.assertRaises(ZeroDivisionError, f.raiseException)
txn.commit()
def test_raiseOnZeroRowCountWithUnreliableRowCount(self):
"""
As it turns out, some databases can't reliably tell you how many rows
they're going to fetch via the C{rowcount} attribute before the rows
have actually been fetched, so the C{raiseOnZeroRowCount} will I{not}
raise an exception if C{rowcount} is zero but C{description} and
C{fetchall} indicates the presence of some rows.
"""
self.factory.hasResults = True
self.factory.shouldUpdateRowcount = False
txn = self.createTransaction()
r = self.resultOf(
txn.execSQL("some-rows", raiseOnZeroRowCount=RuntimeError)
)
[[[counter, echo]]] = r
self.assertEquals(echo, "some-rows")
class IOPump(object):
"""
Connect a client and a server.
@ivar client: a client protocol
@ivar server: a server protocol
"""
def __init__(self, client, server):
self.client = client
self.server = server
self.clientTransport = StringTransport()
self.serverTransport = StringTransport()
self.client.makeConnection(self.clientTransport)
self.server.makeConnection(self.serverTransport)
self.c2s = [self.clientTransport, self.server]
self.s2c = [self.serverTransport, self.client]
def moveData(self, (outTransport, inProtocol)):
"""
Move data from a L{StringTransport} to an L{IProtocol}.
@return: C{True} if any data was moved, C{False} if no data was moved.
"""
data = outTransport.io.getvalue()
outTransport.io.seek(0)
outTransport.io.truncate()
if data:
inProtocol.dataReceived(data)
return True
else:
return False
def pump(self):
"""
Deliver all input from the client to the server, then from the server
to the client.
"""
a = self.moveData(self.c2s)
b = self.moveData(self.s2c)
return a or b
def flush(self, maxTurns=100):
"""
Continue pumping until no more data is flowing.
"""
turns = 0
while self.pump():
turns += 1
if turns > maxTurns:
raise RuntimeError("Ran too long!")
class NetworkedPoolHelper(ConnectionPoolHelper):
"""
An extension of L{ConnectionPoolHelper} that can set up a
L{ConnectionPoolClient} and L{ConnectionPoolConnection} attached to each
other.
"""
def setUp(self):
"""
Do the same setup from L{ConnectionPoolBase}, but also establish a
loopback connection between a L{ConnectionPoolConnection} and a
L{ConnectionPoolClient}.
"""
super(NetworkedPoolHelper, self).setUp()
self.pump = IOPump(ConnectionPoolClient(dialect=self.dialect,
paramstyle=self.paramstyle),
ConnectionPoolConnection(self.pool))
def flushHolders(self):
"""
In addition to flushing the L{ThreadHolder} stubs, also flush any
pending network I/O.
"""
self.pump.flush()
super(NetworkedPoolHelper, self).flushHolders()
self.pump.flush()
def createTransaction(self):
txn = self.pump.client.newTransaction()
self.pump.flush()
return txn
def translateError(self, err):
"""
All errors raised locally will unfortunately be translated into
UnknownRemoteError, since AMP requires specific enumeration of all of
them. Flush the locally logged error of the given type and return
L{UnknownRemoteError}.
"""
if err in Commit.errors:
return err
self.flushLoggedErrors(err)
return FailsafeException
def resultOf(self, it):
result = resultOf(it)
self.pump.flush()
return result
class NetworkedConnectionPoolTests(NetworkedPoolHelper, ConnectionPoolTests):
"""
Tests for L{ConnectionPoolConnection} and L{ConnectionPoolClient}
interacting with each other.
"""
def setParamstyle(self, paramstyle):
"""
Change the paramstyle on both the pool and the client.
"""
super(NetworkedConnectionPoolTests, self).setParamstyle(paramstyle)
self.pump.client.paramstyle = paramstyle
def setDialect(self, dialect):
"""
Change the dialect on both the pool and the client.
"""
super(NetworkedConnectionPoolTests, self).setDialect(dialect)
self.pump.client.dialect = dialect
def test_newTransaction(self):
"""
L{ConnectionPoolClient.newTransaction} returns a provider of
L{IAsyncTransaction}, and creates a new transaction on the server side.
"""
txn = self.pump.client.newTransaction()
verifyObject(IAsyncTransaction, txn)
self.pump.flush()
self.assertEquals(len(self.factory.connections), 1)
class HookableOperationTests(TestCase):
"""
Tests for L{_HookableOperation}.
"""
@inlineCallbacks
def test_clearPreventsSubsequentAddHook(self):
"""
After clear() or runHooks() are called, subsequent calls to addHook()
are NO-OPs.
"""
def hook():
return succeed(None)
hookOp = _HookableOperation()
hookOp.addHook(hook)
self.assertEquals(len(hookOp._hooks), 1)
hookOp.clear()
self.assertEquals(hookOp._hooks, None)
hookOp = _HookableOperation()
hookOp.addHook(hook)
yield hookOp.runHooks()
self.assertEquals(hookOp._hooks, None)
hookOp.addHook(hook)
self.assertEquals(hookOp._hooks, None)
calendarserver-5.2+dfsg/twext/enterprise/test/__init__.py 0000644 0001750 0001750 00000001207 12263343324 022734 0 ustar rahul rahul
##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for L{twext.enterprise}.
"""
calendarserver-5.2+dfsg/twext/enterprise/fixtures.py 0000644 0001750 0001750 00000035254 12263343324 022100 0 ustar rahul rahul # -*- test-case-name: twext.enterprise.test.test_fixtures -*-
##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Fixtures for testing code that uses ADBAPI2.
"""
import sqlite3
from Queue import Empty
from itertools import count
from zope.interface import implementer
from zope.interface.verify import verifyClass
from twisted.internet.interfaces import IReactorThreads
from twisted.python.threadpool import ThreadPool
from twisted.internet.task import Clock
from twext.enterprise.adbapi2 import ConnectionPool
from twext.enterprise.ienterprise import SQLITE_DIALECT
from twext.enterprise.ienterprise import POSTGRES_DIALECT
from twext.enterprise.adbapi2 import DEFAULT_PARAM_STYLE
from twext.internet.threadutils import ThreadHolder
def buildConnectionPool(testCase, schemaText="", dialect=SQLITE_DIALECT):
"""
Build a L{ConnectionPool} for testing purposes, with the given C{testCase}.
@param testCase: the test case to attach the resulting L{ConnectionPool}
to.
@type testCase: L{twisted.trial.unittest.TestCase}
@param schemaText: The text of the schema with which to initialize the
database.
@type schemaText: L{str}
@return: a L{ConnectionPool} service whose C{startService} method has
already been invoked.
@rtype: L{ConnectionPool}
"""
sqlitename = testCase.mktemp()
seqs = {}
def connectionFactory(label=testCase.id()):
conn = sqlite3.connect(sqlitename)
def nextval(seq):
result = seqs[seq] = seqs.get(seq, 0) + 1
return result
conn.create_function("nextval", 1, nextval)
return conn
con = connectionFactory()
con.executescript(schemaText)
con.commit()
pool = ConnectionPool(connectionFactory, paramstyle='numeric',
dialect=SQLITE_DIALECT)
pool.startService()
testCase.addCleanup(pool.stopService)
return pool
def resultOf(deferred, propagate=False):
"""
Add a callback and errback which will capture the result of a L{Deferred}
in a list, and return that list. If 'propagate' is True, pass through the
results.
"""
results = []
if propagate:
def cb(r):
results.append(r)
return r
else:
cb = results.append
deferred.addBoth(cb)
return results
class FakeThreadHolder(ThreadHolder):
"""
Run things to submitted this ThreadHolder on the main thread, so that
execution is easier to control.
"""
def __init__(self, test):
super(FakeThreadHolder, self).__init__(self)
self.test = test
self.started = False
self.stopped = False
self._workerIsRunning = False
def start(self):
self.started = True
return super(FakeThreadHolder, self).start()
def stop(self):
result = super(FakeThreadHolder, self).stop()
self.stopped = True
return result
@property
def _get_q(self):
return self._q_
@_get_q.setter
def _q(self, newq):
if newq is not None:
oget = newq.get
newq.get = lambda: oget(timeout=0)
oput = newq.put
def putit(x):
p = oput(x)
if not self.test.paused:
self.flush()
return p
newq.put = putit
self._q_ = newq
def callFromThread(self, f, *a, **k):
result = f(*a, **k)
return result
def callInThread(self, f, *a, **k):
"""
This should be called only once, to start the worker function that
dedicates a thread to this L{ThreadHolder}.
"""
self._workerIsRunning = True
def flush(self):
"""
Fire all deferreds previously returned from submit.
"""
try:
while self._workerIsRunning and self._qpull():
pass
else:
self._workerIsRunning = False
except Empty:
pass
@implementer(IReactorThreads)
class ClockWithThreads(Clock):
"""
A testing reactor that supplies L{IReactorTime} and L{IReactorThreads}.
"""
def __init__(self):
super(ClockWithThreads, self).__init__()
self._pool = ThreadPool()
def getThreadPool(self):
"""
Get the threadpool.
"""
return self._pool
def suggestThreadPoolSize(self, size):
"""
Approximate the behavior of a 'real' reactor.
"""
self._pool.adjustPoolsize(maxthreads=size)
def callInThread(self, thunk, *a, **kw):
"""
No implementation.
"""
def callFromThread(self, thunk, *a, **kw):
"""
No implementation.
"""
verifyClass(IReactorThreads, ClockWithThreads)
class ConnectionPoolHelper(object):
"""
Connection pool setting-up facilities for tests that need a
L{ConnectionPool}.
"""
dialect = POSTGRES_DIALECT
paramstyle = DEFAULT_PARAM_STYLE
def setUp(self, test=None, connect=None):
"""
Support inheritance by L{TestCase} classes.
"""
if test is None:
test = self
if connect is None:
self.factory = ConnectionFactory()
connect = self.factory.connect
self.connect = connect
self.paused = False
self.holders = []
self.pool = ConnectionPool(connect,
maxConnections=2,
dialect=self.dialect,
paramstyle=self.paramstyle)
self.pool._createHolder = self.makeAHolder
self.clock = self.pool.reactor = ClockWithThreads()
self.pool.startService()
test.addCleanup(self.flushHolders)
def flushHolders(self):
"""
Flush all pending C{submit}s since C{pauseHolders} was called. This
makes sure the service is stopped and the fake ThreadHolders are all
executing their queues so failed tsets can exit cleanly.
"""
self.paused = False
for holder in self.holders:
holder.flush()
def pauseHolders(self):
"""
Pause all L{FakeThreadHolder}s, causing C{submit} to return an unfired
L{Deferred}.
"""
self.paused = True
def makeAHolder(self):
"""
Make a ThreadHolder-alike.
"""
fth = FakeThreadHolder(self)
self.holders.append(fth)
return fth
def resultOf(self, it):
return resultOf(it)
def createTransaction(self):
return self.pool.connection()
def translateError(self, err):
return err
class SteppablePoolHelper(ConnectionPoolHelper):
"""
A version of L{ConnectionPoolHelper} that can set up a connection pool
capable of firing all its L{Deferred}s on demand, synchronously, by using
SQLite.
"""
dialect = SQLITE_DIALECT
paramstyle = sqlite3.paramstyle
def __init__(self, schema):
self.schema = schema
def setUp(self, test):
connect = synchronousConnectionFactory(test)
con = connect()
cur = con.cursor()
cur.executescript(self.schema)
con.commit()
super(SteppablePoolHelper, self).setUp(test, connect)
def rows(self, sql):
"""
Get some rows from the database to compare in a test.
"""
con = self.connect()
cur = con.cursor()
cur.execute(sql)
result = cur.fetchall()
con.commit()
return result
def synchronousConnectionFactory(test):
tmpdb = test.mktemp()
def connect():
return sqlite3.connect(tmpdb)
return connect
class Child(object):
"""
An object with a L{Parent}, in its list of C{children}.
"""
def __init__(self, parent):
self.closed = False
self.parent = parent
self.parent.children.append(self)
def close(self):
if self.parent._closeFailQueue:
raise self.parent._closeFailQueue.pop(0)
self.closed = True
class Parent(object):
"""
An object with a list of L{Child}ren.
"""
def __init__(self):
self.children = []
self._closeFailQueue = []
def childCloseWillFail(self, exception):
"""
Closing children of this object will result in the given exception.
@see: L{ConnectionFactory}
"""
self._closeFailQueue.append(exception)
class FakeConnection(Parent, Child):
"""
Fake Stand-in for DB-API 2.0 connection.
@ivar executions: the number of statements which have been executed.
"""
executions = 0
def __init__(self, factory):
"""
Initialize list of cursors
"""
Parent.__init__(self)
Child.__init__(self, factory)
self.id = factory.idcounter.next()
self._executeFailQueue = []
self._commitCount = 0
self._rollbackCount = 0
def executeWillFail(self, thunk):
"""
The next call to L{FakeCursor.execute} will fail with an exception
returned from the given callable.
"""
self._executeFailQueue.append(thunk)
@property
def cursors(self):
"Alias to make tests more readable."
return self.children
def cursor(self):
return FakeCursor(self)
def commit(self):
self._commitCount += 1
if self.parent.commitFail:
self.parent.commitFail = False
raise CommitFail()
def rollback(self):
self._rollbackCount += 1
if self.parent.rollbackFail:
self.parent.rollbackFail = False
raise RollbackFail()
class RollbackFail(Exception):
"""
Sample rollback-failure exception.
"""
class CommitFail(Exception):
"""
Sample Commit-failure exception.
"""
class FakeCursor(Child):
"""
Fake stand-in for a DB-API 2.0 cursor.
"""
def __init__(self, connection):
Child.__init__(self, connection)
self.rowcount = 0
# not entirely correct, but all we care about is its truth value.
self.description = False
self.variables = []
self.allExecutions = []
@property
def connection(self):
"Alias to make tests more readable."
return self.parent
def execute(self, sql, args=()):
self.connection.executions += 1
if self.connection._executeFailQueue:
raise self.connection._executeFailQueue.pop(0)()
self.allExecutions.append((sql, args))
self.sql = sql
factory = self.connection.parent
self.description = factory.hasResults
if factory.hasResults and factory.shouldUpdateRowcount:
self.rowcount = 1
else:
self.rowcount = 0
return
def var(self, type, *args):
"""
Return a database variable in the style of the cx_Oracle bindings.
"""
v = FakeVariable(self, type, args)
self.variables.append(v)
return v
def fetchall(self):
"""
Just echo the SQL that was executed in the last query.
"""
if self.connection.parent.hasResults:
return [[self.connection.id, self.sql]]
if self.description:
return []
return None
class FakeVariable(object):
def __init__(self, cursor, type, args):
self.cursor = cursor
self.type = type
self.args = args
def getvalue(self):
vv = self.cursor.connection.parent.varvals
if vv:
return vv.pop(0)
return self.cursor.variables.index(self) + 300
def __reduce__(self):
raise RuntimeError("Not pickleable (since oracle vars aren't)")
class ConnectionFactory(Parent):
"""
A factory for L{FakeConnection} objects.
@ivar shouldUpdateRowcount: Should C{execute} on cursors produced by
connections produced by this factory update their C{rowcount} or just
their C{description} attribute?
@ivar hasResults: should cursors produced by connections by this factory
have any results returned by C{fetchall()}?
"""
rollbackFail = False
commitFail = False
def __init__(self, shouldUpdateRowcount=True, hasResults=True):
Parent.__init__(self)
self.idcounter = count(1)
self._connectResultQueue = []
self.defaultConnect()
self.varvals = []
self.shouldUpdateRowcount = shouldUpdateRowcount
self.hasResults = hasResults
@property
def connections(self):
"Alias to make tests more readable."
return self.children
def connect(self):
"""
Implement the C{ConnectionFactory} callable expected by
L{ConnectionPool}.
"""
if self._connectResultQueue:
thunk = self._connectResultQueue.pop(0)
else:
thunk = self._default
return thunk()
def willConnect(self):
"""
Used by tests to queue a successful result for connect().
"""
def thunk():
return FakeConnection(self)
self._connectResultQueue.append(thunk)
def willConnectTo(self):
"""
Queue a successful result for connect() and immediately add it as a
child to this L{ConnectionFactory}.
@return: a connection object
@rtype: L{FakeConnection}
"""
aConnection = FakeConnection(self)
def thunk():
return aConnection
self._connectResultQueue.append(thunk)
return aConnection
def willFail(self):
"""
Used by tests to queue a successful result for connect().
"""
def thunk():
raise FakeConnectionError()
self._connectResultQueue.append(thunk)
def defaultConnect(self):
"""
By default, connection attempts will succeed.
"""
self.willConnect()
self._default = self._connectResultQueue.pop()
def defaultFail(self):
"""
By default, connection attempts will fail.
"""
self.willFail()
self._default = self._connectResultQueue.pop()
class FakeConnectionError(Exception):
"""
Synthetic error that might occur during connection.
"""
calendarserver-5.2+dfsg/twext/enterprise/adbapi2.py 0000644 0001750 0001750 00000152430 12263343324 021525 0 ustar rahul rahul # -*- test-case-name: twext.enterprise.test.test_adbapi2 -*-
##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Asynchronous multi-process connection pool.
This is similar to L{twisted.enterprise.adbapi}, but can hold a transaction
(and thereby a thread) open across multiple asynchronous operations, rather
than forcing the transaction to be completed entirely in a thread and/or
entirely in a single SQL statement.
Also, this module includes an AMP protocol for multiplexing connections through
a single choke-point host. This is not currently in use, however, as AMP needs
some optimization before it can be low-impact enough for this to be an
improvement.
"""
import sys
import weakref
from cStringIO import StringIO
from cPickle import dumps, loads
from itertools import count
from zope.interface import implements
from twisted.internet.defer import inlineCallbacks
from twisted.internet.defer import returnValue
from twisted.internet.defer import DeferredList
from twisted.internet.defer import Deferred
from twisted.protocols.amp import Boolean
from twisted.python.failure import Failure
from twisted.protocols.amp import Argument, String, Command, AMP, Integer
from twisted.internet import reactor as _reactor
from twisted.application.service import Service
from twisted.python import log
from twisted.internet.defer import maybeDeferred
from twisted.python.components import proxyForInterface
from twext.internet.threadutils import ThreadHolder
from twisted.internet.defer import succeed
from twext.enterprise.ienterprise import ConnectionError
from twext.enterprise.ienterprise import IDerivedParameter
from twisted.internet.defer import fail
from twext.enterprise.ienterprise import (
AlreadyFinishedError, IAsyncTransaction, POSTGRES_DIALECT, ICommandBlock
)
# FIXME: there should be no defaults for connection metadata, it should be
# discovered dynamically everywhere. Right now it's specified as an explicit
# argument to the ConnectionPool but it should probably be determined
# automatically from the database binding.
DEFAULT_PARAM_STYLE = 'pyformat'
DEFAULT_DIALECT = POSTGRES_DIALECT
def _forward(thunk):
"""
Forward an attribute to the connection pool.
"""
@property
def getter(self):
return getattr(self._pool, thunk.func_name)
return getter
def _destructively(aList):
"""
Destructively iterate a list, popping elements from the beginning.
"""
while aList:
yield aList.pop(0)
def _deriveParameters(cursor, args):
"""
Some DB-API extensions need to call special extension methods on
the cursor itself before executing.
@param cursor: The DB-API cursor object to derive parameters from.
@param args: the parameters being specified to C{execSQL}. This list will
be modified to present parameters suitable to pass to the C{cursor}'s
C{execute} method.
@return: a list of L{IDerivedParameter} providers which had C{preQuery}
executed on them, so that after the query they may have C{postQuery}
executed. This may also be C{None} if no parameters were derived.
@see: {IDerivedParameter}
"""
# TODO: have the underlying connection report whether it has any
# IDerivedParameters that it knows about, so we can skip even inspecting
# the arguments if none of them could possibly provide
# IDerivedParameter.
derived = None
for n, arg in enumerate(args):
if IDerivedParameter.providedBy(arg):
if derived is None:
# Be as sparing as possible with extra allocations, as this
# usually isn't needed, and we're doing a ton of extra work to
# support it.
derived = []
derived.append(arg)
args[n] = arg.preQuery(cursor)
return derived
def _deriveQueryEnded(cursor, derived):
"""
A query which involved some L{IDerivedParameter}s just ended. Execute any
post-query cleanup or tasks that those parameters have to do.
@param cursor: The DB-API object that derived the query.
@param derived: The L{IDerivedParameter} providers that were given
C{preQuery} notifications when the query started.
@return: C{None}
"""
for arg in derived:
arg.postQuery(cursor)
class _ConnectedTxn(object):
"""
L{IAsyncTransaction} implementation based on a L{ThreadHolder} in the
current process.
"""
implements(IAsyncTransaction)
noisy = False
def __init__(self, pool, threadHolder, connection, cursor):
self._pool = pool
self._completed = "idle"
self._cursor = cursor
self._connection = connection
self._holder = threadHolder
self._first = True
@_forward
def paramstyle(self):
"""
The paramstyle attribute is mirrored from the connection pool.
"""
@_forward
def dialect(self):
"""
The dialect attribute is mirrored from the connection pool.
"""
def _reallyExecSQL(self, sql, args=None, raiseOnZeroRowCount=None):
"""
Execute the given SQL on a thread, using a DB-API 2.0 cursor.
This method is invoked internally on a non-reactor thread, one
dedicated to and associated with the current cursor. It executes the
given SQL, re-connecting first if necessary, re-cycling the old
connection if necessary, and then, if there are results from the
statement (as determined by the DB-API 2.0 'description' attribute) it
will fetch all the rows and return them, leaving them to be relayed to
L{_ConnectedTxn.execSQL} via the L{ThreadHolder}.
The rules for possibly reconnecting automatically are: if this is the
very first statement being executed in this transaction, and an error
occurs in C{execute}, close the connection and try again. We will
ignore any errors from C{close()} (or C{rollback()}) and log them
during this process. This is OK because adbapi2 always enforces
transaction discipline: connections are never in autocommit mode, so if
the first statement in a transaction fails, nothing can have happened
to the database; as per the ADBAPI spec, a lost connection is a
rolled-back transaction. In the cases where some databases fail to
enforce transaction atomicity (i.e. schema manipulations),
re-executing the same statement will result, at worst, in a spurious
and harmless error (like "table already exists"), not corruption.
@param sql: The SQL string to execute.
@type sql: C{str}
@param args: The bind parameters to pass to adbapi, if any.
@type args: C{list} or C{None}
@param raiseOnZeroRowCount: If specified, an exception to raise when no
rows are found.
@return: all the rows that resulted from execution of the given C{sql},
or C{None}, if the statement is one which does not produce results.
@rtype: C{list} of C{tuple}, or C{NoneType}
@raise Exception: this function may raise any exception raised by the
underlying C{dbapi.connect}, C{cursor.execute},
L{IDerivedParameter.preQuery}, C{connection.cursor}, or
C{cursor.fetchall}.
@raise raiseOnZeroRowCount: if the argument was specified and no rows
were returned by the executed statement.
"""
wasFirst = self._first
# If this is the first time this cursor has been used in this
# transaction, remember that, but mark it as now used.
self._first = False
if args is None:
args = []
# Note: as of this writing, derived parameters are only used to support
# cx_Oracle's "host variable" feature (i.e. cursor.var()), and creating
# a host variable will never be a connection-oriented error (a
# disconnected cursor can happily create variables of all types).
# However, this may need to move into the 'try' below if other database
# features need to compute database arguments based on runtime state.
derived = _deriveParameters(self._cursor, args)
try:
self._cursor.execute(sql, args)
except:
# If execute() raised an exception, and this was the first thing to
# happen in the transaction, then the connection has probably gone
# bad in the meanwhile, and we should try again.
if wasFirst:
# Report the error before doing anything else, since doing
# other things may cause the traceback stack to be eliminated
# if they raise exceptions (even internally).
log.err(
Failure(),
"Exception from execute() on first statement in "
"transaction. Possibly caused by a database server "
"restart. Automatically reconnecting now."
)
try:
self._connection.close()
except:
# close() may raise an exception to alert us of an error as
# well. Right now the only type of error we know about is
# "the connection is already closed", which obviously
# doesn't need to be handled specially. Unfortunately the
# reporting of this type of error is not consistent or
# predictable across different databases, or even different
# bindings to the same database, so we have to do a
# catch-all here. While I can't imagine another type of
# error at the moment, bare 'except:'s are notorious for
# making debugging surprising error conditions very
# difficult, so let's make sure that the error is logged
# just in case.
log.err(
Failure(),
"Exception from close() while automatically "
"reconnecting. (Probably not serious.)"
)
# Now, if either of *these* things fail, there's an error here
# that we cannot workaround or address automatically, so no
# try:except: for them.
self._connection = self._pool.connectionFactory()
self._cursor = self._connection.cursor()
# Note that although this method is being invoked recursively,
# the '_first' flag is re-set at the very top, so we will _not_
# be re-entering it more than once.
result = self._reallyExecSQL(sql, args, raiseOnZeroRowCount)
return result
else:
raise
if derived is not None:
_deriveQueryEnded(self._cursor, derived)
if self._cursor.description:
# see test_raiseOnZeroRowCountWithUnreliableRowCount
rows = self._cursor.fetchall()
if not rows:
if raiseOnZeroRowCount is not None:
raise raiseOnZeroRowCount()
return rows
else:
if raiseOnZeroRowCount is not None and self._cursor.rowcount == 0:
raise raiseOnZeroRowCount()
return None
def execSQL(self, *args, **kw):
result = self._holder.submit(
lambda: self._reallyExecSQL(*args, **kw)
)
if self.noisy:
def reportResult(results):
sys.stdout.write("\n".join([
"",
"SQL: %r %r" % (args, kw),
"Results: %r" % (results,),
"",
]))
return results
result.addBoth(reportResult)
return result
def _end(self, really):
"""
Common logic for commit or abort. Executed in the main reactor thread.
@param really: the callable to execute in the cursor thread to actually
do the commit or rollback.
@return: a L{Deferred} which fires when the database logic has
completed.
@raise: L{AlreadyFinishedError} if the transaction has already been
committed or aborted.
"""
if not self._completed:
self._completed = "ended"
def reallySomething():
"""
Do the database work and set appropriate flags. Executed in
the cursor thread.
"""
if self._cursor is None or self._first:
return
really()
self._first = True
result = self._holder.submit(reallySomething)
self._pool._repoolAfter(self, result)
return result
else:
raise AlreadyFinishedError(self._completed)
def commit(self):
return self._end(self._connection.commit)
def abort(self):
return self._end(self._connection.rollback).addErrback(log.err)
def reset(self):
"""
Call this when placing this transaction back into the pool.
@raise RuntimeError: if the transaction has not been committed or
aborted.
"""
if not self._completed:
raise RuntimeError("Attempt to re-set active transaction.")
self._completed = False
def _releaseConnection(self):
"""
Release the thread and database connection associated with this
transaction.
"""
self._completed = "released"
self._stopped = True
holder = self._holder
self._holder = None
def _reallyClose():
if self._cursor is None:
return
self._connection.close()
holder.submit(_reallyClose)
return holder.stop()
class _NoTxn(object):
"""
An L{IAsyncTransaction} that indicates a local failure before we could even
communicate any statements (or possibly even any connection attempts) to
the server.
"""
implements(IAsyncTransaction)
def __init__(self, pool, reason):
self.paramstyle = pool.paramstyle
self.dialect = pool.dialect
self.reason = reason
def _everything(self, *a, **kw):
"""
Everything fails with a L{ConnectionError}.
"""
return fail(ConnectionError(self.reason))
execSQL = _everything
commit = _everything
abort = _everything
class _WaitingTxn(object):
"""
A L{_WaitingTxn} is an implementation of L{IAsyncTransaction} which cannot
yet actually execute anything, so it waits and spools SQL requests for
later execution. When a L{_ConnectedTxn} becomes available later, it can
be unspooled onto that.
"""
implements(IAsyncTransaction)
def __init__(self, pool):
"""
Initialize a L{_WaitingTxn} based on a L{ConnectionPool}. (The C{pool}
is used only to reflect C{dialect} and C{paramstyle} attributes; not
remembered or modified in any way.)
"""
self._spool = []
self.paramstyle = pool.paramstyle
self.dialect = pool.dialect
def _enspool(self, cmd, a=(), kw={}):
d = Deferred()
self._spool.append((d, cmd, a, kw))
return d
def _iterDestruct(self):
"""
Iterate the spool list destructively, while popping items from the
beginning. This allows code which executes more SQL in the callback of
a Deferred to not interfere with the originally submitted order of
commands.
"""
return _destructively(self._spool)
def _unspool(self, other):
"""
Unspool this transaction onto another transaction.
@param other: another provider of L{IAsyncTransaction} which will
actually execute the SQL statements we have been buffering.
"""
for (d, cmd, a, kw) in self._iterDestruct():
self._relayCommand(other, d, cmd, a, kw)
def _relayCommand(self, other, d, cmd, a, kw):
"""
Relay a single command to another transaction.
"""
maybeDeferred(getattr(other, cmd), *a, **kw).chainDeferred(d)
def execSQL(self, *a, **kw):
return self._enspool('execSQL', a, kw)
def commit(self):
return self._enspool('commit')
def abort(self):
"""
Succeed and do nothing. The actual logic for this method is mostly
implemented by L{_SingleTxn._stopWaiting}.
"""
return succeed(None)
class _HookableOperation(object):
def __init__(self):
self._hooks = []
@inlineCallbacks
def runHooks(self, ignored=None):
"""
Callback for C{commit} and C{abort} Deferreds.
"""
for operation in _destructively(self._hooks):
yield operation()
self.clear()
returnValue(ignored)
def addHook(self, operation):
"""
Implement L{IAsyncTransaction.postCommit}.
"""
if self._hooks is not None:
self._hooks.append(operation)
def clear(self):
"""
Remove all hooks from this operation. Once this is called, no
more hooks can be added ever again.
"""
self._hooks = None
class _CommitAndAbortHooks(object):
"""
Shared implementation of post-commit and post-abort hooks.
"""
# FIXME: this functionality needs direct tests, although it's pretty well-
# covered by txdav's test suite.
def __init__(self):
self._preCommit = _HookableOperation()
self._commit = _HookableOperation()
self._abort = _HookableOperation()
def _commitWithHooks(self, doCommit):
"""
Run pre-hooks, commit, the real DB commit, and then post-hooks.
"""
pre = self._preCommit.runHooks()
def ok(ignored):
self._abort.clear()
return doCommit().addCallback(self._commit.runHooks)
def failed(why):
return self.abort().addCallback(lambda ignored: why)
return pre.addCallbacks(ok, failed)
def preCommit(self, operation):
return self._preCommit.addHook(operation)
def postCommit(self, operation):
return self._commit.addHook(operation)
def postAbort(self, operation):
return self._abort.addHook(operation)
class _SingleTxn(_CommitAndAbortHooks,
proxyForInterface(iface=IAsyncTransaction,
originalAttribute='_baseTxn')):
"""
A L{_SingleTxn} is a single-use wrapper for the longer-lived
L{_ConnectedTxn}, so that if a badly-behaved API client accidentally hangs
on to one of these and, for example C{.abort()}s it multiple times once
another client is using that connection, it will get some harmless
tracebacks.
It's a wrapper around a "real" implementation; either a L{_ConnectedTxn},
L{_NoTxn}, or L{_WaitingTxn} depending on the availability of real
underlying datbase connections.
This is the only L{IAsyncTransaction} implementation exposed to application
code.
It's also the only implementor of the C{commandBlock} method for grouping
commands together.
"""
def __init__(self, pool, baseTxn):
super(_SingleTxn, self).__init__()
self._pool = pool
self._baseTxn = baseTxn
self._completed = False
self._currentBlock = None
self._blockedQueue = None
self._pendingBlocks = []
self._stillExecuting = []
def __repr__(self):
"""
Reveal the backend in the string representation.
"""
return '_SingleTxn(%r)' % (self._baseTxn,)
def _unspoolOnto(self, baseTxn):
"""
Replace my C{_baseTxn}, currently a L{_WaitingTxn}, with a new
implementation of L{IAsyncTransaction} that will actually do the work;
either a L{_ConnectedTxn} or a L{_NoTxn}.
"""
spooledBase = self._baseTxn
self._baseTxn = baseTxn
spooledBase._unspool(baseTxn)
def execSQL(self, sql, args=None, raiseOnZeroRowCount=None):
return self._execSQLForBlock(sql, args, raiseOnZeroRowCount, None)
def _execSQLForBlock(self, sql, args, raiseOnZeroRowCount, block):
"""
Execute some SQL for a particular L{CommandBlock}; or, if the given
C{block} is C{None}, execute it in the outermost transaction context.
"""
self._checkComplete()
if block is None and self._blockedQueue is not None:
return self._blockedQueue.execSQL(sql, args, raiseOnZeroRowCount)
# 'block' should always be _currentBlock at this point.
d = super(_SingleTxn, self).execSQL(sql, args, raiseOnZeroRowCount)
self._stillExecuting.append(d)
def itsDone(result):
self._stillExecuting.remove(d)
self._checkNextBlock()
return result
d.addBoth(itsDone)
return d
def _checkNextBlock(self):
"""
Check to see if there are any blocks pending statements waiting to
execute, and execute the next one if there are no outstanding execute
calls.
"""
if self._stillExecuting:
# If we're still executing statements, nevermind. We'll get called
# again by the 'itsDone' callback above.
return
if self._currentBlock is not None:
# If there's still a current block, then keep it going. We'll be
# called by the '_finishExecuting' callback below.
return
# There's no block executing now. What to do?
if self._pendingBlocks:
# If there are pending blocks, start one of them.
self._currentBlock = self._pendingBlocks.pop(0)
d = self._currentBlock._startExecuting()
d.addCallback(self._finishExecuting)
elif self._blockedQueue is not None:
# If there aren't any pending blocks any more, and there are
# spooled statements that aren't part of a block, unspool all the
# statements that have been held up until this point.
bq = self._blockedQueue
self._blockedQueue = None
bq._unspool(self)
def _finishExecuting(self, result):
"""
The active block just finished executing. Clear it and see if there
are more blocks to execute, or if all the blocks are done and we should
execute any queued free statements.
"""
self._currentBlock = None
self._checkNextBlock()
def commit(self):
if self._blockedQueue is not None:
# We're in the process of executing a block of commands. Wait
# until they're done. (Commit will be repeated in
# _checkNextBlock.)
return self._blockedQueue.commit()
def reallyCommit():
self._markComplete()
return super(_SingleTxn, self).commit()
return self._commitWithHooks(reallyCommit)
def abort(self):
self._markComplete()
self._commit.clear()
self._preCommit.clear()
result = super(_SingleTxn, self).abort()
if self in self._pool._waiting:
self._stopWaiting()
result.addCallback(self._abort.runHooks)
return result
def _stopWaiting(self):
"""
Stop waiting for a free transaction and fail.
"""
self._pool._waiting.remove(self)
self._completed = True
self._unspoolOnto(_NoTxn(self._pool,
"connection pool shut down while txn "
"waiting for database connection."))
def _checkComplete(self):
"""
If the transaction is complete, raise L{AlreadyFinishedError}
"""
if self._completed:
raise AlreadyFinishedError()
def _markComplete(self):
"""
Mark the transaction as complete, raising AlreadyFinishedError.
"""
self._checkComplete()
self._completed = True
def commandBlock(self):
"""
Create a L{CommandBlock} which will wait for all currently spooled
commands to complete before executing its own.
"""
self._checkComplete()
block = CommandBlock(self)
if self._currentBlock is None:
self._blockedQueue = _WaitingTxn(self._pool)
# FIXME: test the case where it's ready immediately.
self._checkNextBlock()
return block
def __del__(self):
"""
When garbage collected, a L{_SingleTxn} recycles itself.
"""
try:
if not self._completed:
self.abort()
except AlreadyFinishedError:
# The underlying transaction might already be completed without us
# knowing; for example if the service shuts down.
pass
class _Unspooler(object):
def __init__(self, orig):
self.orig = orig
def execSQL(self, sql, args=None, raiseOnZeroRowCount=None):
"""
Execute some SQL, but don't track a new Deferred.
"""
return self.orig.execSQL(sql, args, raiseOnZeroRowCount, False)
class CommandBlock(object):
"""
A partial implementation of L{IAsyncTransaction} that will group execSQL
calls together.
Does not implement commit() or abort(), because this will simply group
commands. In order to implement sub-transactions or checkpoints, some
understanding of the SQL dialect in use by the underlying connection is
required. Instead, it provides 'end'.
"""
implements(ICommandBlock)
def __init__(self, singleTxn):
self._singleTxn = singleTxn
self.paramstyle = singleTxn.paramstyle
self.dialect = singleTxn.dialect
self._spool = _WaitingTxn(singleTxn._pool)
self._started = False
self._ended = False
self._waitingForEnd = []
self._endDeferred = Deferred()
singleTxn._pendingBlocks.append(self)
def _startExecuting(self):
self._started = True
self._spool._unspool(_Unspooler(self))
return self._endDeferred
def execSQL(self, sql, args=None, raiseOnZeroRowCount=None, track=True):
"""
Execute some SQL within this command block.
@param sql: the SQL string to execute.
@param args: the SQL arguments.
@param raiseOnZeroRowCount: see L{IAsyncTransaction.execSQL}
@param track: an internal parameter; was this called by application
code or as part of unspooling some previously-queued requests?
True if application code, False if unspooling.
"""
if track and self._ended:
raise AlreadyFinishedError()
self._singleTxn._checkComplete()
if self._singleTxn._currentBlock is self and self._started:
d = self._singleTxn._execSQLForBlock(
sql, args, raiseOnZeroRowCount, self)
else:
d = self._spool.execSQL(sql, args, raiseOnZeroRowCount)
if track:
self._trackForEnd(d)
return d
def _trackForEnd(self, d):
"""
Watch the following L{Deferred}, since we need to watch it to determine
when C{end} should be considered done, and the next CommandBlock or
regular SQL statement should be unqueued.
"""
self._waitingForEnd.append(d)
def end(self):
"""
The block of commands has completed. Allow other SQL to run on the
underlying L{IAsyncTransaction}.
"""
# FIXME: test the case where end() is called when it's not the current
# executing block.
if self._ended:
raise AlreadyFinishedError()
self._ended = True
# TODO: maybe this should return a Deferred that's a clone of
# _endDeferred, so that callers can determine when the block is really
# complete? Struggling for an actual use-case on that one.
DeferredList(self._waitingForEnd).chainDeferred(self._endDeferred)
class _ConnectingPseudoTxn(object):
"""
This is a pseudo-Transaction for bookkeeping purposes.
When a connection has asked to connect, but has not yet completed
connecting, the L{ConnectionPool} still needs a way to shut it down. This
object provides that tracking handle, and will be present in the pool's
C{busy} list while it is populating the list.
"""
_retry = None
def __init__(self, pool, holder):
"""
Initialize the L{_ConnectingPseudoTxn}; get ready to connect.
@param pool: The pool that this connection attempt is participating in.
@type pool: L{ConnectionPool}
@param holder: the L{ThreadHolder} allocated to this connection attempt
and subsequent SQL executions for this connection.
@type holder: L{ThreadHolder}
"""
self._pool = pool
self._holder = holder
self._aborted = False
def abort(self):
"""
Ignore the result of attempting to connect to this database, and
instead simply close the connection and free the L{ThreadHolder}
allocated for it.
"""
self._aborted = True
if self._retry is not None:
self._retry.cancel()
d = self._holder.stop()
def removeme(ignored):
if self in self._pool._busy:
self._pool._busy.remove(self)
d.addCallback(removeme)
return d
def _fork(x):
"""
Produce a L{Deferred} that will fire when another L{Deferred} fires without
disturbing its results.
"""
d = Deferred()
def fired(result):
d.callback(result)
return result
x.addBoth(fired)
return d
class ConnectionPool(Service, object):
"""
This is a central service that has a threadpool and executes SQL statements
asynchronously, in a pool.
@ivar connectionFactory: a 0-or-1-argument callable that returns a DB-API
connection. The optional argument can be used as a label for
diagnostic purposes.
@ivar maxConnections: The connection pool will not attempt to make more
than this many concurrent connections to the database.
@type maxConnections: C{int}
@ivar reactor: The reactor used for scheduling threads as well as retries
for failed connect() attempts.
@type reactor: L{IReactorTime} and L{IReactorThreads} provider.
@ivar _free: The list of free L{_ConnectedTxn} objects which are not
currently attached to a L{_SingleTxn} object, and have active
connections ready for processing a new transaction.
@ivar _busy: The list of busy L{_ConnectedTxn} objects; those currently
servicing an unfinished L{_SingleTxn} object.
@ivar _finishing: The list of 2-tuples of L{_ConnectedTxn} objects which
have had C{abort} or C{commit} called on them, but are not done
executing that method, and the L{Deferred} returned from that method
that will be fired when its execution has completed.
@ivar _waiting: The list of L{_SingleTxn} objects attached to a
L{_WaitingTxn}; i.e. those which are awaiting a connection to become
free so that they can be executed.
@ivar _stopping: Is this L{ConnectionPool} in the process of shutting down?
(If so, new connections will not be established.)
"""
reactor = _reactor
RETRY_TIMEOUT = 10.0
def __init__(self, connectionFactory, maxConnections=10,
paramstyle=DEFAULT_PARAM_STYLE, dialect=DEFAULT_DIALECT):
super(ConnectionPool, self).__init__()
self.connectionFactory = connectionFactory
self.maxConnections = maxConnections
self.paramstyle = paramstyle
self.dialect = dialect
self._free = []
self._busy = []
self._waiting = []
self._finishing = []
self._stopping = False
def startService(self):
"""
Increase the thread pool size of the reactor by the number of threads
that this service may consume. This is important because unlike most
L{IReactorThreads} users, the connection work units are very long-lived
and block until this service has been stopped.
"""
super(ConnectionPool, self).startService()
tp = self.reactor.getThreadPool()
self.reactor.suggestThreadPoolSize(tp.max + self.maxConnections)
@inlineCallbacks
def stopService(self):
"""
Forcibly abort any outstanding transactions, and release all resources
(notably, threads).
"""
super(ConnectionPool, self).stopService()
self._stopping = True
# Phase 1: Cancel any transactions that are waiting so they won't try
# to eagerly acquire new connections as they flow into the free-list.
while self._waiting:
waiting = self._waiting[0]
waiting._stopWaiting()
# Phase 2: Wait for all the Deferreds from the L{_ConnectedTxn}s that
# have *already* been stopped.
while self._finishing:
yield _fork(self._finishing[0][1])
# Phase 3: All of the busy transactions must be aborted first. As each
# one is aborted, it will remove itself from the list.
while self._busy:
yield self._busy[0].abort()
# Phase 4: All transactions should now be in the free list, since
# 'abort()' will have put them there. Shut down all the associated
# ThreadHolders.
while self._free:
# Releasing a L{_ConnectedTxn} doesn't automatically recycle it /
# remove it the way aborting a _SingleTxn does, so we need to
# .pop() here. L{_ConnectedTxn.stop} really shouldn't be able to
# fail, as it's just stopping the thread, and the holder's stop()
# is independently submitted from .abort() / .close().
yield self._free.pop()._releaseConnection()
tp = self.reactor.getThreadPool()
self.reactor.suggestThreadPoolSize(tp.max - self.maxConnections)
def _createHolder(self):
"""
Create a L{ThreadHolder}. (Test hook.)
"""
return ThreadHolder(self.reactor)
def connection(self, label=""):
"""
Find and immediately return an L{IAsyncTransaction} object. Execution
of statements, commit and abort on that transaction may be delayed
until a real underlying database connection is available.
@return: an L{IAsyncTransaction}
"""
if self._stopping:
# FIXME: should be wrapping a _SingleTxn around this to get
# .commandBlock()
return _NoTxn(self, "txn created while DB pool shutting down")
if self._free:
basetxn = self._free.pop(0)
self._busy.append(basetxn)
txn = _SingleTxn(self, basetxn)
else:
txn = _SingleTxn(self, _WaitingTxn(self))
self._waiting.append(txn)
# FIXME/TESTME: should be len(self._busy) + len(self._finishing)
# (free doesn't need to be considered, as it's tested above)
if self._activeConnectionCount() < self.maxConnections:
self._startOneMore()
return txn
def _activeConnectionCount(self):
"""
@return: the number of active outgoing connections to the database.
"""
return len(self._busy) + len(self._finishing)
def _startOneMore(self):
"""
Start one more _ConnectedTxn.
"""
holder = self._createHolder()
holder.start()
txn = _ConnectingPseudoTxn(self, holder)
# take up a slot in the 'busy' list, sit there so we can be aborted.
self._busy.append(txn)
def initCursor():
# support threadlevel=1; we can't necessarily cursor() in a
# different thread than we do transactions in.
connection = self.connectionFactory()
cursor = connection.cursor()
return (connection, cursor)
def finishInit((connection, cursor)):
if txn._aborted:
return
baseTxn = _ConnectedTxn(
pool=self,
threadHolder=holder,
connection=connection,
cursor=cursor
)
self._busy.remove(txn)
self._repoolNow(baseTxn)
def maybeTryAgain(f):
log.err(f, "Re-trying connection due to connection failure")
txn._retry = self.reactor.callLater(self.RETRY_TIMEOUT, resubmit)
def resubmit():
d = holder.submit(initCursor)
d.addCallbacks(finishInit, maybeTryAgain)
resubmit()
def _repoolAfter(self, txn, d):
"""
Re-pool the given L{_ConnectedTxn} after the given L{Deferred} has
fired.
"""
self._busy.remove(txn)
finishRecord = (txn, d)
self._finishing.append(finishRecord)
def repool(result):
self._finishing.remove(finishRecord)
self._repoolNow(txn)
return result
def discard(result):
self._finishing.remove(finishRecord)
txn._releaseConnection()
self._startOneMore()
return result
return d.addCallbacks(repool, discard)
def _repoolNow(self, txn):
"""
Recycle a L{_ConnectedTxn} into the free list.
"""
txn.reset()
if self._waiting:
waiting = self._waiting.pop(0)
self._busy.append(txn)
waiting._unspoolOnto(txn)
else:
self._free.append(txn)
def txnarg():
return [('transactionID', Integer())]
CHUNK_MAX = 0xffff
class BigArgument(Argument):
"""
An argument whose payload can be larger than L{CHUNK_MAX}, by splitting
across multiple AMP keys.
"""
def fromBox(self, name, strings, objects, proto):
value = StringIO()
for counter in count():
chunk = strings.get("%s.%d" % (name, counter))
if chunk is None:
break
value.write(chunk)
objects[name] = self.fromString(value.getvalue())
def toBox(self, name, strings, objects, proto):
value = StringIO(self.toString(objects[name]))
for counter in count():
nextChunk = value.read(CHUNK_MAX)
if not nextChunk:
break
strings["%s.%d" % (name, counter)] = nextChunk
class Pickle(BigArgument):
"""
A pickle sent over AMP. This is to serialize the 'args' argument to
C{execSQL}, which is the dynamically-typed 'args' list argument to a DB-API
C{execute} function, as well as its dynamically-typed result ('rows').
This should be cleaned up into a nicer structure, but this is not a network
protocol, so we can be a little relaxed about security.
This is a L{BigArgument} rather than a regular L{Argument} because
individual arguments and query results need to contain entire vCard or
iCalendar documents, which can easily be greater than 64k.
"""
def toString(self, inObject):
return dumps(inObject)
def fromString(self, inString):
return loads(inString)
class FailsafeException(Exception):
"""
Exception raised by all responders.
"""
_quashErrors = {
FailsafeException: "SOMETHING_UNKNOWN",
AlreadyFinishedError: "ALREADY_FINISHED",
ConnectionError: "CONNECTION_ERROR",
}
def failsafeResponder(command):
"""
Wrap an AMP command responder in some fail-safe logic, to make it so that
unknown errors won't drop the connection, as AMP's default behavior would.
"""
def wrap(inner):
@inlineCallbacks
def innerinner(*a, **k):
try:
val = yield inner(*a, **k)
except:
f = Failure()
if f.type in command.errors:
returnValue(f)
else:
log.err(Failure(), "shared database connection pool error")
raise FailsafeException()
else:
returnValue(val)
return command.responder(innerinner)
return wrap
class StartTxn(Command):
"""
Start a transaction, identified with an ID generated by the client.
"""
arguments = txnarg()
errors = _quashErrors
class ExecSQL(Command):
"""
Execute an SQL statement.
"""
arguments = [('sql', String()),
('queryID', String()),
('args', Pickle()),
('blockID', String()),
('reportZeroRowCount', Boolean())] + txnarg()
errors = _quashErrors
class StartBlock(Command):
"""
Create a new SQL command block.
"""
arguments = [("blockID", String())] + txnarg()
errors = _quashErrors
class EndBlock(Command):
"""
Create a new SQL command block.
"""
arguments = [("blockID", String())] + txnarg()
errors = _quashErrors
class Row(Command):
"""
A row has been returned. Sent from server to client in response to
L{ExecSQL}.
"""
arguments = [('queryID', String()),
('row', Pickle())]
errors = _quashErrors
class QueryComplete(Command):
"""
A query issued with L{ExecSQL} is complete.
"""
arguments = [('queryID', String()),
('norows', Boolean()),
('derived', Pickle()),
('noneResult', Boolean())]
errors = _quashErrors
class Commit(Command):
arguments = txnarg()
errors = _quashErrors
class Abort(Command):
arguments = txnarg()
errors = _quashErrors
class _NoRows(Exception):
"""
Placeholder exception to report zero rows.
"""
class ConnectionPoolConnection(AMP):
"""
A L{ConnectionPoolConnection} is a single connection to a
L{ConnectionPool}. This is the server side of the connection-pool-sharing
protocol; it implements all the AMP responders necessary.
"""
def __init__(self, pool):
"""
Initialize a mapping of transaction IDs to transaction objects.
"""
super(ConnectionPoolConnection, self).__init__()
self.pool = pool
self._txns = {}
self._blocks = {}
def stopReceivingBoxes(self, why):
log.msg("(S) Stopped receiving boxes: " + why.getTraceback())
def unhandledError(self, failure):
"""
An unhandled error has occurred. Since we can't really classify errors
well on this protocol, log it and forget it.
"""
log.err(failure, "Shared connection pool server encountered an error.")
@failsafeResponder(StartTxn)
def start(self, transactionID):
self._txns[transactionID] = self.pool.connection()
return {}
@failsafeResponder(StartBlock)
def startBlock(self, transactionID, blockID):
self._blocks[blockID] = self._txns[transactionID].commandBlock()
return {}
@failsafeResponder(EndBlock)
def endBlock(self, transactionID, blockID):
self._blocks[blockID].end()
return {}
@failsafeResponder(ExecSQL)
@inlineCallbacks
def receivedSQL(self, transactionID, queryID, sql, args, blockID,
reportZeroRowCount):
derived = None
noneResult = False
for param in args:
if IDerivedParameter.providedBy(param):
if derived is None:
derived = []
derived.append(param)
if blockID:
txn = self._blocks[blockID]
else:
txn = self._txns[transactionID]
if reportZeroRowCount:
rozrc = _NoRows
else:
rozrc = None
try:
rows = yield txn.execSQL(sql, args, rozrc)
except _NoRows:
norows = True
else:
norows = False
if rows is not None:
for row in rows:
# Either this should be yielded or it should be
# requiresAnswer=False
self.callRemote(Row, queryID=queryID, row=row)
else:
noneResult = True
self.callRemote(QueryComplete, queryID=queryID, norows=norows,
derived=derived, noneResult=noneResult)
returnValue({})
def _complete(self, transactionID, thunk):
txn = self._txns.pop(transactionID)
return thunk(txn).addCallback(lambda ignored: {})
@failsafeResponder(Commit)
def commit(self, transactionID):
"""
Successfully complete the given transaction.
"""
def commitme(x):
return x.commit()
return self._complete(transactionID, commitme)
@failsafeResponder(Abort)
def abort(self, transactionID):
"""
Roll back the given transaction.
"""
def abortme(x):
return x.abort()
return self._complete(transactionID, abortme)
class ConnectionPoolClient(AMP):
"""
A client which can execute SQL.
"""
def __init__(self, dialect=POSTGRES_DIALECT,
paramstyle=DEFAULT_PARAM_STYLE):
# See DEFAULT_PARAM_STYLE FIXME above.
super(ConnectionPoolClient, self).__init__()
self._nextID = count().next
self._txns = weakref.WeakValueDictionary()
self._queries = {}
self.dialect = dialect
self.paramstyle = paramstyle
def unhandledError(self, failure):
"""
An unhandled error has occurred. Since we can't really classify errors
well on this protocol, log it and forget it.
"""
log.err(failure, "Shared connection pool client encountered an error.")
def stopReceivingBoxes(self, why):
log.msg("(C) Stopped receiving boxes: " + why.getTraceback())
def newTransaction(self):
"""
Create a new networked provider of L{IAsyncTransaction}.
(This will ultimately call L{ConnectionPool.connection} on the other
end of the wire.)
@rtype: L{IAsyncTransaction}
"""
txnid = str(self._nextID())
txn = _NetTransaction(client=self, transactionID=txnid)
self._txns[txnid] = txn
self.callRemote(StartTxn, transactionID=txnid)
return txn
@failsafeResponder(Row)
def row(self, queryID, row):
self._queries[queryID].row(row)
return {}
@failsafeResponder(QueryComplete)
def complete(self, queryID, norows, derived, noneResult):
self._queries.pop(queryID).done(norows, derived, noneResult)
return {}
class _Query(object):
def __init__(self, sql, raiseOnZeroRowCount, args):
self.sql = sql
self.args = args
self.results = []
self.deferred = Deferred()
self.raiseOnZeroRowCount = raiseOnZeroRowCount
def row(self, row):
"""
A row was received.
"""
self.results.append(row)
def done(self, norows, derived, noneResult):
"""
The query is complete.
@param norows: A boolean. True if there were not any rows.
@param derived: either C{None} or a C{list} of L{IDerivedParameter}
providers initially passed into the C{execSQL} that started this
query. The values of these object swill mutate the original input
parameters to resemble them. Although
L{IDerivedParameter.preQuery} and L{IDerivedParameter.postQuery}
are invoked on the other end of the wire, the local objects will be
made to appear as though they were called here.
@param noneResult: should the result of the query be C{None} (i.e. did
it not have a C{description} on the cursor).
"""
if noneResult and not self.results:
results = None
else:
results = self.results
if derived is not None:
# 1) Bleecchh.
# 2) FIXME: add some direct tests in test_adbapi2, the unit test
# for this crosses some abstraction boundaries so it's a little
# integration-y and in the tests for twext.enterprise.dal
for remote, local in zip(derived, self._deriveDerived()):
local.__dict__ = remote.__dict__
if norows and (self.raiseOnZeroRowCount is not None):
exc = self.raiseOnZeroRowCount()
self.deferred.errback(Failure(exc))
else:
self.deferred.callback(results)
def _deriveDerived(self):
derived = None
for param in self.args:
if IDerivedParameter.providedBy(param):
if derived is None:
derived = []
derived.append(param)
return derived
class _NetTransaction(_CommitAndAbortHooks):
"""
A L{_NetTransaction} is an L{AMP}-protocol-based provider of the
L{IAsyncTransaction} interface. It sends SQL statements, query results,
and commit/abort commands via an AMP socket to a pooling process.
"""
implements(IAsyncTransaction)
def __init__(self, client, transactionID):
"""
Initialize a transaction with a L{ConnectionPoolClient} and a unique
transaction identifier.
"""
super(_NetTransaction, self).__init__()
self._client = client
self._transactionID = transactionID
self._completed = False
self._committing = False
self._committed = False
@property
def paramstyle(self):
"""
Forward 'paramstyle' attribute to the client.
"""
return self._client.paramstyle
@property
def dialect(self):
"""
Forward 'dialect' attribute to the client.
"""
return self._client.dialect
def execSQL(self, sql, args=None, raiseOnZeroRowCount=None, blockID=""):
if not blockID:
if self._completed:
raise AlreadyFinishedError()
if args is None:
args = []
client = self._client
queryID = str(client._nextID())
query = client._queries[queryID] = _Query(sql, raiseOnZeroRowCount,
args)
result = (
client.callRemote(
ExecSQL, queryID=queryID, sql=sql, args=args,
transactionID=self._transactionID, blockID=blockID,
reportZeroRowCount=raiseOnZeroRowCount is not None,
)
.addCallback(lambda nothing: query.deferred)
)
return result
def _complete(self, command):
if self._completed:
raise AlreadyFinishedError()
self._completed = True
return self._client.callRemote(
command, transactionID=self._transactionID
).addCallback(lambda x: None)
def commit(self):
def reallyCommit():
self._committing = True
def done(whatever):
self._committed = True
return whatever
return self._complete(Commit).addBoth(done)
return self._commitWithHooks(reallyCommit)
def abort(self):
self._commit.clear()
self._preCommit.clear()
return self._complete(Abort).addCallback(self._abort.runHooks)
def commandBlock(self):
if self._completed:
raise AlreadyFinishedError()
blockID = str(self._client._nextID())
self._client.callRemote(
StartBlock, blockID=blockID, transactionID=self._transactionID
)
return _NetCommandBlock(self, blockID)
def __del__(self):
"""
When a L{_NetTransaction} is garabage collected, it aborts itself.
"""
if not self._completed:
def shush(f):
f.trap(ConnectionError, AlreadyFinishedError)
self.abort().addErrback(shush)
class _NetCommandBlock(object):
"""
Net command block.
"""
implements(ICommandBlock)
def __init__(self, transaction, blockID):
self._transaction = transaction
self._blockID = blockID
self._ended = False
@property
def paramstyle(self):
"""
Forward 'paramstyle' attribute to the transaction.
"""
return self._transaction.paramstyle
@property
def dialect(self):
"""
Forward 'dialect' attribute to the transaction.
"""
return self._transaction.dialect
def execSQL(self, sql, args=None, raiseOnZeroRowCount=None):
"""
Execute some SQL on this command block.
"""
if (
self._ended or self._transaction._completed and
not self._transaction._committing or self._transaction._committed
):
raise AlreadyFinishedError()
return self._transaction.execSQL(sql, args, raiseOnZeroRowCount,
self._blockID)
def end(self):
"""
End this block.
"""
if self._ended:
raise AlreadyFinishedError()
self._ended = True
self._transaction._client.callRemote(
EndBlock, blockID=self._blockID,
transactionID=self._transaction._transactionID
)
calendarserver-5.2+dfsg/twext/enterprise/locking.py 0000644 0001750 0001750 00000006546 12263346572 021667 0 ustar rahul rahul # -*- test-case-name: twext.enterprise.test.test_locking -*-
##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Utilities to restrict concurrency based on mutual exclusion.
"""
from twext.enterprise.dal.model import Table
from twext.enterprise.dal.model import SQLType
from twext.enterprise.dal.model import Constraint
from twext.enterprise.dal.syntax import SchemaSyntax
from twext.enterprise.dal.model import Schema
from twext.enterprise.dal.record import Record
from twext.enterprise.dal.record import fromTable
class AlreadyUnlocked(Exception):
"""
The lock you were trying to unlock was already unlocked.
"""
class LockTimeout(Exception):
"""
The lock you were trying to lock was already locked causing a timeout.
"""
def makeLockSchema(inSchema):
"""
Create a self-contained schema just for L{Locker} use, in C{inSchema}.
@param inSchema: a L{Schema} to add the locks table to.
@type inSchema: L{Schema}
@return: inSchema
"""
LockTable = Table(inSchema, 'NAMED_LOCK')
LockTable.addColumn("LOCK_NAME", SQLType("varchar", 255))
LockTable.tableConstraint(Constraint.NOT_NULL, ["LOCK_NAME"])
LockTable.tableConstraint(Constraint.UNIQUE, ["LOCK_NAME"])
LockTable.primaryKey = [LockTable.columnNamed("LOCK_NAME")]
return inSchema
LockSchema = SchemaSyntax(makeLockSchema(Schema(__file__)))
class NamedLock(Record, fromTable(LockSchema.NAMED_LOCK)):
"""
An L{AcquiredLock} lock against a shared data store that the current
process holds via the referenced transaction.
"""
@classmethod
def acquire(cls, txn, name):
"""
Acquire a lock with the given name.
@param name: The name of the lock to acquire. Against the same store,
no two locks may be acquired.
@type name: L{unicode}
@return: a L{Deferred} that fires with an L{AcquiredLock} when the lock
has fired, or fails when the lock has not been acquired.
"""
def autoRelease(self):
txn.preCommit(lambda: self.release(True))
return self
def lockFailed(f):
raise LockTimeout(name)
return cls.create(txn, lockName=name).addCallback(autoRelease).addErrback(lockFailed)
def release(self, ignoreAlreadyUnlocked=False):
"""
Release this lock.
@param ignoreAlreadyUnlocked: If you don't care about the current
status of this lock, and just want to release it if it is still
acquired, pass this parameter as L{True}. Otherwise this method
will raise an exception if it is invoked when the lock has already
been released.
@raise: L{AlreadyUnlocked}
@return: A L{Deferred} that fires with L{None} when the lock has been
unlocked.
"""
return self.delete()
calendarserver-5.2+dfsg/twext/enterprise/queue.py 0000644 0001750 0001750 00000150165 12276242656 021364 0 ustar rahul rahul # -*- test-case-name: twext.enterprise.test.test_queue -*-
##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
L{twext.enterprise.queue} is an U{eventually consistent
} task-queueing system for
use by applications with multiple front-end servers talking to a single
database instance, that want to defer and parallelize work that involves
storing the results of computation.
By enqueuing with L{twisted.enterprise.queue}, you may guarantee that the work
will I{eventually} be done, and reliably commit to doing it in the future, but
defer it if it does not need to be done I{now}.
To pick a hypothetical example, let's say that you have a store which wants to
issue a promotional coupon based on a customer loyalty program, in response to
an administrator clicking on a button. Determining the list of customers to
send the coupon to is quick: a simple query will get you all their names.
However, analyzing each user's historical purchase data is (A) time consuming
and (B) relatively isolated, so it would be good to do that in parallel, and it
would also be acceptable to have that happen at a later time, outside the
critical path.
Such an application might be implemented with this queueing system like so::
from twext.enterprise.queue import WorkItem, queueFromTransaction
from twext.enterprise.dal.parseschema import addSQLToSchema
from twext.enterprise.dal.syntax import SchemaSyntax
schemaModel = Schema()
addSQLToSchema('''
create table CUSTOMER (NAME varchar(255), ID integer primary key);
create table PRODUCT (NAME varchar(255), ID integer primary key);
create table PURCHASE (NAME varchar(255), WHEN timestamp,
CUSTOMER_ID integer references CUSTOMER,
PRODUCT_ID integer references PRODUCT;
create table COUPON_WORK (WORK_ID integer primary key,
CUSTOMER_ID integer references CUSTOMER);
create table COUPON (ID integer primary key,
CUSTOMER_ID integer references customer,
AMOUNT integer);
''')
schema = SchemaSyntax(schemaModel)
class Coupon(Record, fromTable(schema.COUPON_WORK)):
pass
class CouponWork(WorkItem, fromTable(schema.COUPON_WORK)):
@inlineCallbacks
def doWork(self):
purchases = yield Select(schema.PURCHASE,
Where=schema.PURCHASE.CUSTOMER_ID
== self.customerID).on(self.transaction)
couponAmount = yield doSomeMathThatTakesAWhile(purchases)
yield Coupon.create(customerID=self.customerID,
amount=couponAmount)
@inlineCallbacks
def makeSomeCoupons(txn):
# Note, txn was started before, will be committed later...
for customerID in (yield Select([schema.CUSTOMER.CUSTOMER_ID],
From=schema.CUSTOMER).on(txn)):
# queuer is a provider of IQueuer, of which there are several
# implementations in this module.
queuer.enqueueWork(txn, CouponWork, customerID=customerID)
"""
from functools import wraps
from datetime import datetime
from zope.interface import implements
from twisted.application.service import MultiService
from twisted.internet.protocol import Factory
from twisted.internet.defer import (
inlineCallbacks, returnValue, Deferred, passthru, succeed
)
from twisted.internet.endpoints import TCP4ClientEndpoint
from twisted.protocols.amp import AMP, Command, Integer, Argument, String
from twisted.python.reflect import qual
from twisted.python import log
from twext.enterprise.dal.syntax import SchemaSyntax, Lock, NamedValue
from twext.enterprise.dal.model import ProcedureCall
from twext.enterprise.dal.record import Record, fromTable, NoSuchRecord
from twisted.python.failure import Failure
from twext.enterprise.dal.model import Table, Schema, SQLType, Constraint
from twisted.internet.endpoints import TCP4ServerEndpoint
from twext.enterprise.ienterprise import IQueuer
from zope.interface.interface import Interface
from twext.enterprise.locking import NamedLock
class _IWorkPerformer(Interface):
"""
An object that can perform work.
Internal interface; implemented by several classes here since work has to
(in the worst case) pass from worker->controller->controller->worker.
"""
def performWork(table, workID): #@NoSelf
"""
@param table: The table where work is waiting.
@type table: L{TableSyntax}
@param workID: The primary key identifier of the given work.
@type workID: L{int}
@return: a L{Deferred} firing with an empty dictionary when the work is
complete.
@rtype: L{Deferred} firing L{dict}
"""
def makeNodeSchema(inSchema):
"""
Create a self-contained schema for L{NodeInfo} to use, in C{inSchema}.
@param inSchema: a L{Schema} to add the node-info table to.
@type inSchema: L{Schema}
@return: a schema with just the one table.
"""
# Initializing this duplicate schema avoids a circular dependency, but this
# should really be accomplished with independent schema objects that the
# transaction is made aware of somehow.
NodeTable = Table(inSchema, 'NODE_INFO')
NodeTable.addColumn("HOSTNAME", SQLType("varchar", 255))
NodeTable.addColumn("PID", SQLType("integer", None))
NodeTable.addColumn("PORT", SQLType("integer", None))
NodeTable.addColumn("TIME", SQLType("timestamp", None)).setDefaultValue(
# Note: in the real data structure, this is actually a not-cleaned-up
# sqlparse internal data structure, but it *should* look closer to
# this.
ProcedureCall("timezone", ["UTC", NamedValue('CURRENT_TIMESTAMP')])
)
for column in NodeTable.columns:
NodeTable.tableConstraint(Constraint.NOT_NULL, [column.name])
NodeTable.primaryKey = [NodeTable.columnNamed("HOSTNAME"),
NodeTable.columnNamed("PORT")]
return inSchema
NodeInfoSchema = SchemaSyntax(makeNodeSchema(Schema(__file__)))
@inlineCallbacks
def inTransaction(transactionCreator, operation):
"""
Perform the given operation in a transaction, committing or aborting as
required.
@param transactionCreator: a 0-arg callable that returns an
L{IAsyncTransaction}
@param operation: a 1-arg callable that takes an L{IAsyncTransaction} and
returns a value.
@return: a L{Deferred} that fires with C{operation}'s result or fails with
its error, unless there is an error creating, aborting or committing
the transaction.
"""
txn = transactionCreator()
try:
result = yield operation(txn)
except:
f = Failure()
yield txn.abort()
returnValue(f)
else:
yield txn.commit()
returnValue(result)
def astimestamp(v):
"""
Convert the given datetime to a POSIX timestamp.
"""
return (v - datetime.utcfromtimestamp(0)).total_seconds()
class TableSyntaxByName(Argument):
"""
Serialize and deserialize L{TableSyntax} objects for an AMP protocol with
an attached schema.
"""
def fromStringProto(self, inString, proto):
"""
Convert the name of the table into a table, given a C{proto} with an
attached C{schema}.
@param inString: the name of a table, as utf-8 encoded bytes
@type inString: L{bytes}
@param proto: an L{SchemaAMP}
"""
return getattr(proto.schema, inString.decode("UTF-8"))
def toString(self, inObject):
"""
Convert a L{TableSyntax} object into just its name for wire transport.
@param inObject: a table.
@type inObject: L{TableSyntax}
@return: the name of that table
@rtype: L{bytes}
"""
return inObject.model.name.encode("UTF-8")
class NodeInfo(Record, fromTable(NodeInfoSchema.NODE_INFO)):
"""
A L{NodeInfo} is information about a currently-active Node process.
"""
def endpoint(self, reactor):
"""
Create an L{IStreamServerEndpoint} that will talk to the node process
that is described by this L{NodeInfo}.
@return: an endpoint that will connect to this host.
@rtype: L{IStreamServerEndpoint}
"""
return TCP4ClientEndpoint(reactor, self.hostname, self.port)
def abstract(thunk):
"""
The decorated function is abstract.
@note: only methods are currently supported.
"""
@classmethod
@wraps(thunk)
def inner(cls, *a, **k):
raise NotImplementedError(qual(cls) + " does not implement " +
thunk.func_name)
return inner
class WorkItem(Record):
"""
A L{WorkItem} is an item of work which may be stored in a database, then
executed later.
L{WorkItem} is an abstract class, since it is a L{Record} with no table
associated via L{fromTable}. Concrete subclasses must associate a specific
table by inheriting like so::
class MyWorkItem(WorkItem, fromTable(schema.MY_TABLE)):
Concrete L{WorkItem}s should generally not be created directly; they are
both created and thereby implicitly scheduled to be executed by calling
L{enqueueWork } with the
appropriate L{WorkItem} concrete subclass. There are different queue
implementations (L{PeerConnectionPool} and L{LocalQueuer}, for example), so
the exact timing and location of the work execution may differ.
L{WorkItem}s may be constrained in the ordering and timing of their
execution, to control concurrency and for performance reasons repsectively.
Although all the usual database mutual-exclusion rules apply to work
executed in L{WorkItem.doWork}, implicit database row locking is not always
the best way to manage concurrency. They have some problems, including:
- implicit locks are easy to accidentally acquire out of order, which
can lead to deadlocks
- implicit locks are easy to forget to acquire correctly - for example,
any read operation which subsequently turns into a write operation
must have been acquired with C{Select(..., ForUpdate=True)}, but it
is difficult to consistently indicate that methods which abstract out
read operations must pass this flag in certain cases and not others.
- implicit locks are held until the transaction ends, which means that
if expensive (long-running) queue operations share the same lock with
cheap (short-running) queue operations or user interactions, the
cheap operations all have to wait for the expensive ones to complete,
but continue to consume whatever database resources they were using.
In order to ameliorate these problems with potentiallly concurrent work
that uses the same resources, L{WorkItem} provides a database-wide mutex
that is automatically acquired at the beginning of the transaction and
released at the end. To use it, simply L{align
} the C{group}
attribute on your L{WorkItem} subclass with a column holding a string
(varchar). L{WorkItem} subclasses with the same value for C{group} will
not execute their C{doWork} methods concurrently. Furthermore, if the lock
cannot be quickly acquired, database resources associated with the
transaction attempting it will be released, and the transaction rolled back
until a future transaction I{can} can acquire it quickly. If you do not
want any limits to concurrency, simply leave it set to C{None}.
In some applications it's possible to coalesce work together; to grab
multiple L{WorkItem}s in one C{doWork} transaction. All you need to do is
to delete the rows which back other L{WorkItem}s from the database, and
they won't be processed. Using the C{group} attribute, you can easily
prevent concurrency so that you can easily group these items together and
remove them as a set (otherwise, other workers might be attempting to
concurrently work on them and you'll get deletion errors).
However, if doing more work at once is less expensive, and you want to
avoid processing lots of individual rows in tiny transactions, you may also
delay the execution of a L{WorkItem} by setting its C{notBefore} attribute.
This must be backed by a database timestamp, so that processes which happen
to be restarting and examining the work to be done in the database don't
jump the gun and do it too early.
@cvar workID: the unique identifier (primary key) for items of this type.
On an instance of a concrete L{WorkItem} subclass, this attribute must
be an integer; on the concrete L{WorkItem} subclass itself, this
attribute must be a L{twext.enterprise.dal.syntax.ColumnSyntax}. Note
that this is automatically taken care of if you simply have a
corresponding C{work_id} column in the associated L{fromTable} on your
L{WorkItem} subclass. This column must be unique, and it must be an
integer. In almost all cases, this column really ought to be filled
out by a database-defined sequence; if not, you need some other
mechanism for establishing a cluster-wide sequence.
@type workID: L{int} on instance,
L{twext.enterprise.dal.syntax.ColumnSyntax} on class.
@cvar notBefore: the timestamp before which this item should I{not} be
processed. If unspecified, this should be the date and time of the
creation of the L{WorkItem}.
@type notBefore: L{datetime.datetime} on instance,
L{twext.enterprise.dal.syntax.ColumnSyntax} on class.
@ivar group: If not C{None}, a unique-to-the-database identifier for which
only one L{WorkItem} will execute at a time.
@type group: L{unicode} or L{NoneType}
"""
group = None
@abstract
def doWork(self):
"""
Subclasses must implement this to actually perform the queued work.
This method will be invoked in a worker process.
This method does I{not} need to delete the row referencing it; that
will be taken care of by the job queueing machinery.
"""
@classmethod
def forTable(cls, table):
"""
Look up a work-item class given a particular L{TableSyntax}. Factoring
this correctly may place it into L{twext.enterprise.record.Record}
instead; it is probably generally useful to be able to look up a mapped
class from a table.
@param table: the table to look up
@type table: L{twext.enterprise.dal.model.Table}
@return: the relevant subclass
@rtype: L{type}
"""
tableName = table.model.name
for subcls in cls.__subclasses__():
clstable = getattr(subcls, "table", None)
if table == clstable:
return subcls
raise KeyError("No mapped {0} class for {1}.".format(
cls, tableName
))
class PerformWork(Command):
"""
Notify another process that it must do some work that has been persisted to
the database, by informing it of the table and the ID where said work has
been persisted.
"""
arguments = [
("table", TableSyntaxByName()),
("workID", Integer()),
]
response = []
class ReportLoad(Command):
"""
Notify another node of the total, current load for this whole node (all of
its workers).
"""
arguments = [
("load", Integer())
]
response = []
class IdentifyNode(Command):
"""
Identify this node to its peer. The connector knows which hostname it's
looking for, and which hostname it considers itself to be, only the
initiator (not the listener) issues this command. This command is
necessary because we don't want to rely on DNS; if reverse DNS weren't set
up perfectly, the listener would not be able to identify its peer, and it
is easier to modify local configuration so that L{socket.getfqdn} returns
the right value than to ensure that DNS doesself.
"""
arguments = [
("host", String()),
("port", Integer()),
]
class SchemaAMP(AMP):
"""
An AMP instance which also has a L{Schema} attached to it.
@ivar schema: The schema to look up L{TableSyntaxByName} arguments in.
@type schema: L{Schema}
"""
def __init__(self, schema, boxReceiver=None, locator=None):
self.schema = schema
super(SchemaAMP, self).__init__(boxReceiver, locator)
class ConnectionFromPeerNode(SchemaAMP):
"""
A connection to a peer node. Symmetric; since the 'client' and the
'server' both serve the same role, the logic is the same in every node.
@ivar localWorkerPool: the pool of local worker procesess that can process
queue work.
@type localWorkerPool: L{WorkerConnectionPool}
@ivar _reportedLoad: The number of outstanding requests being processed by
the peer of this connection, from all requestors (both the host of this
connection and others), as last reported by the most recent
L{ReportLoad} message received from the peer.
@type _reportedLoad: L{int}
@ivar _bonusLoad: The number of additional outstanding requests being
processed by the peer of this connection; the number of requests made
by the host of this connection since the last L{ReportLoad} message.
@type _bonusLoad: L{int}
"""
implements(_IWorkPerformer)
def __init__(self, peerPool, boxReceiver=None, locator=None):
"""
Initialize this L{ConnectionFromPeerNode} with a reference to a
L{PeerConnectionPool}, as well as required initialization arguments for
L{AMP}.
@param peerPool: the connection pool within which this
L{ConnectionFromPeerNode} is a participant.
@type peerPool: L{PeerConnectionPool}
@see: L{AMP.__init__}
"""
self.peerPool = peerPool
self._bonusLoad = 0
self._reportedLoad = 0
super(ConnectionFromPeerNode, self).__init__(peerPool.schema,
boxReceiver, locator)
def reportCurrentLoad(self):
"""
Report the current load for the local worker pool to this peer.
"""
return self.callRemote(ReportLoad, load=self.totalLoad())
@ReportLoad.responder
def reportedLoad(self, load):
"""
The peer reports its load.
"""
self._reportedLoad = (load - self._bonusLoad)
return {}
def startReceivingBoxes(self, sender):
"""
Connection is up and running; add this to the list of active peers.
"""
r = super(ConnectionFromPeerNode, self).startReceivingBoxes(sender)
self.peerPool.addPeerConnection(self)
return r
def stopReceivingBoxes(self, reason):
"""
The connection has shut down; remove this from the list of active
peers.
"""
self.peerPool.removePeerConnection(self)
r = super(ConnectionFromPeerNode, self).stopReceivingBoxes(reason)
return r
def currentLoadEstimate(self):
"""
What is the current load estimate for this peer?
@return: The number of full "slots", i.e. currently-being-processed
queue items (and other items which may contribute to this process's
load, such as currently-being-processed client requests).
@rtype: L{int}
"""
return self._reportedLoad + self._bonusLoad
def performWork(self, table, workID):
"""
A L{local worker connection } is asking this
specific peer node-controller process to perform some work, having
already determined that it's appropriate.
@see: L{_IWorkPerformer.performWork}
"""
d = self.callRemote(PerformWork, table=table, workID=workID)
self._bonusLoad += 1
@d.addBoth
def performed(result):
self._bonusLoad -= 1
return result
@d.addCallback
def success(result):
return None
return d
@PerformWork.responder
def dispatchToWorker(self, table, workID):
"""
A remote peer node has asked this node to do some work; dispatch it to
a local worker on this node.
@param table: the table to work on.
@type table: L{TableSyntax}
@param workID: the identifier within the table.
@type workID: L{int}
@return: a L{Deferred} that fires when the work has been completed.
"""
return self.peerPool.performWorkForPeer(table, workID).addCallback(
lambda ignored: {}
)
@IdentifyNode.responder
def identifyPeer(self, host, port):
self.peerPool.mapPeer(host, port, self)
return {}
class WorkerConnectionPool(object):
"""
A pool of L{ConnectionFromWorker}s.
L{WorkerConnectionPool} also implements the same implicit protocol as a
L{ConnectionFromPeerNode}, but one that dispenses work to the local worker
processes rather than to a remote connection pool.
"""
implements(_IWorkPerformer)
def __init__(self, maximumLoadPerWorker=5):
self.workers = []
self.maximumLoadPerWorker = maximumLoadPerWorker
def addWorker(self, worker):
"""
Add a L{ConnectionFromWorker} to this L{WorkerConnectionPool} so that
it can be selected.
"""
self.workers.append(worker)
def removeWorker(self, worker):
"""
Remove a L{ConnectionFromWorker} from this L{WorkerConnectionPool} that
was previously added.
"""
self.workers.remove(worker)
def hasAvailableCapacity(self):
"""
Does this worker connection pool have any local workers who have spare
hasAvailableCapacity to process another queue item?
"""
for worker in self.workers:
if worker.currentLoad < self.maximumLoadPerWorker:
return True
return False
def allWorkerLoad(self):
"""
The total load of all currently connected workers.
"""
return sum(worker.currentLoad for worker in self.workers)
def _selectLowestLoadWorker(self):
"""
Select the local connection with the lowest current load, or C{None} if
all workers are too busy.
@return: a worker connection with the lowest current load.
@rtype: L{ConnectionFromWorker}
"""
return sorted(self.workers[:], key=lambda w: w.currentLoad)[0]
def performWork(self, table, workID):
"""
Select a local worker that is idle enough to perform the given work,
then ask them to perform it.
@param table: The table where work is waiting.
@type table: L{TableSyntax}
@param workID: The primary key identifier of the given work.
@type workID: L{int}
@return: a L{Deferred} firing with an empty dictionary when the work is
complete.
@rtype: L{Deferred} firing L{dict}
"""
preferredWorker = self._selectLowestLoadWorker()
result = preferredWorker.performWork(table, workID)
return result
class ConnectionFromWorker(SchemaAMP):
"""
An individual connection from a worker, as seem from the master's
perspective. L{ConnectionFromWorker}s go into a L{WorkerConnectionPool}.
"""
def __init__(self, peerPool, boxReceiver=None, locator=None):
super(ConnectionFromWorker, self).__init__(peerPool.schema,
boxReceiver, locator)
self.peerPool = peerPool
self._load = 0
@property
def currentLoad(self):
"""
What is the current load of this worker?
"""
return self._load
def startReceivingBoxes(self, sender):
"""
Start receiving AMP boxes from the peer. Initialize all necessary
state.
"""
result = super(ConnectionFromWorker, self).startReceivingBoxes(sender)
self.peerPool.workerPool.addWorker(self)
return result
def stopReceivingBoxes(self, reason):
"""
AMP boxes will no longer be received.
"""
result = super(ConnectionFromWorker, self).stopReceivingBoxes(reason)
self.peerPool.workerPool.removeWorker(self)
return result
@PerformWork.responder
def performWork(self, table, workID):
"""
Dispatch work to this worker.
@see: The responder for this should always be
L{ConnectionFromController.actuallyReallyExecuteWorkHere}.
"""
d = self.callRemote(PerformWork, table=table, workID=workID)
self._load += 1
@d.addBoth
def f(result):
self._load -= 1
return result
return d
class ConnectionFromController(SchemaAMP):
"""
A L{ConnectionFromController} is the connection to a node-controller
process, in a worker process. It processes requests from its own
controller to do work. It is the opposite end of the connection from
L{ConnectionFromWorker}.
"""
implements(IQueuer)
def __init__(self, transactionFactory, schema, whenConnected,
boxReceiver=None, locator=None):
super(ConnectionFromController, self).__init__(schema,
boxReceiver, locator)
self.transactionFactory = transactionFactory
self.whenConnected = whenConnected
# FIXME: Glyph it appears WorkProposal expects this to have reactor...
from twisted.internet import reactor
self.reactor = reactor
def startReceivingBoxes(self, sender):
super(ConnectionFromController, self).startReceivingBoxes(sender)
self.whenConnected(self)
def choosePerformer(self):
"""
To conform with L{WorkProposal}'s expectations, which may run in either
a controller (against a L{PeerConnectionPool}) or in a worker (against
a L{ConnectionFromController}), this is implemented to always return
C{self}, since C{self} is also an object that has a C{performWork}
method.
"""
return self
def performWork(self, table, workID):
"""
Ask the controller to perform some work on our behalf.
"""
return self.callRemote(PerformWork, table=table, workID=workID)
def enqueueWork(self, txn, workItemType, **kw):
"""
There is some work to do. Do it, ideally someplace else, ideally in
parallel. Later, let the caller know that the work has been completed
by firing a L{Deferred}.
@param workItemType: The type of work item to be enqueued.
@type workItemType: A subtype of L{WorkItem}
@param kw: The parameters to construct a work item.
@type kw: keyword parameters to C{workItemType.create}, i.e.
C{workItemType.__init__}
@return: an object that can track the enqueuing and remote execution of
this work.
@rtype: L{WorkProposal}
"""
wp = WorkProposal(self, txn, workItemType, kw)
wp._start()
return wp
@PerformWork.responder
def actuallyReallyExecuteWorkHere(self, table, workID):
"""
This is where it's time to actually do the work. The controller
process has instructed this worker to do it; so, look up the data in
the row, and do it.
"""
return (ultimatelyPerform(self.transactionFactory, table, workID)
.addCallback(lambda ignored: {}))
def ultimatelyPerform(txnFactory, table, workID):
"""
Eventually, after routing the work to the appropriate place, somebody
actually has to I{do} it.
@param txnFactory: a 0- or 1-argument callable that creates an
L{IAsyncTransaction}
@type txnFactory: L{callable}
@param table: the table object that corresponds to the necessary work item
@type table: L{twext.enterprise.dal.syntax.TableSyntax}
@param workID: the ID of the work to be performed
@type workID: L{int}
@return: a L{Deferred} which fires with C{None} when the work has been
performed, or fails if the work can't be performed.
"""
@inlineCallbacks
def work(txn):
workItemClass = WorkItem.forTable(table)
try:
workItem = yield workItemClass.load(txn, workID)
if workItem.group is not None:
yield NamedLock.acquire(txn, workItem.group)
# TODO: what if we fail? error-handling should be recorded
# someplace, the row should probably be marked, re-tries should be
# triggerable administratively.
yield workItem.delete()
# TODO: verify that workID is the primary key someplace.
yield workItem.doWork()
except NoSuchRecord:
# The record has already been removed
pass
return inTransaction(txnFactory, work)
class LocalPerformer(object):
"""
Implementor of C{performWork} that does its work in the local process,
regardless of other conditions.
"""
implements(_IWorkPerformer)
def __init__(self, txnFactory):
"""
Create this L{LocalPerformer} with a transaction factory.
"""
self.txnFactory = txnFactory
def performWork(self, table, workID):
"""
Perform the given work right now.
"""
return ultimatelyPerform(self.txnFactory, table, workID)
class WorkerFactory(Factory, object):
"""
Factory, to be used as the client to connect from the worker to the
controller.
"""
def __init__(self, transactionFactory, schema, whenConnected):
"""
Create a L{WorkerFactory} with a transaction factory and a schema.
"""
self.transactionFactory = transactionFactory
self.schema = schema
self.whenConnected = whenConnected
def buildProtocol(self, addr):
"""
Create a L{ConnectionFromController} connected to the
transactionFactory and store.
"""
return ConnectionFromController(self.transactionFactory, self.schema,
self.whenConnected)
class TransactionFailed(Exception):
"""
A transaction failed.
"""
def _cloneDeferred(d):
"""
Make a new Deferred, adding callbacks to C{d}.
@return: another L{Deferred} that fires with C{d's} result when C{d} fires.
@rtype: L{Deferred}
"""
d2 = Deferred()
d.chainDeferred(d2)
return d2
class WorkProposal(object):
"""
A L{WorkProposal} is a proposal for work that will be executed, perhaps on
another node, perhaps in the future.
@ivar _chooser: The object which will choose where the work in this
proposal gets performed. This must have both a C{choosePerformer}
method and a C{reactor} attribute, providing an L{IReactorTime}.
@type _chooser: L{PeerConnectionPool} or L{LocalQueuer}
@ivar txn: The transaction where the work will be enqueued.
@type txn: L{IAsyncTransaction}
@ivar workItemType: The type of work to be enqueued by this L{WorkProposal}
@type workItemType: L{WorkItem} subclass
@ivar kw: The keyword arguments to pass to C{self.workItemType.create} to
construct it.
@type kw: L{dict}
"""
def __init__(self, chooser, txn, workItemType, kw):
self._chooser = chooser
self.txn = txn
self.workItemType = workItemType
self.kw = kw
self._whenProposed = Deferred()
self._whenExecuted = Deferred()
self._whenCommitted = Deferred()
def _start(self):
"""
Execute this L{WorkProposal} by creating the work item in the database,
waiting for the transaction where that addition was completed to
commit, and asking the local node controller process to do the work.
"""
created = self.workItemType.create(self.txn, **self.kw)
def whenCreated(item):
self._whenProposed.callback(self)
@self.txn.postCommit
def whenDone():
self._whenCommitted.callback(self)
def maybeLater():
performer = self._chooser.choosePerformer()
@passthru(performer.performWork(item.table, item.workID)
.addCallback)
def performed(result):
self._whenExecuted.callback(self)
@performed.addErrback
def notPerformed(why):
self._whenExecuted.errback(why)
reactor = self._chooser.reactor
when = max(0, astimestamp(item.notBefore) - reactor.seconds())
# TODO: Track the returned DelayedCall so it can be stopped
# when the service stops.
self._chooser.reactor.callLater(when, maybeLater)
@self.txn.postAbort
def whenFailed():
self._whenCommitted.errback(TransactionFailed)
def whenNotCreated(failure):
self._whenProposed.errback(failure)
created.addCallbacks(whenCreated, whenNotCreated)
def whenExecuted(self):
"""
Let the caller know when the proposed work has been fully executed.
@note: The L{Deferred} returned by C{whenExecuted} should be used with
extreme caution. If an application decides to do any
database-persistent work as a result of this L{Deferred} firing,
that work I{may be lost} as a result of a service being normally
shut down between the time that the work is scheduled and the time
that it is executed. So, the only things that should be added as
callbacks to this L{Deferred} are those which are ephemeral, in
memory, and reflect only presentation state associated with the
user's perception of the completion of work, not logical chains of
work which need to be completed in sequence; those should all be
completed within the transaction of the L{WorkItem.doWork} that
gets executed.
@return: a L{Deferred} that fires with this L{WorkProposal} when the
work has been completed remotely.
"""
return _cloneDeferred(self._whenExecuted)
def whenProposed(self):
"""
Let the caller know when the work has been proposed; i.e. when the work
is first transmitted to the database.
@return: a L{Deferred} that fires with this L{WorkProposal} when the
relevant commands have been sent to the database to create the
L{WorkItem}, and fails if those commands do not succeed for some
reason.
"""
return _cloneDeferred(self._whenProposed)
def whenCommitted(self):
"""
Let the caller know when the work has been committed to; i.e. when the
transaction where the work was proposed has been committed to the
database.
@return: a L{Deferred} that fires with this L{WorkProposal} when the
relevant transaction has been committed, or fails if the
transaction is not committed for any reason.
"""
return _cloneDeferred(self._whenCommitted)
class _BaseQueuer(object):
implements(IQueuer)
def __init__(self):
super(_BaseQueuer, self).__init__()
self.proposalCallbacks = set()
def callWithNewProposals(self, callback):
self.proposalCallbacks.add(callback)
def transferProposalCallbacks(self, newQueuer):
newQueuer.proposalCallbacks = self.proposalCallbacks
return newQueuer
def enqueueWork(self, txn, workItemType, **kw):
"""
There is some work to do. Do it, someplace else, ideally in parallel.
Later, let the caller know that the work has been completed by firing a
L{Deferred}.
@param workItemType: The type of work item to be enqueued.
@type workItemType: A subtype of L{WorkItem}
@param kw: The parameters to construct a work item.
@type kw: keyword parameters to C{workItemType.create}, i.e.
C{workItemType.__init__}
@return: an object that can track the enqueuing and remote execution of
this work.
@rtype: L{WorkProposal}
"""
wp = WorkProposal(self, txn, workItemType, kw)
wp._start()
for callback in self.proposalCallbacks:
callback(wp)
return wp
class PeerConnectionPool(_BaseQueuer, MultiService, object):
"""
Each node has a L{PeerConnectionPool} connecting it to all the other nodes
currently active on the same database.
@ivar hostname: The hostname where this node process is running, as
reported by the local host's configuration. Possibly this should be
obtained via C{config.ServerHostName} instead of C{socket.getfqdn()};
although hosts within a cluster may be configured with the same
C{ServerHostName}; TODO need to confirm.
@type hostname: L{bytes}
@ivar thisProcess: a L{NodeInfo} representing this process, which is
initialized when this L{PeerConnectionPool} service is started via
C{startService}. May be C{None} if this service is not fully started
up or if it is shutting down.
@type thisProcess: L{NodeInfo}
@ivar queueProcessTimeout: The amount of time after a L{WorkItem} is
scheduled to be processed (its C{notBefore} attribute) that it is
considered to be "orphaned" and will be run by a lost-work check rather
than waiting for it to be requested. By default, 10 minutes.
@type queueProcessTimeout: L{float} (in seconds)
@ivar queueDelayedProcessInterval: The amount of time between database
pings, i.e. checks for over-due queue items that might have been
orphaned by a controller process that died mid-transaction. This is
how often the shared database should be pinged by I{all} nodes (i.e.,
all controller processes, or each instance of L{PeerConnectionPool});
each individual node will ping commensurately less often as more nodes
join the database.
@type queueDelayedProcessInterval: L{float} (in seconds)
@ivar reactor: The reactor used for scheduling timed events.
@type reactor: L{IReactorTime} provider.
@ivar peers: The list of currently connected peers.
@type peers: L{list} of L{PeerConnectionPool}
"""
implements(IQueuer)
from socket import getfqdn
from os import getpid
getfqdn = staticmethod(getfqdn)
getpid = staticmethod(getpid)
queueProcessTimeout = (10.0 * 60.0)
queueDelayedProcessInterval = (60.0)
def __init__(self, reactor, transactionFactory, ampPort, schema):
"""
Initialize a L{PeerConnectionPool}.
@param ampPort: The AMP TCP port number to listen on for inter-host
communication. This must be an integer (and not, say, an endpoint,
or an endpoint description) because we need to communicate it to
the other peers in the cluster in a way that will be meaningful to
them as clients.
@type ampPort: L{int}
@param transactionFactory: a 0- or 1-argument callable that produces an
L{IAsyncTransaction}
@param schema: The schema which contains all the tables associated with
the L{WorkItem}s that this L{PeerConnectionPool} will process.
@type schema: L{Schema}
"""
super(PeerConnectionPool, self).__init__()
self.reactor = reactor
self.transactionFactory = transactionFactory
self.hostname = self.getfqdn()
self.pid = self.getpid()
self.ampPort = ampPort
self.thisProcess = None
self.workerPool = WorkerConnectionPool()
self.peers = []
self.mappedPeers = {}
self.schema = schema
self._startingUp = None
self._listeningPort = None
self._lastSeenTotalNodes = 1
self._lastSeenNodeIndex = 1
def addPeerConnection(self, peer):
"""
Add a L{ConnectionFromPeerNode} to the active list of peers.
"""
self.peers.append(peer)
def totalLoad(self):
return self.workerPool.allWorkerLoad()
def workerListenerFactory(self):
"""
Factory that listens for connections from workers.
"""
f = Factory()
f.buildProtocol = lambda addr: ConnectionFromWorker(self)
return f
def removePeerConnection(self, peer):
"""
Remove a L{ConnectionFromPeerNode} to the active list of peers.
"""
self.peers.remove(peer)
def choosePerformer(self, onlyLocally=False):
"""
Choose a peer to distribute work to based on the current known slot
occupancy of the other nodes. Note that this will prefer distributing
work to local workers until the current node is full, because that
should be lower-latency. Also, if no peers are available, work will be
submitted locally even if the worker pool is already over-subscribed.
@return: the chosen peer.
@rtype: L{_IWorkPerformer} L{ConnectionFromPeerNode} or
L{WorkerConnectionPool}
"""
if self.workerPool.hasAvailableCapacity():
return self.workerPool
if self.peers and not onlyLocally:
return sorted(self.peers, key=lambda p: p.currentLoadEstimate())[0]
else:
return LocalPerformer(self.transactionFactory)
def performWorkForPeer(self, table, workID):
"""
A peer has requested us to perform some work; choose a work performer
local to this node, and then execute it.
"""
performer = self.choosePerformer(onlyLocally=True)
return performer.performWork(table, workID)
def allWorkItemTypes(self):
"""
Load all the L{WorkItem} types that this node can process and return
them.
@return: L{list} of L{type}
"""
# TODO: For completeness, this may need to involve a plugin query to
# make sure that all WorkItem subclasses are imported first.
for workItemSubclass in WorkItem.__subclasses__():
# TODO: It might be a good idea to offload this table-filtering to
# SchemaSyntax.__contains__, adding in some more structure-
# comparison of similarly-named tables. For now a name check is
# sufficient.
if workItemSubclass.table.model.name in set([x.model.name for x in
self.schema]):
yield workItemSubclass
def totalNumberOfNodes(self):
"""
How many nodes are there, total?
@return: the maximum number of other L{PeerConnectionPool} instances
that may be connected to the database described by
C{self.transactionFactory}. Note that this is not the current
count by connectivity, but the count according to the database.
@rtype: L{int}
"""
# TODO
return self._lastSeenTotalNodes
def nodeIndex(self):
"""
What ordinal does this node, i.e. this instance of
L{PeerConnectionPool}, occupy within the ordered set of all nodes
connected to the database described by C{self.transactionFactory}?
@return: the index of this node within the total collection. For
example, if this L{PeerConnectionPool} is 6 out of 30, this method
will return C{6}.
@rtype: L{int}
"""
# TODO
return self._lastSeenNodeIndex
def _periodicLostWorkCheck(self):
"""
Periodically, every node controller has to check to make sure that work
hasn't been dropped on the floor by someone. In order to do that it
queries each work-item table.
"""
@inlineCallbacks
def workCheck(txn):
if self.thisProcess:
nodes = [(node.hostname, node.port) for node in
(yield self.activeNodes(txn))]
nodes.sort()
self._lastSeenTotalNodes = len(nodes)
self._lastSeenNodeIndex = nodes.index(
(self.thisProcess.hostname, self.thisProcess.port)
)
for itemType in self.allWorkItemTypes():
tooLate = datetime.utcfromtimestamp(
self.reactor.seconds() - self.queueProcessTimeout
)
overdueItems = (yield itemType.query(
txn, (itemType.notBefore < tooLate))
)
for overdueItem in overdueItems:
peer = self.choosePerformer()
yield peer.performWork(overdueItem.table,
overdueItem.workID)
return inTransaction(self.transactionFactory, workCheck)
_currentWorkDeferred = None
_lostWorkCheckCall = None
def _lostWorkCheckLoop(self):
"""
While the service is running, keep checking for any overdue / lost work
items and re-submit them to the cluster for processing. Space out
those checks in time based on the size of the cluster.
"""
self._lostWorkCheckCall = None
@passthru(self._periodicLostWorkCheck().addErrback(log.err)
.addCallback)
def scheduleNext(result):
self._currentWorkDeferred = None
if not self.running:
return
index = self.nodeIndex()
now = self.reactor.seconds()
interval = self.queueDelayedProcessInterval
count = self.totalNumberOfNodes()
when = (now - (now % interval)) + (interval * (count + index))
delay = when - now
self._lostWorkCheckCall = self.reactor.callLater(
delay, self._lostWorkCheckLoop
)
self._currentWorkDeferred = scheduleNext
def startService(self):
"""
Register ourselves with the database and establish all outgoing
connections to other servers in the cluster.
"""
@inlineCallbacks
def startup(txn):
endpoint = TCP4ServerEndpoint(self.reactor, self.ampPort)
# If this fails, the failure mode is going to be ugly, just like
# all conflicted-port failures. But, at least it won't proceed.
self._listeningPort = yield endpoint.listen(self.peerFactory())
self.ampPort = self._listeningPort.getHost().port
yield Lock.exclusive(NodeInfo.table).on(txn)
nodes = yield self.activeNodes(txn)
selves = [node for node in nodes
if ((node.hostname == self.hostname) and
(node.port == self.ampPort))]
if selves:
self.thisProcess = selves[0]
nodes.remove(self.thisProcess)
yield self.thisProcess.update(pid=self.pid,
time=datetime.now())
else:
self.thisProcess = yield NodeInfo.create(
txn, hostname=self.hostname, port=self.ampPort,
pid=self.pid, time=datetime.now()
)
for node in nodes:
self._startConnectingTo(node)
self._startingUp = inTransaction(self.transactionFactory, startup)
@self._startingUp.addBoth
def done(result):
self._startingUp = None
super(PeerConnectionPool, self).startService()
self._lostWorkCheckLoop()
return result
@inlineCallbacks
def stopService(self):
"""
Stop this service, terminating any incoming or outgoing connections.
"""
yield super(PeerConnectionPool, self).stopService()
if self._startingUp is not None:
yield self._startingUp
if self._listeningPort is not None:
yield self._listeningPort.stopListening()
if self._lostWorkCheckCall is not None:
self._lostWorkCheckCall.cancel()
if self._currentWorkDeferred is not None:
yield self._currentWorkDeferred
for peer in self.peers:
peer.transport.abortConnection()
def activeNodes(self, txn):
"""
Load information about all other nodes.
"""
return NodeInfo.all(txn)
def mapPeer(self, host, port, peer):
"""
A peer has been identified as belonging to the given host/port
combination. Disconnect any other peer that claims to be connected for
the same peer.
"""
# if (host, port) in self.mappedPeers:
# TODO: think about this for race conditions
# self.mappedPeers.pop((host, port)).transport.loseConnection()
self.mappedPeers[(host, port)] = peer
def _startConnectingTo(self, node):
"""
Start an outgoing connection to another master process.
@param node: a description of the master to connect to.
@type node: L{NodeInfo}
"""
connected = node.endpoint(self.reactor).connect(self.peerFactory())
def whenConnected(proto):
self.mapPeer(node.hostname, node.port, proto)
proto.callRemote(IdentifyNode,
host=self.thisProcess.hostname,
port=self.thisProcess.port).addErrback(
noted, "identify"
)
def noted(err, x="connect"):
log.msg("Could not {0} to cluster peer {1} because {2}"
.format(x, node, str(err.value)))
connected.addCallbacks(whenConnected, noted)
def peerFactory(self):
"""
Factory for peer connections.
@return: a L{Factory} that will produce L{ConnectionFromPeerNode}
protocols attached to this L{PeerConnectionPool}.
"""
return _PeerPoolFactory(self)
class _PeerPoolFactory(Factory, object):
"""
Protocol factory responsible for creating L{ConnectionFromPeerNode}
connections, both client and server.
"""
def __init__(self, peerConnectionPool):
self.peerConnectionPool = peerConnectionPool
def buildProtocol(self, addr):
return ConnectionFromPeerNode(self.peerConnectionPool)
class LocalQueuer(_BaseQueuer):
"""
When work is enqueued with this queuer, it is just executed locally.
"""
implements(IQueuer)
def __init__(self, txnFactory, reactor=None):
super(LocalQueuer, self).__init__()
self.txnFactory = txnFactory
if reactor is None:
from twisted.internet import reactor
self.reactor = reactor
def choosePerformer(self):
"""
Choose to perform the work locally.
"""
return LocalPerformer(self.txnFactory)
class NonPerformer(object):
"""
Implementor of C{performWork} that doesn't actual perform any work. This
is used in the case where you want to be able to enqueue work for someone
else to do, but not take on any work yourself (such as a command line
tool).
"""
implements(_IWorkPerformer)
def performWork(self, table, workID):
"""
Don't perform work.
"""
return succeed(None)
class NonPerformingQueuer(_BaseQueuer):
"""
When work is enqueued with this queuer, it is never executed locally.
It's expected that the polling machinery will find the work and perform it.
"""
implements(IQueuer)
def __init__(self, reactor=None):
super(NonPerformingQueuer, self).__init__()
if reactor is None:
from twisted.internet import reactor
self.reactor = reactor
def choosePerformer(self):
"""
Choose to perform the work locally.
"""
return NonPerformer()
calendarserver-5.2+dfsg/twext/enterprise/__init__.py 0000644 0001750 0001750 00000001333 12263343324 021755 0 ustar rahul rahul ##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Extensions in the spirit of Twisted's "enterprise" package; things related to
database connectivity and management.
"""
calendarserver-5.2+dfsg/twext/enterprise/dal/ 0000755 0001750 0001750 00000000000 12322625326 020405 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/enterprise/dal/model.py 0000644 0001750 0001750 00000040171 12263343324 022061 0 ustar rahul rahul # -*- test-case-name: twext.enterprise.dal.test.test_parseschema -*-
##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Model classes for SQL.
"""
from twisted.python.util import FancyEqMixin
class SQLType(object):
"""
A data-type as defined in SQL; like "integer" or "real" or "varchar(255)".
@ivar name: the name of this type.
@type name: C{str}
@ivar length: the length of this type, if it is a type like 'varchar' or
'character' that comes with a parenthetical length.
@type length: C{int} or C{NoneType}
"""
def __init__(self, name, length):
_checkstr(name)
self.name = name
self.length = length
def __eq__(self, other):
"""
Compare equal to other L{SQLTypes} with matching name and length.
"""
if not isinstance(other, SQLType):
return NotImplemented
return (self.name, self.length) == (other.name, other.length)
def __ne__(self, other):
"""
(Inequality is the opposite of equality.)
"""
if not isinstance(other, SQLType):
return NotImplemented
return not self.__eq__(other)
def __repr__(self):
"""
A useful string representation which includes the name and length if
present.
"""
if self.length:
lendesc = '(%s)' % (self.length)
else:
lendesc = ''
return '' % (self.name, lendesc)
class Constraint(object):
"""
A constraint on a set of columns.
@ivar type: the type of constraint. Currently, only C{'UNIQUE'} and C{'NOT
NULL'} are supported.
@type type: C{str}
@ivar affectsColumns: Columns affected by this constraint.
@type affectsColumns: C{list} of L{Column}
"""
# Values for 'type' attribute:
NOT_NULL = 'NOT NULL'
UNIQUE = 'UNIQUE'
def __init__(self, type, affectsColumns, name=None):
self.affectsColumns = affectsColumns
# XXX: possibly different constraint types should have different
# classes?
self.type = type
self.name = name
class Check(Constraint):
"""
A 'check' constraint, which evaluates an SQL expression.
@ivar expression: the expression that should evaluate to True.
@type expression: L{twext.enterprise.dal.syntax.ExpressionSyntax}
"""
# XXX TODO: model for expression, rather than
def __init__(self, syntaxExpression, name=None):
self.expression = syntaxExpression
super(Check, self).__init__(
'CHECK', [c.model for c in self.expression.allColumns()], name
)
class ProcedureCall(object):
"""
An invocation of a stored procedure or built-in function.
"""
def __init__(self, name, args):
_checkstr(name)
self.name = name
self.args = args
class NO_DEFAULT(object):
"""
Placeholder value for not having a default. (C{None} would not be suitable,
as that would imply a default of C{NULL}).
"""
def _checkstr(x):
"""
Verify that C{x} is a C{str}. Raise a L{ValueError} if not. This is to
prevent pollution with unicode values.
"""
if not isinstance(x, str):
raise ValueError("%r is not a str." % (x,))
class Column(FancyEqMixin, object):
"""
A column from a table.
@ivar table: The L{Table} to which this L{Column} belongs.
@type table: L{Table}
@ivar name: The unqualified name of this column. For example, in the case
of a column BAR in a table FOO, this would be the string C{'BAR'}.
@type name: C{str}
@ivar type: The declared type of this column.
@type type: L{SQLType}
@ivar references: If this column references a foreign key on another table,
this will be a reference to that table; otherwise (normally) C{None}.
@type references: L{Table} or C{NoneType}
@ivar deleteAction: If this column references another table, home will this column's
row be altered when the matching row in that other table is deleted? Possible values are
None - for 'on delete no action'
'cascade' - for 'on delete cascade'
'set null' - for 'on delete set null'
'set default' - for 'on delete set default'
@type deleteAction: C{bool}
"""
compareAttributes = 'table name'.split()
def __init__(self, table, name, type):
_checkstr(name)
self.table = table
self.name = name
self.type = type
self.default = NO_DEFAULT
self.references = None
self.deleteAction = None
def __repr__(self):
return '' % (self.name, self.type)
def compare(self, other):
"""
Return the differences between two columns.
@param other: the column to compare with
@type other: L{Column}
"""
results = []
# TODO: sql_dump does not do types write now - so ignore this
# if self.type != other.type:
# results.append("Table: %s, mismatched column type: %s" % (self.table.name, self.name))
# TODO: figure out how to compare default, references and deleteAction
return results
def canBeNull(self):
"""
Can this column ever be C{NULL}, i.e. C{None}? In other words, is it
free of any C{NOT NULL} constraints?
@return: C{True} if so, C{False} if not.
"""
for constraint in self.table.constraints:
if self in constraint.affectsColumns:
if constraint.type is Constraint.NOT_NULL:
return False
return True
def setDefaultValue(self, value):
"""
Change the default value of this column. (Should only be called during
schema parsing.)
"""
self.default = value
def needsValue(self):
"""
Does this column require a value in C{INSERT} statements which create
rows?
@return: C{True} for L{Column}s with no default specified which also
cannot be NULL, C{False} otherwise.
@rtype: C{bool}
"""
return not (self.canBeNull() or
(self.default not in (None, NO_DEFAULT)))
def doesReferenceName(self, name):
"""
Change this column to refer to a table in the schema. (Should only be
called during schema parsing.)
@param name: the name of a L{Table} in this L{Column}'s L{Schema}.
@type name: L{str}
"""
self.references = self.table.schema.tableNamed(name)
class Table(FancyEqMixin, object):
"""
A set of columns.
@ivar descriptiveComment: A docstring for the table. Parsed from a '--'
comment preceding this table in the SQL schema file that was parsed, if
any.
@type descriptiveComment: C{str}
@ivar schema: a reference to the L{Schema} to which this table belongs.
@ivar primaryKey: a C{list} of L{Column} objects representing the primary
key of this table, or C{None} if no primary key has been specified.
"""
compareAttributes = 'schema name'.split()
def __init__(self, schema, name):
_checkstr(name)
self.descriptiveComment = ''
self.schema = schema
self.name = name
self.columns = []
self.constraints = []
self.schemaRows = []
self.primaryKey = None
self.schema.tables.append(self)
def __repr__(self):
return '' % (self.name, self.columns)
def compare(self, other):
"""
Return the differences between two tables.
@param other: the table to compare with
@type other: L{Table}
"""
results = []
myColumns = dict([(item.name.lower(), item) for item in self.columns])
otherColumns = dict([(item.name.lower(), item) for item in other.columns])
for item in set(myColumns.keys()) ^ set(otherColumns.keys()):
results.append("Table: %s, missing column: %s" % (self.name, item,))
for name in set(myColumns.keys()) & set(otherColumns.keys()):
results.extend(myColumns[name].compare(otherColumns[name]))
# TODO: figure out how to compare schemaRows
return results
def columnNamed(self, name):
"""
Retrieve a column from this table with a given name.
@raise KeyError: if no such table exists.
@return: a column
@rtype: L{Column}
"""
for column in self.columns:
if column.name == name:
return column
raise KeyError("no such column: %r" % (name,))
def addColumn(self, name, type):
"""
A new column was parsed for this table.
@param name: The unqualified name of the column.
@type name: C{str}
@param type: The L{SQLType} describing the column's type.
"""
column = Column(self, name, type)
self.columns.append(column)
return column
def tableConstraint(self, constraintType, columnNames):
"""
This table is affected by a constraint. (Should only be called during
schema parsing.)
@param constraintType: the type of constraint; either
L{Constraint.NOT_NULL} or L{Constraint.UNIQUE}, currently.
"""
affectsColumns = []
for name in columnNames:
affectsColumns.append(self.columnNamed(name))
self.constraints.append(Constraint(constraintType, affectsColumns))
def checkConstraint(self, protoExpression, name=None):
"""
This table is affected by a 'check' constraint. (Should only be called
during schema parsing.)
@param protoExpression: proto expression.
"""
self.constraints.append(Check(protoExpression, name))
def insertSchemaRow(self, values):
"""
A statically-defined row was inserted as part of the schema itself.
This is used for tables that want to track static enumerations, for
example, but want to be referred to by a foreign key in other tables for
proper referential integrity.
Append this data to this L{Table}'s L{Table.schemaRows}.
(Should only be called during schema parsing.)
@param values: a C{list} of data items, one for each column in this
table's current list of L{Column}s.
"""
row = {}
for column, value in zip(self.columns, values):
row[column] = value
self.schemaRows.append(row)
def addComment(self, comment):
"""
Add a comment to C{descriptiveComment}.
@param comment: some additional descriptive text
@type comment: C{str}
"""
self.descriptiveComment = comment
def uniques(self):
"""
Get the groups of unique columns for this L{Table}.
@return: an iterable of C{list}s of C{Column}s which are unique within
this table.
"""
for constraint in self.constraints:
if constraint.type is Constraint.UNIQUE:
yield list(constraint.affectsColumns)
class Index(object):
"""
An L{Index} is an SQL index.
"""
def __init__(self, schema, name, table, unique=False):
self.name = name
self.table = table
self.unique = unique
self.columns = []
schema.indexes.append(self)
def addColumn(self, column):
self.columns.append(column)
class PseudoIndex(object):
"""
A class used to represent explicit and implicit indexes. An implicit index is one the
DB creates for primary key and unique columns in a table. An explicit index is one
created by a CREATE [UNIQUE] INDEX statement. Because the name of an implicit index
is implementation defined, instead we create a name based on the table name, uniqueness
and column names.
"""
def __init__(self, table, columns, unique=False):
self.name = "%s%s:(%s)" % (table.name, "-unique" if unique else "", ",".join([col.name for col in columns]))
self.table = table
self.unique = unique
self.columns = columns
def compare(self, other):
"""
Return the differences between two indexes.
@param other: the index to compare with
@type other: L{Index}
"""
# Nothing to do as name comparison will catch differences
return []
class Sequence(FancyEqMixin, object):
"""
A sequence object.
"""
compareAttributes = 'name'.split()
def __init__(self, schema, name):
_checkstr(name)
self.name = name
self.referringColumns = []
schema.sequences.append(self)
def __repr__(self):
return '' % (self.name,)
def compare(self, other):
"""
Return the differences between two sequences.
@param other: the sequence to compare with
@type other: L{Sequence}
"""
# TODO: figure out whether to compare referringColumns attribute
return []
def _namedFrom(name, sequence):
"""
Retrieve an item with a given name attribute from a given sequence, or raise
a L{KeyError}.
"""
for item in sequence:
if item.name == name:
return item
raise KeyError(name)
class Schema(object):
"""
A schema containing tables, indexes, and sequences.
"""
def __init__(self, filename=''):
self.filename = filename
self.tables = []
self.indexes = []
self.sequences = []
def __repr__(self):
return '' % (self.filename,)
def compare(self, other):
"""
Return the differences between two schemas.
@param other: the schema to compare with
@type other: L{Schema}
"""
results = []
def _compareLists(list1, list2, descriptor):
myItems = dict([(item.name.lower()[:63], item) for item in list1])
otherItems = dict([(item.name.lower()[:63], item) for item in list2])
for item in set(myItems.keys()) - set(otherItems.keys()):
results.append("Schema: %s, missing %s: %s" % (other.filename, descriptor, item,))
for item in set(otherItems.keys()) - set(myItems.keys()):
results.append("Schema: %s, missing %s: %s" % (self.filename, descriptor, item,))
for name in set(myItems.keys()) & set(otherItems.keys()):
results.extend(myItems[name].compare(otherItems[name]))
_compareLists(self.tables, other.tables, "table")
_compareLists(self.pseudoIndexes(), other.pseudoIndexes(), "index")
_compareLists(self.sequences, other.sequences, "sequence")
return results
def pseudoIndexes(self):
"""
Return a set of indexes that include "implicit" indexes from table/column constraints. The name of the
index is formed from the table name and then list of columns.
"""
results = []
# First add the list of explicit indexes we have
for index in self.indexes:
results.append(PseudoIndex(index.table, index.columns, index.unique))
# Now do implicit index for each table
for table in self.tables:
if table.primaryKey is not None:
results.append(PseudoIndex(table, table.primaryKey, True))
for constraint in table.constraints:
if constraint.type == Constraint.UNIQUE:
results.append(PseudoIndex(table, constraint.affectsColumns, True))
return results
def tableNamed(self, name):
return _namedFrom(name, self.tables)
def sequenceNamed(self, name):
return _namedFrom(name, self.sequences)
def indexNamed(self, name):
return _namedFrom(name, self.indexes)
calendarserver-5.2+dfsg/twext/enterprise/dal/test/ 0000755 0001750 0001750 00000000000 12322625326 021364 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/enterprise/dal/test/test_parseschema.py 0000644 0001750 0001750 00000032422 12263343324 025272 0 ustar rahul rahul ##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for parsing an SQL schema, which cover L{twext.enterprise.dal.model}
and L{twext.enterprise.dal.parseschema}.
"""
from twext.enterprise.dal.model import Schema
from twext.enterprise.dal.syntax import CompoundComparison, ColumnSyntax
from twext.enterprise.dal.parseschema import addSQLToSchema
from twisted.trial.unittest import TestCase
class SchemaTestHelper(object):
"""
Mix-in that can parse a schema from a string.
"""
def schemaFromString(self, string):
"""
Createa a L{Schema}
"""
s = Schema(self.id())
addSQLToSchema(s, string)
return s
class ParsingExampleTests(TestCase, SchemaTestHelper):
"""
Tests for parsing some sample schemas.
"""
def test_simplest(self):
"""
Parse an extremely simple schema with one table in it.
"""
s = self.schemaFromString("create table foo (bar integer);")
self.assertEquals(len(s.tables), 1)
foo = s.tableNamed('foo')
self.assertEquals(len(foo.columns), 1)
bar = foo.columns[0]
self.assertEquals(bar.name, "bar")
self.assertEquals(bar.type.name, "integer")
def test_stringTypes(self):
"""
Table and column names should be byte strings.
"""
s = self.schemaFromString("create table foo (bar integer);")
self.assertEquals(len(s.tables), 1)
foo = s.tableNamed('foo')
self.assertIsInstance(foo.name, str)
self.assertIsInstance(foo.columnNamed('bar').name, str)
def test_typeWithLength(self):
"""
Parse a type with a length.
"""
s = self.schemaFromString("create table foo (bar varchar(6543))")
bar = s.tableNamed('foo').columnNamed('bar')
self.assertEquals(bar.type.name, "varchar")
self.assertEquals(bar.type.length, 6543)
def test_sequence(self):
"""
Parsing a 'create sequence' statement adds a L{Sequence} to the
L{Schema}.
"""
s = self.schemaFromString("create sequence myseq;")
self.assertEquals(len(s.sequences), 1)
self.assertEquals(s.sequences[0].name, "myseq")
def test_sequenceColumn(self):
"""
Parsing a 'create sequence' statement adds a L{Sequence} to the
L{Schema}, and then a table that contains a column which uses the SQL
C{nextval()} function to retrieve its default value from that sequence,
will cause the L{Column} object to refer to the L{Sequence} and vice
versa.
"""
s = self.schemaFromString(
"""
create sequence thingy;
create table thetable (
thecolumn integer default nextval('thingy')
);
""")
self.assertEquals(len(s.sequences), 1)
self.assertEquals(s.sequences[0].name, "thingy")
self.assertEquals(s.tables[0].columns[0].default, s.sequences[0])
self.assertEquals(s.sequences[0].referringColumns,
[s.tables[0].columns[0]])
def test_sequenceDefault(self):
"""
Default sequence column.
"""
s = self.schemaFromString(
"""
create sequence alpha;
create table foo (
bar integer default nextval('alpha') not null,
qux integer not null
);
""")
self.assertEquals(s.tableNamed("foo").columnNamed("bar").needsValue(),
False)
def test_sequenceDefaultWithParens(self):
"""
SQLite requires 'default' expression to be in parentheses, and that
should be equivalent on other databases; we should be able to parse
that too.
"""
s = self.schemaFromString(
"""
create sequence alpha;
create table foo (
bar integer default (nextval('alpha')) not null,
qux integer not null
);
"""
)
self.assertEquals(s.tableNamed("foo").columnNamed("bar").needsValue(),
False)
def test_defaultConstantColumns(self):
"""
Parsing a 'default' column with an appropriate type in it will return
that type as the 'default' attribute of the Column object.
"""
s = self.schemaFromString(
"""
create table a (
b integer default 4321,
c boolean default false,
d boolean default true,
e varchar(255) default 'sample value',
f varchar(255) default null
);
""")
table = s.tableNamed("a")
self.assertEquals(table.columnNamed("b").default, 4321)
self.assertEquals(table.columnNamed("c").default, False)
self.assertEquals(table.columnNamed("d").default, True)
self.assertEquals(table.columnNamed("e").default, 'sample value')
self.assertEquals(table.columnNamed("f").default, None)
def test_needsValue(self):
"""
Columns with defaults, or with a 'not null' constraint don't need a
value; columns without one don't.
"""
s = self.schemaFromString(
"""
create table a (
b integer default 4321 not null,
c boolean default false,
d integer not null,
e integer
)
""")
table = s.tableNamed("a")
# Has a default, NOT NULL.
self.assertEquals(table.columnNamed("b").needsValue(), False)
# Has a default _and_ nullable.
self.assertEquals(table.columnNamed("c").needsValue(), False)
# No default, not nullable.
self.assertEquals(table.columnNamed("d").needsValue(), True)
# Just nullable.
self.assertEquals(table.columnNamed("e").needsValue(), False)
def test_notNull(self):
"""
A column with a NOT NULL constraint in SQL will be parsed as a
constraint which returns False from its C{canBeNull()} method.
"""
s = self.schemaFromString(
"create table alpha (beta integer, gamma integer not null);"
)
t = s.tableNamed('alpha')
self.assertEquals(True, t.columnNamed('beta').canBeNull())
self.assertEquals(False, t.columnNamed('gamma').canBeNull())
def test_unique(self):
"""
A column with a UNIQUE constraint in SQL will result in the table
listing that column as a unique set.
"""
for identicalSchema in [
"create table sample (example integer unique);",
"create table sample (example integer, unique (example));",
"create table sample "
"(example integer, constraint unique_example unique (example))"]:
s = self.schemaFromString(identicalSchema)
table = s.tableNamed('sample')
column = table.columnNamed('example')
self.assertEquals(list(table.uniques()), [[column]])
def test_checkExpressionConstraint(self):
"""
A column with a CHECK constraint in SQL that uses an inequality will
result in a L{Check} constraint being added to the L{Table} object.
"""
def checkOneConstraint(sqlText, checkName=None):
s = self.schemaFromString(sqlText)
table = s.tableNamed('sample')
self.assertEquals(len(table.constraints), 1)
constraint = table.constraints[0]
expr = constraint.expression
self.assertIsInstance(expr, CompoundComparison)
self.assertEqual(expr.a.model, table.columnNamed('example'))
self.assertEqual(expr.b.value, 5)
self.assertEqual(expr.op, '>')
self.assertEqual(constraint.name, checkName)
checkOneConstraint(
"create table sample (example integer check (example > 5));"
)
checkOneConstraint(
"create table sample (example integer, check (example > 5));"
)
checkOneConstraint(
"create table sample "
"(example integer, constraint gt_5 check (example>5))", "gt_5"
)
def test_checkKeywordConstraint(self):
"""
A column with a CHECK constraint in SQL that compares with a keyword
expression such as 'lower' will result in a L{Check} constraint being
added to the L{Table} object.
"""
def checkOneConstraint(sqlText):
s = self.schemaFromString(sqlText)
table = s.tableNamed('sample')
self.assertEquals(len(table.constraints), 1)
expr = table.constraints[0].expression
self.assertEquals(expr.a.model, table.columnNamed("example"))
self.assertEquals(expr.op, "=")
self.assertEquals(expr.b.function.name, "lower")
self.assertEquals(
expr.b.args,
tuple([ColumnSyntax(table.columnNamed("example"))])
)
checkOneConstraint(
"create table sample "
"(example integer check (example = lower (example)));"
)
def test_multiUnique(self):
"""
A column with a UNIQUE constraint in SQL will result in the table
listing that column as a unique set.
"""
s = self.schemaFromString(
"create table a (b integer, c integer, unique (b, c), unique (c));"
)
a = s.tableNamed('a')
b = a.columnNamed('b')
c = a.columnNamed('c')
self.assertEquals(list(a.uniques()), [[b, c], [c]])
def test_singlePrimaryKey(self):
"""
A table with a multi-column PRIMARY KEY clause will be parsed as a list
of a single L{Column} object and stored as a C{primaryKey} attribute on
the L{Table} object.
"""
s = self.schemaFromString(
"create table a (b integer primary key, c integer)"
)
a = s.tableNamed("a")
self.assertEquals(a.primaryKey, [a.columnNamed("b")])
def test_multiPrimaryKey(self):
"""
A table with a multi-column PRIMARY KEY clause will be parsed as a list
C{primaryKey} attribute on the Table object.
"""
s = self.schemaFromString(
"create table a (b integer, c integer, primary key (b, c))"
)
a = s.tableNamed("a")
self.assertEquals(
a.primaryKey, [a.columnNamed("b"), a.columnNamed("c")]
)
def test_deleteAction(self):
"""
A column with an 'on delete cascade' constraint will have its C{cascade}
attribute set to True.
"""
s = self.schemaFromString(
"""
create table a1 (b1 integer primary key);
create table c2 (d2 integer references a1 on delete cascade);
create table e3 (f3 integer references a1 on delete set null);
create table g4 (h4 integer references a1 on delete set default);
""")
self.assertEquals(s.tableNamed("a1").columnNamed("b1").deleteAction, None)
self.assertEquals(s.tableNamed("c2").columnNamed("d2").deleteAction, "cascade")
self.assertEquals(s.tableNamed("e3").columnNamed("f3").deleteAction, "set null")
self.assertEquals(s.tableNamed("g4").columnNamed("h4").deleteAction, "set default")
def test_indexes(self):
"""
A 'create index' statement will add an L{Index} object to a L{Schema}'s
C{indexes} list.
"""
s = self.schemaFromString(
"""
create table q (b integer); -- noise
create table a (b integer primary key, c integer);
create table z (c integer); -- make sure we get the right table
create index idx_a_b on a(b);
create index idx_a_b_c on a (c, b);
create index idx_c on z using btree (c);
""")
a = s.tableNamed("a")
b = s.indexNamed("idx_a_b")
bc = s.indexNamed('idx_a_b_c')
self.assertEquals(b.table, a)
self.assertEquals(b.columns, [a.columnNamed("b")])
self.assertEquals(bc.table, a)
self.assertEquals(bc.columns, [a.columnNamed("c"), a.columnNamed("b")])
def test_pseudoIndexes(self):
"""
A implicit and explicit indexes are listed.
"""
s = self.schemaFromString(
"""
create table q (b integer); -- noise
create table a (b integer primary key, c integer);
create table z (c integer, unique(c) );
create unique index idx_a_c on a(c);
create index idx_a_b_c on a (c, b);
""")
self.assertEqual(set([pseudo.name for pseudo in s.pseudoIndexes()]), set((
"a-unique:(c)",
"a:(c,b)",
"a-unique:(b)",
"z-unique:(c)",
)))
calendarserver-5.2+dfsg/twext/enterprise/dal/test/test_record.py 0000644 0001750 0001750 00000030601 12263343324 024252 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Test cases for L{twext.enterprise.dal.record}.
"""
import datetime
from twisted.internet.defer import inlineCallbacks
from twisted.trial.unittest import TestCase
from twext.enterprise.dal.record import (
Record, fromTable, ReadOnly, NoSuchRecord
)
from twext.enterprise.dal.test.test_parseschema import SchemaTestHelper
from twext.enterprise.dal.syntax import SchemaSyntax
from twisted.internet.defer import gatherResults
from twisted.internet.defer import returnValue
from twext.enterprise.fixtures import buildConnectionPool
# from twext.enterprise.dal.syntax import
sth = SchemaTestHelper()
sth.id = lambda : __name__
schemaString = """
create table ALPHA (BETA integer primary key, GAMMA text);
create table DELTA (PHI integer primary key default (nextval('myseq')),
EPSILON text not null,
ZETA timestamp not null default '2012-12-12 12:12:12' );
"""
# sqlite can be made to support nextval() as a function, but 'create sequence'
# is syntax and can't.
parseableSchemaString = """
create sequence myseq;
""" + schemaString
testSchema = SchemaSyntax(sth.schemaFromString(parseableSchemaString))
class TestRecord(Record, fromTable(testSchema.ALPHA)):
"""
A sample test record.
"""
class TestAutoRecord(Record, fromTable(testSchema.DELTA)):
"""
A sample test record with default values specified.
"""
class TestCRUD(TestCase):
"""
Tests for creation, mutation, and deletion operations.
"""
def setUp(self):
self.pool = buildConnectionPool(self, schemaString)
@inlineCallbacks
def test_simpleLoad(self):
"""
Loading an existing row from the database by its primary key will
populate its attributes from columns of the corresponding row in the
database.
"""
txn = self.pool.connection()
yield txn.execSQL("insert into ALPHA values (:1, :2)", [234, "one"])
yield txn.execSQL("insert into ALPHA values (:1, :2)", [456, "two"])
rec = yield TestRecord.load(txn, 456)
self.assertIsInstance(rec, TestRecord)
self.assertEquals(rec.beta, 456)
self.assertEquals(rec.gamma, "two")
rec2 = yield TestRecord.load(txn, 234)
self.assertIsInstance(rec2, TestRecord)
self.assertEqual(rec2.beta, 234)
self.assertEqual(rec2.gamma, "one")
@inlineCallbacks
def test_missingLoad(self):
"""
Try loading an row which doesn't exist
"""
txn = self.pool.connection()
yield txn.execSQL("insert into ALPHA values (:1, :2)", [234, "one"])
self.assertFailure(TestRecord.load(txn, 456), NoSuchRecord)
@inlineCallbacks
def test_simpleCreate(self):
"""
When a record object is created, a row with matching column values will
be created in the database.
"""
txn = self.pool.connection()
rec = yield TestRecord.create(txn, beta=3, gamma=u'epsilon')
self.assertEquals(rec.beta, 3)
self.assertEqual(rec.gamma, u'epsilon')
rows = yield txn.execSQL("select BETA, GAMMA from ALPHA")
self.assertEqual(rows, [tuple([3, u'epsilon'])])
@inlineCallbacks
def test_simpleDelete(self):
"""
When a record object is deleted, a row with a matching primary key will
be deleted in the database.
"""
txn = self.pool.connection()
def mkrow(beta, gamma):
return txn.execSQL("insert into ALPHA values (:1, :2)",
[beta, gamma])
yield gatherResults([mkrow(123, u"one"), mkrow(234, u"two"),
mkrow(345, u"three")])
tr = yield TestRecord.load(txn, 234)
yield tr.delete()
rows = yield txn.execSQL("select BETA, GAMMA from ALPHA order by BETA")
self.assertEqual(rows, [(123, u"one"), (345, u"three")])
@inlineCallbacks
def oneRowCommitted(self, beta=123, gamma=u'456'):
"""
Create, commit, and return one L{TestRecord}.
"""
txn = self.pool.connection(self.id())
row = yield TestRecord.create(txn, beta=beta, gamma=gamma)
yield txn.commit()
returnValue(row)
@inlineCallbacks
def test_deleteWhenDeleted(self):
"""
When a record object is deleted, if it's already been deleted, it will
raise L{NoSuchRecord}.
"""
row = yield self.oneRowCommitted()
txn = self.pool.connection(self.id())
newRow = yield TestRecord.load(txn, row.beta)
yield newRow.delete()
self.failUnlessFailure(newRow.delete(), NoSuchRecord)
@inlineCallbacks
def test_cantCreateWithoutRequiredValues(self):
"""
When a L{Record} object is created without required values, it raises a
L{TypeError}.
"""
txn = self.pool.connection()
te = yield self.failUnlessFailure(TestAutoRecord.create(txn),
TypeError)
self.assertIn("required attribute 'epsilon' not passed", str(te))
@inlineCallbacks
def test_datetimeType(self):
"""
When a L{Record} references a timestamp column, it retrieves the date
as UTC.
"""
txn = self.pool.connection()
# Create ...
rec = yield TestAutoRecord.create(txn, epsilon=1)
self.assertEquals(rec.zeta, datetime.datetime(2012, 12, 12, 12, 12, 12))
yield txn.commit()
# ... should have the same effect as loading.
txn = self.pool.connection()
rec = (yield TestAutoRecord.all(txn))[0]
self.assertEquals(rec.zeta, datetime.datetime(2012, 12, 12, 12, 12, 12))
@inlineCallbacks
def test_tooManyAttributes(self):
"""
When a L{Record} object is created with unknown attributes (those which
don't map to any column), it raises a L{TypeError}.
"""
txn = self.pool.connection()
te = yield self.failUnlessFailure(TestRecord.create(
txn, beta=3, gamma=u'three',
extraBonusAttribute=u'nope',
otherBonusAttribute=4321,
), TypeError)
self.assertIn("extraBonusAttribute, otherBonusAttribute", str(te))
@inlineCallbacks
def test_createFillsInPKey(self):
"""
If L{Record.create} is called without an auto-generated primary key
value for its row, that value will be generated and set on the returned
object.
"""
txn = self.pool.connection()
tr = yield TestAutoRecord.create(txn, epsilon=u'specified')
tr2 = yield TestAutoRecord.create(txn, epsilon=u'also specified')
self.assertEquals(tr.phi, 1)
self.assertEquals(tr2.phi, 2)
@inlineCallbacks
def test_attributesArentMutableYet(self):
"""
Changing attributes on a database object is not supported yet, because
it's not entirely clear when to flush the SQL to the database.
Instead, for the time being, use C{.update}. When you attempt to set
an attribute, an error will be raised informing you of this fact, so
that the error is clear.
"""
txn = self.pool.connection()
rec = yield TestRecord.create(txn, beta=7, gamma=u'what')
def setit():
rec.beta = 12
ro = self.assertRaises(ReadOnly, setit)
self.assertEqual(rec.beta, 7)
self.assertIn("SQL-backed attribute 'TestRecord.beta' is read-only. "
"Use '.update(...)' to modify attributes.", str(ro))
@inlineCallbacks
def test_simpleUpdate(self):
"""
L{Record.update} will change the values on the record and in te
database.
"""
txn = self.pool.connection()
rec = yield TestRecord.create(txn, beta=3, gamma=u'epsilon')
yield rec.update(gamma=u'otherwise')
self.assertEqual(rec.gamma, u'otherwise')
yield txn.commit()
# Make sure that it persists.
txn = self.pool.connection()
rec = yield TestRecord.load(txn, 3)
self.assertEqual(rec.gamma, u'otherwise')
@inlineCallbacks
def test_simpleQuery(self):
"""
L{Record.query} will allow you to query for a record by its class
attributes as columns.
"""
txn = self.pool.connection()
for beta, gamma in [(123, u"one"), (234, u"two"), (345, u"three"),
(356, u"three"), (456, u"four")]:
yield txn.execSQL("insert into ALPHA values (:1, :2)",
[beta, gamma])
records = yield TestRecord.query(txn, TestRecord.gamma == u"three")
self.assertEqual(len(records), 2)
records.sort(key=lambda x: x.beta)
self.assertEqual(records[0].beta, 345)
self.assertEqual(records[1].beta, 356)
@inlineCallbacks
def test_all(self):
"""
L{Record.all} will return all instances of the record, sorted by
primary key.
"""
txn = self.pool.connection()
data = [(123, u"one"), (456, u"four"), (345, u"three"),
(234, u"two"), (356, u"three")]
for beta, gamma in data:
yield txn.execSQL("insert into ALPHA values (:1, :2)",
[beta, gamma])
self.assertEqual(
[(x.beta, x.gamma) for x in (yield TestRecord.all(txn))],
sorted(data)
)
@inlineCallbacks
def test_repr(self):
"""
The C{repr} of a L{Record} presents all its values.
"""
txn = self.pool.connection()
yield txn.execSQL("insert into ALPHA values (:1, :2)", [789, u'nine'])
rec = list((yield TestRecord.all(txn)))[0]
self.assertIn(" beta=789", repr(rec))
self.assertIn(" gamma=u'nine'", repr(rec))
@inlineCallbacks
def test_orderedQuery(self):
"""
L{Record.query} takes an 'order' argument which will allow the objects
returned to be ordered.
"""
txn = self.pool.connection()
for beta, gamma in [(123, u"one"), (234, u"two"), (345, u"three"),
(356, u"three"), (456, u"four")]:
yield txn.execSQL("insert into ALPHA values (:1, :2)",
[beta, gamma])
records = yield TestRecord.query(txn, TestRecord.gamma == u"three",
TestRecord.beta)
self.assertEqual([record.beta for record in records], [345, 356])
records = yield TestRecord.query(txn, TestRecord.gamma == u"three",
TestRecord.beta, ascending=False)
self.assertEqual([record.beta for record in records], [356, 345])
@inlineCallbacks
def test_pop(self):
"""
A L{Record} may be loaded and deleted atomically, with L{Record.pop}.
"""
txn = self.pool.connection()
for beta, gamma in [(123, u"one"), (234, u"two"), (345, u"three"),
(356, u"three"), (456, u"four")]:
yield txn.execSQL("insert into ALPHA values (:1, :2)",
[beta, gamma])
rec = yield TestRecord.pop(txn, 234)
self.assertEqual(rec.gamma, u'two')
self.assertEqual((yield txn.execSQL("select count(*) from ALPHA "
"where BETA = :1", [234])),
[tuple([0])])
yield self.failUnlessFailure(TestRecord.pop(txn, 234), NoSuchRecord)
def test_columnNamingConvention(self):
"""
The naming convention maps columns C{LIKE_THIS} to be attributes
C{likeThis}.
"""
self.assertEqual(Record.namingConvention(u"like_this"), "likeThis")
self.assertEqual(Record.namingConvention(u"LIKE_THIS"), "likeThis")
self.assertEqual(Record.namingConvention(u"LIKE_THIS_ID"), "likeThisID")
calendarserver-5.2+dfsg/twext/enterprise/dal/test/test_sqlsyntax.py 0000644 0001750 0001750 00000175242 12263343324 025055 0 ustar rahul rahul ##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for L{twext.enterprise.dal.syntax}
"""
from twext.enterprise.dal import syntax
from twext.enterprise.dal.parseschema import addSQLToSchema
from twext.enterprise.dal.syntax import (
Select, Insert, Update, Delete, Lock, SQLFragment,
TableMismatch, Parameter, Max, Len, NotEnoughValues,
Savepoint, RollbackToSavepoint, ReleaseSavepoint, SavepointAction,
Union, Intersect, Except, SetExpression, DALError,
ResultAliasSyntax, Count, QueryGenerator, ALL_COLUMNS,
DatabaseLock, DatabaseUnlock)
from twext.enterprise.dal.syntax import FixedPlaceholder, NumericPlaceholder
from twext.enterprise.dal.syntax import Function
from twext.enterprise.dal.syntax import SchemaSyntax
from twext.enterprise.dal.test.test_parseschema import SchemaTestHelper
from twext.enterprise.ienterprise import (POSTGRES_DIALECT, ORACLE_DIALECT,
SQLITE_DIALECT)
from twext.enterprise.test.test_adbapi2 import ConnectionPoolHelper
from twext.enterprise.test.test_adbapi2 import NetworkedPoolHelper
from twext.enterprise.test.test_adbapi2 import resultOf, AssertResultHelper
from twisted.internet.defer import succeed
from twisted.trial.unittest import TestCase
from twext.enterprise.dal.syntax import Tuple
from twext.enterprise.dal.syntax import Constant
class _FakeTransaction(object):
"""
An L{IAsyncTransaction} that provides the relevant metadata for SQL
generation.
"""
def __init__(self, paramstyle):
self.paramstyle = 'qmark'
class FakeCXOracleModule(object):
NUMBER = 'the NUMBER type'
STRING = 'a string type (for varchars)'
NCLOB = 'the NCLOB type. (for text)'
TIMESTAMP = 'for timestamps!'
class CatchSQL(object):
"""
L{IAsyncTransaction} emulator that records the SQL executed on it.
"""
counter = 0
def __init__(self, dialect=SQLITE_DIALECT, paramstyle='numeric'):
self.execed = []
self.pendingResults = []
self.dialect = SQLITE_DIALECT
self.paramstyle = 'numeric'
def nextResult(self, result):
"""
Make it so that the next result from L{execSQL} will be the argument.
"""
self.pendingResults.append(result)
def execSQL(self, sql, args, rozrc):
"""
Implement L{IAsyncTransaction} by recording C{sql} and C{args} in
C{self.execed}, and return a L{Deferred} firing either an integer or a
value pre-supplied by L{CatchSQL.nextResult}.
"""
self.execed.append([sql, args])
self.counter += 1
if self.pendingResults:
result = self.pendingResults.pop(0)
else:
result = self.counter
return succeed(result)
class NullTestingOracleTxn(object):
"""
Fake transaction for testing oracle NULL behavior.
"""
dialect = ORACLE_DIALECT
paramstyle = 'numeric'
def execSQL(self, text, params, exc):
return succeed([[None, None]])
EXAMPLE_SCHEMA = """
create sequence A_SEQ;
create table FOO (BAR integer, BAZ varchar(255));
create table BOZ (QUX integer, QUUX integer);
create table OTHER (BAR integer,
FOO_BAR integer not null);
create table TEXTUAL (MYTEXT varchar(255));
create table LEVELS (ACCESS integer,
USERNAME varchar(255));
create table NULLCHECK (ASTRING varchar(255) not null,
ANUMBER integer);
"""
class ExampleSchemaHelper(SchemaTestHelper):
"""
setUp implementor.
"""
def setUp(self):
self.schema = SchemaSyntax(self.schemaFromString(EXAMPLE_SCHEMA))
class GenerationTests(ExampleSchemaHelper, TestCase, AssertResultHelper):
"""
Tests for syntactic helpers to generate SQL queries.
"""
def test_simplestSelect(self):
"""
L{Select} generates a 'select' statement, by default, asking for all
rows in a table.
"""
self.assertEquals(Select(From=self.schema.FOO).toSQL(),
SQLFragment("select * from FOO", []))
def test_tableSyntaxFromSchemaSyntaxCompare(self):
"""
One L{TableSyntax} is equivalent to another wrapping the same table;
one wrapping a different table is different.
"""
self.assertEquals(self.schema.FOO, self.schema.FOO)
self.assertNotEquals(self.schema.FOO, self.schema.BOZ)
def test_simpleWhereClause(self):
"""
L{Select} generates a 'select' statement with a 'where' clause
containing an expression.
"""
self.assertEquals(Select(From=self.schema.FOO,
Where=self.schema.FOO.BAR == 1).toSQL(),
SQLFragment("select * from FOO where BAR = ?", [1]))
def test_alternateMetadata(self):
"""
L{Select} generates a 'select' statement with the specified placeholder
syntax when explicitly given L{ConnectionMetadata} which specifies a
placeholder.
"""
self.assertEquals(Select(From=self.schema.FOO,
Where=self.schema.FOO.BAR == 1).toSQL(
QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("$$"))),
SQLFragment("select * from FOO where BAR = $$", [1]))
def test_columnComparison(self):
"""
L{Select} generates a 'select' statement which compares columns.
"""
self.assertEquals(Select(From=self.schema.FOO,
Where=self.schema.FOO.BAR ==
self.schema.FOO.BAZ).toSQL(),
SQLFragment("select * from FOO where BAR = BAZ", []))
def test_comparisonTestErrorPrevention(self):
"""
The comparison object between SQL expressions raises an exception when
compared for a truth value, so that code will not accidentally operate
on SQL objects and get a truth value.
(Note that this has a caveat, in test_columnsAsDictKeys and
test_columnEqualityTruth.)
"""
def sampleComparison():
if self.schema.FOO.BAR > self.schema.FOO.BAZ:
return 'comparison should not succeed'
self.assertRaises(DALError, sampleComparison)
def test_compareWithNULL(self):
"""
Comparing a column with None results in the generation of an 'is null'
or 'is not null' SQL statement.
"""
self.assertEquals(Select(From=self.schema.FOO,
Where=self.schema.FOO.BAR ==
None).toSQL(),
SQLFragment(
"select * from FOO where BAR is null", []))
self.assertEquals(Select(From=self.schema.FOO,
Where=self.schema.FOO.BAR !=
None).toSQL(),
SQLFragment(
"select * from FOO where BAR is not null", []))
def test_compareWithEmptyStringOracleSpecialCase(self):
"""
Oracle considers the empty string to be a NULL value, so comparisons
with the empty string should be 'is NULL' comparisons.
"""
# Sanity check: let's make sure that the non-oracle case looks normal.
self.assertEquals(Select(
From=self.schema.FOO,
Where=self.schema.FOO.BAR == '').toSQL(),
SQLFragment(
"select * from FOO where BAR = ?", [""]))
self.assertEquals(Select(
From=self.schema.FOO,
Where=self.schema.FOO.BAR != '').toSQL(),
SQLFragment(
"select * from FOO where BAR != ?", [""]))
self.assertEquals(Select(
From=self.schema.FOO,
Where=self.schema.FOO.BAR == ''
).toSQL(QueryGenerator(ORACLE_DIALECT, NumericPlaceholder())),
SQLFragment(
"select * from FOO where BAR is null", []))
self.assertEquals(Select(
From=self.schema.FOO,
Where=self.schema.FOO.BAR != ''
).toSQL(QueryGenerator(ORACLE_DIALECT, NumericPlaceholder())),
SQLFragment(
"select * from FOO where BAR is not null", []))
def test_compoundWhere(self):
"""
L{Select.And} and L{Select.Or} will return compound columns.
"""
self.assertEquals(
Select(From=self.schema.FOO,
Where=(self.schema.FOO.BAR < 2).Or(
self.schema.FOO.BAR > 5)).toSQL(),
SQLFragment("select * from FOO where BAR < ? or BAR > ?", [2, 5]))
def test_orderBy(self):
"""
L{Select}'s L{OrderBy} parameter generates an 'order by' clause for a
'select' statement.
"""
self.assertEquals(
Select(From=self.schema.FOO,
OrderBy=self.schema.FOO.BAR).toSQL(),
SQLFragment("select * from FOO order by BAR")
)
def test_orderByOrder(self):
"""
L{Select}'s L{Ascending} parameter specifies an ascending/descending
order for query results with an OrderBy clause.
"""
self.assertEquals(
Select(From=self.schema.FOO,
OrderBy=self.schema.FOO.BAR,
Ascending=False).toSQL(),
SQLFragment("select * from FOO order by BAR desc")
)
self.assertEquals(
Select(From=self.schema.FOO,
OrderBy=self.schema.FOO.BAR,
Ascending=True).toSQL(),
SQLFragment("select * from FOO order by BAR asc")
)
self.assertEquals(
Select(From=self.schema.FOO,
OrderBy=[self.schema.FOO.BAR, self.schema.FOO.BAZ],
Ascending=True).toSQL(),
SQLFragment("select * from FOO order by BAR, BAZ asc")
)
def test_orderByParens(self):
"""
L{Select}'s L{OrderBy} paraneter, if specified as a L{Tuple}, generates
an SQL expression I{without} parentheses, since the standard format
does not allow an arbitrary sort expression but rather a list of
columns.
"""
self.assertEquals(
Select(From=self.schema.FOO,
OrderBy=Tuple([self.schema.FOO.BAR,
self.schema.FOO.BAZ])).toSQL(),
SQLFragment("select * from FOO order by BAR, BAZ")
)
def test_forUpdate(self):
"""
L{Select}'s L{ForUpdate} parameter generates a 'for update' clause at
the end of the query.
"""
self.assertEquals(
Select(From=self.schema.FOO, ForUpdate=True).toSQL(),
SQLFragment("select * from FOO for update")
)
def test_groupBy(self):
"""
L{Select}'s L{GroupBy} parameter generates a 'group by' clause for a
'select' statement.
"""
self.assertEquals(
Select(From=self.schema.FOO,
GroupBy=self.schema.FOO.BAR).toSQL(),
SQLFragment("select * from FOO group by BAR")
)
def test_groupByMulti(self):
"""
L{Select}'s L{GroupBy} parameter can accept multiple columns in a list.
"""
self.assertEquals(
Select(From=self.schema.FOO,
GroupBy=[self.schema.FOO.BAR,
self.schema.FOO.BAZ]).toSQL(),
SQLFragment("select * from FOO group by BAR, BAZ")
)
def test_joinClause(self):
"""
A table's .join() method returns a join statement in a SELECT.
"""
self.assertEquals(
Select(From=self.schema.FOO.join(
self.schema.BOZ, self.schema.FOO.BAR ==
self.schema.BOZ.QUX)).toSQL(),
SQLFragment("select * from FOO join BOZ on BAR = QUX", [])
)
def test_crossJoin(self):
"""
A join with no clause specified will generate a cross join. (This is an
explicit synonym for an implicit join: i.e. 'select * from FOO, BAR'.)
"""
self.assertEquals(
Select(From=self.schema.FOO.join(self.schema.BOZ)).toSQL(),
SQLFragment("select * from FOO cross join BOZ")
)
def test_joinJoin(self):
"""
L{Join.join} will result in a multi-table join.
"""
self.assertEquals(
Select([self.schema.FOO.BAR,
self.schema.BOZ.QUX],
From=self.schema.FOO
.join(self.schema.BOZ).join(self.schema.OTHER)).toSQL(),
SQLFragment(
"select FOO.BAR, QUX from FOO "
"cross join BOZ cross join OTHER")
)
def test_multiJoin(self):
"""
L{Join.join} has the same signature as L{TableSyntax.join} and supports
the same 'on' and 'type' arguments.
"""
self.assertEquals(
Select([self.schema.FOO.BAR],
From=self.schema.FOO.join(
self.schema.BOZ).join(
self.schema.OTHER,
self.schema.OTHER.BAR == self.schema.FOO.BAR,
'left outer')).toSQL(),
SQLFragment(
"select FOO.BAR from FOO cross join BOZ left outer join OTHER "
"on OTHER.BAR = FOO.BAR")
)
def test_tableAliasing(self):
"""
Tables may be given aliases, in order to facilitate self-joins.
"""
sfoo = self.schema.FOO
sfoo2 = sfoo.alias()
self.assertEqual(
Select(From=self.schema.FOO.join(sfoo2)).toSQL(),
SQLFragment("select * from FOO cross join FOO alias1")
)
def test_columnsOfAliasedTable(self):
"""
The columns of aliased tables will always be prefixed with their alias
in the generated SQL.
"""
sfoo = self.schema.FOO
sfoo2 = sfoo.alias()
self.assertEquals(
Select([sfoo2.BAR], From=sfoo2).toSQL(),
SQLFragment("select alias1.BAR from FOO alias1")
)
def test_multipleTableAliases(self):
"""
When multiple aliases are used for the same table, they will be unique
within the query.
"""
foo = self.schema.FOO
fooPrime = foo.alias()
fooPrimePrime = foo.alias()
self.assertEquals(
Select([fooPrime.BAR, fooPrimePrime.BAR],
From=fooPrime.join(fooPrimePrime)).toSQL(),
SQLFragment("select alias1.BAR, alias2.BAR "
"from FOO alias1 cross join FOO alias2")
)
def test_columnSelection(self):
"""
If a column is specified by the argument to L{Select}, those will be
output by the SQL statement rather than the all-columns wildcard.
"""
self.assertEquals(
Select([self.schema.FOO.BAR],
From=self.schema.FOO).toSQL(),
SQLFragment("select BAR from FOO")
)
def test_tableIteration(self):
"""
Iterating a L{TableSyntax} iterates its columns, in the order that they
are defined.
"""
self.assertEquals(list(self.schema.FOO),
[self.schema.FOO.BAR, self.schema.FOO.BAZ])
def test_noColumn(self):
"""
Accessing an attribute that is not a defined column on a L{TableSyntax}
raises an L{AttributeError}.
"""
self.assertRaises(AttributeError,
lambda : self.schema.FOO.NOT_A_COLUMN)
def test_columnAliases(self):
"""
When attributes are set on a L{TableSyntax}, they will be remembered as
column aliases, and their alias names may be retrieved via the
L{TableSyntax.columnAliases} method.
"""
self.assertEquals(self.schema.FOO.columnAliases(), {})
self.schema.FOO.ALIAS = self.schema.FOO.BAR
# you comparing ColumnSyntax object results in a ColumnComparison, which
# you can't test for truth.
fixedForEquality = dict([(k, v.model) for k, v in
self.schema.FOO.columnAliases().items()])
self.assertEquals(fixedForEquality,
{'ALIAS': self.schema.FOO.BAR.model})
self.assertIdentical(self.schema.FOO.ALIAS.model,
self.schema.FOO.BAR.model)
def test_multiColumnSelection(self):
"""
If multiple columns are specified by the argument to L{Select}, those
will be output by the SQL statement rather than the all-columns
wildcard.
"""
self.assertEquals(
Select([self.schema.FOO.BAZ,
self.schema.FOO.BAR],
From=self.schema.FOO).toSQL(),
SQLFragment("select BAZ, BAR from FOO")
)
def test_joinColumnSelection(self):
"""
If multiple columns are specified by the argument to L{Select} that uses
a L{TableSyntax.join}, those will be output by the SQL statement.
"""
self.assertEquals(
Select([self.schema.FOO.BAZ,
self.schema.BOZ.QUX],
From=self.schema.FOO.join(self.schema.BOZ,
self.schema.FOO.BAR ==
self.schema.BOZ.QUX)).toSQL(),
SQLFragment("select BAZ, QUX from FOO join BOZ on BAR = QUX")
)
def test_tableMismatch(self):
"""
When a column in the 'columns' argument does not match the table from
the 'From' argument, L{Select} raises a L{TableMismatch}.
"""
self.assertRaises(TableMismatch, Select, [self.schema.BOZ.QUX],
From=self.schema.FOO)
def test_qualifyNames(self):
"""
When two columns in the FROM clause requested from different tables have
the same name, the emitted SQL should explicitly disambiguate them.
"""
self.assertEquals(
Select([self.schema.FOO.BAR,
self.schema.OTHER.BAR],
From=self.schema.FOO.join(self.schema.OTHER,
self.schema.OTHER.FOO_BAR ==
self.schema.FOO.BAR)).toSQL(),
SQLFragment(
"select FOO.BAR, OTHER.BAR from FOO "
"join OTHER on FOO_BAR = FOO.BAR"))
def test_bindParameters(self):
"""
L{SQLFragment.bind} returns a copy of that L{SQLFragment} with the
L{Parameter} objects in its parameter list replaced with the keyword
arguments to C{bind}.
"""
self.assertEquals(
Select(From=self.schema.FOO,
Where=(self.schema.FOO.BAR > Parameter("testing")).And(
self.schema.FOO.BAZ < 7)).toSQL().bind(testing=173),
SQLFragment("select * from FOO where BAR > ? and BAZ < ?",
[173, 7]))
def test_rightHandSideExpression(self):
"""
Arbitrary expressions may be used as the right-hand side of a
comparison operation.
"""
self.assertEquals(
Select(From=self.schema.FOO,
Where=self.schema.FOO.BAR >
(self.schema.FOO.BAZ + 3)).toSQL(),
SQLFragment("select * from FOO where BAR > (BAZ + ?)", [3])
)
def test_setSelects(self):
"""
L{SetExpression} produces set operation on selects.
"""
# Simple UNION
self.assertEquals(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 1),
SetExpression=Union(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 2),
),
),
).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
SQLFragment(
"(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?)", [1, 2]))
# Simple INTERSECT ALL
self.assertEquals(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 1),
SetExpression=Intersect(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 2),
),
optype=SetExpression.OPTYPE_ALL
),
).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
SQLFragment(
"(select * from FOO where BAR = ?) INTERSECT ALL (select * from FOO where BAR = ?)", [1, 2]))
# Multiple EXCEPTs, not nested, Postgres dialect
self.assertEquals(
Select(
From=self.schema.FOO,
SetExpression=Except(
(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 2),
),
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 3),
),
),
optype=SetExpression.OPTYPE_DISTINCT,
),
).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
SQLFragment(
"(select * from FOO) EXCEPT DISTINCT (select * from FOO where BAR = ?) EXCEPT DISTINCT (select * from FOO where BAR = ?)", [2, 3]))
# Nested EXCEPTs, Oracle dialect
self.assertEquals(
Select(
From=self.schema.FOO,
SetExpression=Except(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 2),
SetExpression=Except(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 3),
),
),
),
),
).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
SQLFragment(
"(select * from FOO) MINUS ((select * from FOO where BAR = ?) MINUS (select * from FOO where BAR = ?))", [2, 3]))
# UNION with order by
self.assertEquals(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 1),
SetExpression=Union(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 2),
),
),
OrderBy=self.schema.FOO.BAR,
).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
SQLFragment(
"(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?) order by BAR", [1, 2]))
def test_simpleSubSelects(self):
"""
L{Max}C{(column)} produces an object in the 'columns' clause that
renders the 'max' aggregate in SQL.
"""
self.assertEquals(
Select(
[Max(self.schema.BOZ.QUX)],
From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ))
).toSQL(),
SQLFragment(
"select max(QUX) from (select QUX from BOZ) genid_1"))
self.assertEquals(
Select(
[Count(self.schema.BOZ.QUX)],
From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ))
).toSQL(),
SQLFragment(
"select count(QUX) from (select QUX from BOZ) genid_1"))
self.assertEquals(
Select(
[Max(self.schema.BOZ.QUX)],
From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ, As="alias_BAR")),
).toSQL(),
SQLFragment(
"select max(QUX) from (select QUX from BOZ) alias_BAR"))
def test_setSubSelects(self):
"""
L{SetExpression} in a From sub-select.
"""
# Simple UNION
self.assertEquals(
Select(
[Max(self.schema.FOO.BAR)],
From=Select(
[self.schema.FOO.BAR],
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 1),
SetExpression=Union(
Select(
[self.schema.FOO.BAR],
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 2),
),
),
)
).toSQL(),
SQLFragment(
"select max(BAR) from ((select BAR from FOO where BAR = ?) UNION (select BAR from FOO where BAR = ?)) genid_1", [1, 2]))
def test_selectColumnAliases(self):
"""
L{Select} works with aliased columns.
"""
self.assertEquals(
Select(
[ResultAliasSyntax(self.schema.BOZ.QUX, "BOZ_QUX")],
From=self.schema.BOZ
).toSQL(),
SQLFragment("select QUX BOZ_QUX from BOZ"))
self.assertEquals(
Select(
[ResultAliasSyntax(Max(self.schema.BOZ.QUX))],
From=self.schema.BOZ
).toSQL(),
SQLFragment("select max(QUX) genid_1 from BOZ"))
alias = ResultAliasSyntax(Max(self.schema.BOZ.QUX))
self.assertEquals(
Select([alias.columnReference()],
From=Select(
[alias],
From=self.schema.BOZ)
).toSQL(),
SQLFragment("select genid_1 from (select max(QUX) genid_1 from BOZ) genid_2"))
alias = ResultAliasSyntax(Len(self.schema.BOZ.QUX))
self.assertEquals(
Select([alias.columnReference()],
From=Select(
[alias],
From=self.schema.BOZ)
).toSQL(),
SQLFragment("select genid_1 from (select character_length(QUX) genid_1 from BOZ) genid_2"))
def test_inSubSelect(self):
"""
L{ColumnSyntax.In} returns a sub-expression using the SQL 'in' syntax
with a sub-select.
"""
wherein = (self.schema.FOO.BAR.In(
Select([self.schema.BOZ.QUX], From=self.schema.BOZ)))
self.assertEquals(
Select(From=self.schema.FOO, Where=wherein).toSQL(),
SQLFragment(
"select * from FOO where BAR in (select QUX from BOZ)"))
def test_inParameter(self):
"""
L{ColumnSyntax.In} returns a sub-expression using the SQL 'in' syntax
with parameter list.
"""
# One item with IN only
items = set(('A',))
self.assertEquals(
Select(From=self.schema.FOO, Where=self.schema.FOO.BAR.In(Parameter("names", len(items)))).toSQL().bind(names=items),
SQLFragment(
"select * from FOO where BAR in (?)", ['A']))
# Two items with IN only
items = set(('A', 'B'))
self.assertEquals(
Select(From=self.schema.FOO, Where=self.schema.FOO.BAR.In(Parameter("names", len(items)))).toSQL().bind(names=items),
SQLFragment(
"select * from FOO where BAR in (?, ?)", ['A', 'B']))
# Two items with preceding AND
self.assertEquals(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAZ == Parameter('P1')).And(
self.schema.FOO.BAR.In(Parameter("names", len(items))
))
).toSQL().bind(P1="P1", names=items),
SQLFragment(
"select * from FOO where BAZ = ? and BAR in (?, ?)", ['P1', 'A', 'B']),
)
# Two items with following AND
self.assertEquals(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR.In(Parameter("names", len(items))).And(
self.schema.FOO.BAZ == Parameter('P2')
))
).toSQL().bind(P2="P2", names=items),
SQLFragment(
"select * from FOO where BAR in (?, ?) and BAZ = ?", ['A', 'B', 'P2']),
)
# Two items with preceding OR and following AND
self.assertEquals(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAZ == Parameter('P1')).Or(
self.schema.FOO.BAR.In(Parameter("names", len(items))).And(
self.schema.FOO.BAZ == Parameter('P2')
))
).toSQL().bind(P1="P1", P2="P2", names=items),
SQLFragment(
"select * from FOO where BAZ = ? or BAR in (?, ?) and BAZ = ?", ['P1', 'A', 'B', 'P2']),
)
# Check various error situations
# No count not allowed
self.assertRaises(DALError, self.schema.FOO.BAR.In, Parameter("names"))
# count=0 not allowed
self.assertRaises(DALError, Parameter, "names", 0)
# Mismatched count and len(items)
self.assertRaises(
DALError,
Select(From=self.schema.FOO, Where=self.schema.FOO.BAR.In(Parameter("names", len(items)))).toSQL().bind,
names=["a", "b", "c", ]
)
def test_max(self):
"""
L{Max}C{(column)} produces an object in the 'columns' clause that
renders the 'max' aggregate in SQL.
"""
self.assertEquals(
Select([Max(self.schema.BOZ.QUX)], From=self.schema.BOZ).toSQL(),
SQLFragment(
"select max(QUX) from BOZ"))
def test_countAllCoumns(self):
"""
L{Count}C{(ALL_COLUMNS)} produces an object in the 'columns' clause that
renders the 'count' in SQL.
"""
self.assertEquals(
Select([Count(ALL_COLUMNS)], From=self.schema.BOZ).toSQL(),
SQLFragment(
"select count(*) from BOZ"))
def test_aggregateComparison(self):
"""
L{Max}C{(column) > constant} produces an object in the 'columns' clause
that renders a comparison to the 'max' aggregate in SQL.
"""
self.assertEquals(Select([Max(self.schema.BOZ.QUX) + 12],
From=self.schema.BOZ).toSQL(),
SQLFragment("select max(QUX) + ? from BOZ", [12]))
def test_multiColumnExpression(self):
"""
Multiple columns may be provided in an expression in the 'columns'
portion of a Select() statement. All arithmetic operators are
supported.
"""
self.assertEquals(
Select([((self.schema.FOO.BAR + self.schema.FOO.BAZ) / 3) * 7],
From=self.schema.FOO).toSQL(),
SQLFragment("select ((BAR + BAZ) / ?) * ? from FOO", [3, 7])
)
def test_len(self):
"""
Test for the 'Len' function for determining character length of a
column. (Note that this should be updated to use different techniques
as necessary in different databases.)
"""
self.assertEquals(
Select([Len(self.schema.TEXTUAL.MYTEXT)],
From=self.schema.TEXTUAL).toSQL(),
SQLFragment(
"select character_length(MYTEXT) from TEXTUAL"))
def test_startswith(self):
"""
Test for the string starts with comparison.
(Note that this should be updated to use different techniques
as necessary in different databases.)
"""
self.assertEquals(
Select([
self.schema.TEXTUAL.MYTEXT],
From=self.schema.TEXTUAL,
Where=self.schema.TEXTUAL.MYTEXT.StartsWith("test"),
).toSQL(),
SQLFragment(
"select MYTEXT from TEXTUAL where MYTEXT like (? || ?)",
["test", "%"]
)
)
def test_endswith(self):
"""
Test for the string starts with comparison.
(Note that this should be updated to use different techniques
as necessary in different databases.)
"""
self.assertEquals(
Select([
self.schema.TEXTUAL.MYTEXT],
From=self.schema.TEXTUAL,
Where=self.schema.TEXTUAL.MYTEXT.EndsWith("test"),
).toSQL(),
SQLFragment(
"select MYTEXT from TEXTUAL where MYTEXT like (? || ?)",
["%", "test"]
)
)
def test_contains(self):
"""
Test for the string starts with comparison.
(Note that this should be updated to use different techniques
as necessary in different databases.)
"""
self.assertEquals(
Select([
self.schema.TEXTUAL.MYTEXT],
From=self.schema.TEXTUAL,
Where=self.schema.TEXTUAL.MYTEXT.Contains("test"),
).toSQL(),
SQLFragment(
"select MYTEXT from TEXTUAL where MYTEXT like (? || (? || ?))",
["%", "test", "%"]
)
)
def test_insert(self):
"""
L{Insert.toSQL} generates an 'insert' statement with all the relevant
columns.
"""
self.assertEquals(
Insert({self.schema.FOO.BAR: 23,
self.schema.FOO.BAZ: 9}).toSQL(),
SQLFragment("insert into FOO (BAR, BAZ) values (?, ?)", [23, 9]))
def test_insertNotEnough(self):
"""
L{Insert}'s constructor will raise L{NotEnoughValues} if columns have
not been specified.
"""
notEnough = self.assertRaises(
NotEnoughValues, Insert, {self.schema.OTHER.BAR: 9}
)
self.assertEquals(str(notEnough), "Columns [FOO_BAR] required.")
def test_insertReturning(self):
"""
L{Insert}'s C{Return} argument will insert an SQL 'returning' clause.
"""
self.assertEquals(
Insert({self.schema.FOO.BAR: 23,
self.schema.FOO.BAZ: 9},
Return=self.schema.FOO.BAR).toSQL(),
SQLFragment(
"insert into FOO (BAR, BAZ) values (?, ?) returning BAR",
[23, 9])
)
def test_insertMultiReturn(self):
"""
L{Insert}'s C{Return} argument can also be a C{tuple}, which will insert
an SQL 'returning' clause with multiple columns.
"""
self.assertEquals(
Insert({self.schema.FOO.BAR: 23,
self.schema.FOO.BAZ: 9},
Return=(self.schema.FOO.BAR, self.schema.FOO.BAZ)).toSQL(),
SQLFragment(
"insert into FOO (BAR, BAZ) values (?, ?) returning BAR, BAZ",
[23, 9])
)
def test_insertMultiReturnOracle(self):
"""
In Oracle's SQL dialect, the 'returning' clause requires an 'into'
clause indicating where to put the results, as they can't be simply
relayed to the cursor. Further, additional bound variables are required
to capture the output parameters.
"""
self.assertEquals(
Insert({self.schema.FOO.BAR: 40,
self.schema.FOO.BAZ: 50},
Return=(self.schema.FOO.BAR, self.schema.FOO.BAZ)).toSQL(
QueryGenerator(ORACLE_DIALECT, NumericPlaceholder())
),
SQLFragment(
"insert into FOO (BAR, BAZ) values (:1, :2) returning BAR, BAZ"
" into :3, :4",
[40, 50, Parameter("oracle_out_0"), Parameter("oracle_out_1")]
)
)
def test_insertMultiReturnSQLite(self):
"""
In SQLite's SQL dialect, there is no 'returning' clause, but given that
SQLite serializes all SQL transactions, you can rely upon 'select'
after a write operation to reliably give you exactly what was just
modified. Therefore, although 'toSQL' won't include any indication of
the return value, the 'on' method will execute a 'select' statement
following the insert to retrieve the value.
"""
insertStatement = Insert({self.schema.FOO.BAR: 39,
self.schema.FOO.BAZ: 82},
Return=(self.schema.FOO.BAR, self.schema.FOO.BAZ)
)
qg = lambda : QueryGenerator(SQLITE_DIALECT, NumericPlaceholder())
self.assertEquals(insertStatement.toSQL(qg()),
SQLFragment("insert into FOO (BAR, BAZ) values (:1, :2)",
[39, 82])
)
result = []
csql = CatchSQL()
insertStatement.on(csql).addCallback(result.append)
self.assertEqual(result, [2])
self.assertEqual(
csql.execed,
[["insert into FOO (BAR, BAZ) values (:1, :2)", [39, 82]],
["select BAR, BAZ from FOO where rowid = last_insert_rowid()", []]]
)
def test_insertNoReturnSQLite(self):
"""
Insert a row I{without} a C{Return=} parameter should also work as
normal in sqlite.
"""
statement = Insert({self.schema.FOO.BAR: 12,
self.schema.FOO.BAZ: 48})
csql = CatchSQL()
statement.on(csql)
self.assertEqual(
csql.execed,
[["insert into FOO (BAR, BAZ) values (:1, :2)", [12, 48]]]
)
def test_updateReturningSQLite(self):
"""
Since SQLite does not support the SQL 'returning' syntax extension, in
order to preserve the rows that will be modified during an UPDATE
statement, we must first find the rows that will be affected, then
update them, then return the rows that were affected. Since we might
be changing even part of the primary key, we use the internal 'rowid'
column to uniquely and reliably identify rows in the sqlite database
that have been modified.
"""
csql = CatchSQL()
stmt = Update({self.schema.FOO.BAR: 4321},
Where=self.schema.FOO.BAZ == 1234,
Return=self.schema.FOO.BAR)
csql.nextResult([["sample row id"]])
result = resultOf(stmt.on(csql))
# Three statements were executed; make sure that the result returned was
# the result of executing the 3rd (and final) one.
self.assertResultList(result, 3)
# Check that they were the right statements.
self.assertEqual(len(csql.execed), 3)
self.assertEqual(
csql.execed[0],
["select rowid from FOO where BAZ = :1", [1234]]
)
self.assertEqual(
csql.execed[1],
["update FOO set BAR = :1 where BAZ = :2", [4321, 1234]]
)
self.assertEqual(
csql.execed[2],
["select BAR from FOO where rowid = :1", ["sample row id"]]
)
def test_updateReturningMultipleValuesSQLite(self):
"""
When SQLite updates multiple values, it must embed the row ID of each
subsequent value into its second 'where' clause, as there is no way to
pass a list of values to a single statement..
"""
csql = CatchSQL()
stmt = Update({self.schema.FOO.BAR: 4321},
Where=self.schema.FOO.BAZ == 1234,
Return=self.schema.FOO.BAR)
csql.nextResult([["one row id"], ["and another"], ["and one more"]])
result = resultOf(stmt.on(csql))
# Three statements were executed; make sure that the result returned was
# the result of executing the 3rd (and final) one.
self.assertResultList(result, 3)
# Check that they were the right statements.
self.assertEqual(len(csql.execed), 3)
self.assertEqual(
csql.execed[0],
["select rowid from FOO where BAZ = :1", [1234]]
)
self.assertEqual(
csql.execed[1],
["update FOO set BAR = :1 where BAZ = :2", [4321, 1234]]
)
self.assertEqual(
csql.execed[2],
["select BAR from FOO where rowid = :1 or rowid = :2 or rowid = :3",
["one row id", "and another", "and one more"]]
)
def test_deleteReturningSQLite(self):
"""
When SQLite deletes a value, ...
"""
csql = CatchSQL()
stmt = Delete(From=self.schema.FOO, Where=self.schema.FOO.BAZ == 1234,
Return=self.schema.FOO.BAR)
result = resultOf(stmt.on(csql))
self.assertResultList(result, 1)
self.assertEqual(len(csql.execed), 2)
self.assertEqual(
csql.execed[0],
["select BAR from FOO where BAZ = :1", [1234]]
)
self.assertEqual(
csql.execed[1],
["delete from FOO where BAZ = :1", [1234]]
)
def test_insertMismatch(self):
"""
L{Insert} raises L{TableMismatch} if the columns specified aren't all
from the same table.
"""
self.assertRaises(
TableMismatch,
Insert, {self.schema.FOO.BAR: 23,
self.schema.FOO.BAZ: 9,
self.schema.TEXTUAL.MYTEXT: 'hello'}
)
def test_quotingOnKeywordConflict(self):
"""
'access' is a keyword, so although our schema parser will leniently
accept it, it must be quoted in any outgoing SQL. (This is only done in
the Oracle dialect, because it isn't necessary in postgres, and
idiosyncratic case-folding rules make it challenging to do it in both.)
"""
self.assertEquals(
Insert({self.schema.LEVELS.ACCESS: 1,
self.schema.LEVELS.USERNAME:
"hi"}).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
SQLFragment(
'insert into LEVELS ("ACCESS", USERNAME) values (?, ?)',
[1, "hi"])
)
self.assertEquals(
Insert({self.schema.LEVELS.ACCESS: 1,
self.schema.LEVELS.USERNAME:
"hi"}).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
SQLFragment(
'insert into LEVELS (ACCESS, USERNAME) values (?, ?)',
[1, "hi"])
)
def test_updateReturning(self):
"""
L{update}'s C{Return} argument will update an SQL 'returning' clause.
"""
self.assertEquals(
Update({self.schema.FOO.BAR: 23},
self.schema.FOO.BAZ == 43,
Return=self.schema.FOO.BAR).toSQL(),
SQLFragment(
"update FOO set BAR = ? where BAZ = ? returning BAR",
[23, 43])
)
def test_updateMismatch(self):
"""
L{Update} raises L{TableMismatch} if the columns specified aren't all
from the same table.
"""
self.assertRaises(
TableMismatch,
Update, {self.schema.FOO.BAR: 23,
self.schema.FOO.BAZ: 9,
self.schema.TEXTUAL.MYTEXT: 'hello'},
Where=self.schema.FOO.BAZ == 9
)
def test_updateFunction(self):
"""
L{Update} values may be L{FunctionInvocation}s, to update to computed
values in the database.
"""
sqlfunc = Function("hello")
self.assertEquals(
Update(
{self.schema.FOO.BAR: 23,
self.schema.FOO.BAZ: sqlfunc()},
Where=self.schema.FOO.BAZ == 9
).toSQL(),
SQLFragment("update FOO set BAR = ?, BAZ = hello() "
"where BAZ = ?", [23, 9])
)
def test_insertFunction(self):
"""
L{Update} values may be L{FunctionInvocation}s, to update to computed
values in the database.
"""
sqlfunc = Function("hello")
self.assertEquals(
Insert(
{self.schema.FOO.BAR: 23,
self.schema.FOO.BAZ: sqlfunc()},
).toSQL(),
SQLFragment("insert into FOO (BAR, BAZ) "
"values (?, hello())", [23])
)
def test_deleteReturning(self):
"""
L{Delete}'s C{Return} argument will delete an SQL 'returning' clause.
"""
self.assertEquals(
Delete(self.schema.FOO,
Where=self.schema.FOO.BAR == 7,
Return=self.schema.FOO.BAZ).toSQL(),
SQLFragment(
"delete from FOO where BAR = ? returning BAZ", [7])
)
def test_update(self):
"""
L{Update.toSQL} generates an 'update' statement.
"""
self.assertEquals(
Update({self.schema.FOO.BAR: 4321},
self.schema.FOO.BAZ == 1234).toSQL(),
SQLFragment("update FOO set BAR = ? where BAZ = ?", [4321, 1234]))
def test_delete(self):
"""
L{Delete} generates an SQL 'delete' statement.
"""
self.assertEquals(
Delete(self.schema.FOO,
Where=self.schema.FOO.BAR == 12).toSQL(),
SQLFragment(
"delete from FOO where BAR = ?", [12])
)
self.assertEquals(
Delete(self.schema.FOO,
Where=None).toSQL(),
SQLFragment("delete from FOO")
)
def test_lock(self):
"""
L{Lock.exclusive} generates a ('lock table') statement, locking the
table in the specified mode.
"""
self.assertEquals(Lock.exclusive(self.schema.FOO).toSQL(),
SQLFragment("lock table FOO in exclusive mode"))
def test_databaseLock(self):
"""
L{DatabaseLock} generates a ('pg_advisory_lock') statement
"""
self.assertEquals(DatabaseLock().toSQL(),
SQLFragment("select pg_advisory_lock(1)"))
def test_databaseUnlock(self):
"""
L{DatabaseUnlock} generates a ('pg_advisory_unlock') statement
"""
self.assertEquals(DatabaseUnlock().toSQL(),
SQLFragment("select pg_advisory_unlock(1)"))
def test_savepoint(self):
"""
L{Savepoint} generates a ('savepoint') statement.
"""
self.assertEquals(Savepoint("test").toSQL(),
SQLFragment("savepoint test"))
def test_rollbacktosavepoint(self):
"""
L{RollbackToSavepoint} generates a ('rollback to savepoint') statement.
"""
self.assertEquals(RollbackToSavepoint("test").toSQL(),
SQLFragment("rollback to savepoint test"))
def test_releasesavepoint(self):
"""
L{ReleaseSavepoint} generates a ('release savepoint') statement.
"""
self.assertEquals(ReleaseSavepoint("test").toSQL(),
SQLFragment("release savepoint test"))
def test_savepointaction(self):
"""
L{SavepointAction} generates a ('savepoint') statement.
"""
self.assertEquals(SavepointAction("test")._name, "test")
def test_limit(self):
"""
A L{Select} object with a 'Limit' keyword parameter will generate
a SQL statement with a 'limit' clause.
"""
self.assertEquals(
Select([self.schema.FOO.BAR],
From=self.schema.FOO,
Limit=123).toSQL(),
SQLFragment(
"select BAR from FOO limit ?", [123]))
def test_limitOracle(self):
"""
A L{Select} object with a 'Limit' keyword parameter will generate a SQL
statement using a ROWNUM subquery for Oracle.
See U{this "ask tom" article from 2006 for more
information
}.
"""
self.assertEquals(
Select([self.schema.FOO.BAR],
From=self.schema.FOO,
Limit=123).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
SQLFragment(
"select * from (select BAR from FOO) "
"where ROWNUM <= ?", [123])
)
def test_having(self):
"""
A L{Select} object with a 'Having' keyword parameter will generate
a SQL statement with a 'having' expression.
"""
self.assertEquals(
Select([self.schema.FOO.BAR],
From=self.schema.FOO,
Having=Max(self.schema.FOO.BAZ) < 7).toSQL(),
SQLFragment("select BAR from FOO having max(BAZ) < ?", [7])
)
def test_distinct(self):
"""
A L{Select} object with a 'Disinct' keyword parameter with a value of
C{True} will generate a SQL statement with a 'distinct' keyword
preceding its list of columns.
"""
self.assertEquals(
Select([self.schema.FOO.BAR], From=self.schema.FOO,
Distinct=True).toSQL(),
SQLFragment("select distinct BAR from FOO")
)
def test_nextSequenceValue(self):
"""
When a sequence is used as a value in an expression, it renders as the
call to 'nextval' that will produce its next value.
"""
self.assertEquals(
Insert({self.schema.BOZ.QUX:
self.schema.A_SEQ}).toSQL(),
SQLFragment("insert into BOZ (QUX) values (nextval('A_SEQ'))", []))
def test_nextSequenceValueOracle(self):
"""
When a sequence is used as a value in an expression in the Oracle
dialect, it renders as the 'nextval' attribute of the appropriate
sequence.
"""
self.assertEquals(
Insert({self.schema.BOZ.QUX:
self.schema.A_SEQ}).toSQL(
QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
SQLFragment("insert into BOZ (QUX) values (A_SEQ.nextval)", []))
def test_nextSequenceDefaultImplicitExplicitOracle(self):
"""
In Oracle's dialect, sequence defaults can't be implemented without
using triggers, so instead we just explicitly always include the
sequence default value.
"""
addSQLToSchema(
schema=self.schema.model,
schemaData="create table DFLTR (a varchar(255), "
"b integer default nextval('A_SEQ'));"
)
self.assertEquals(
Insert({self.schema.DFLTR.a: 'hello'}).toSQL(
QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))
),
SQLFragment("insert into DFLTR (a, b) values "
"(?, A_SEQ.nextval)", ['hello']),
)
# Should be the same if it's explicitly specified.
self.assertEquals(
Insert({self.schema.DFLTR.a: 'hello',
self.schema.DFLTR.b: self.schema.A_SEQ}).toSQL(
QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))
),
SQLFragment("insert into DFLTR (a, b) values "
"(?, A_SEQ.nextval)", ['hello']),
)
def test_numericParams(self):
"""
An L{IAsyncTransaction} with the 'numeric' paramstyle attribute will
cause statements to be generated with parameters in the style of :1 :2
:3, as per the DB-API.
"""
stmts = []
class FakeOracleTxn(object):
def execSQL(self, text, params, exc):
stmts.append((text, params))
dialect = ORACLE_DIALECT
paramstyle = 'numeric'
Select([self.schema.FOO.BAR],
From=self.schema.FOO,
Where=(self.schema.FOO.BAR == 7).And(
self.schema.FOO.BAZ == 9)
).on(FakeOracleTxn())
self.assertEquals(
stmts, [("select BAR from FOO where BAR = :1 and BAZ = :2",
[7, 9])]
)
def test_rewriteOracleNULLs_Select(self):
"""
Oracle databases cannot distinguish between the empty string and
C{NULL}. When you insert an empty string, C{cx_Oracle} therefore treats
it as a C{None} and will return that when you select it back again. We
address this in the schema by dropping 'not null' constraints.
Therefore, when executing a statement which includes a string column,
'on' should rewrite None return values from C{cx_Oracle} to be empty
bytestrings, but only for string columns.
"""
rows = resultOf(
Select([self.schema.NULLCHECK.ASTRING,
self.schema.NULLCHECK.ANUMBER],
From=self.schema.NULLCHECK).on(NullTestingOracleTxn()))[0]
self.assertEquals(rows, [['', None]])
def test_rewriteOracleNULLs_SelectAllColumns(self):
"""
Same as L{test_rewriteOracleNULLs_Select}, but with the L{ALL_COLUMNS}
shortcut.
"""
rows = resultOf(
Select(From=self.schema.NULLCHECK).on(NullTestingOracleTxn())
)[0]
self.assertEquals(rows, [['', None]])
def test_nestedLogicalExpressions(self):
"""
Make sure that logical operator precedence inserts proper parenthesis
when needed. e.g. 'a.And(b.Or(c))' needs to be 'a and (b or c)' not 'a
and b or c'.
"""
self.assertEquals(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR != 7).
And(self.schema.FOO.BAZ != 8).
And((self.schema.FOO.BAR == 8).Or(self.schema.FOO.BAZ == 0))
).toSQL(),
SQLFragment("select * from FOO where BAR != ? and BAZ != ? and "
"(BAR = ? or BAZ = ?)", [7, 8, 8, 0]))
self.assertEquals(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR != 7).
Or(self.schema.FOO.BAZ != 8).
Or((self.schema.FOO.BAR == 8).And(self.schema.FOO.BAZ == 0))
).toSQL(),
SQLFragment("select * from FOO where BAR != ? or BAZ != ? or "
"BAR = ? and BAZ = ?", [7, 8, 8, 0]))
self.assertEquals(
Select(
From=self.schema.FOO,
Where=(self.schema.FOO.BAR != 7).
Or(self.schema.FOO.BAZ != 8).
And((self.schema.FOO.BAR == 8).Or(self.schema.FOO.BAZ == 0))
).toSQL(),
SQLFragment("select * from FOO where (BAR != ? or BAZ != ?) and "
"(BAR = ? or BAZ = ?)", [7, 8, 8, 0]))
def test_updateWithNULL(self):
"""
As per the DB-API specification, "SQL NULL values are represented by the
Python None singleton on input and output." When a C{None} is provided
as a value to an L{Update}, it will be relayed to the database as a
parameter.
"""
self.assertEquals(
Update({self.schema.BOZ.QUX: None},
Where=self.schema.BOZ.QUX == 7).toSQL(),
SQLFragment("update BOZ set QUX = ? where QUX = ?", [None, 7])
)
def test_subSelectComparison(self):
"""
A comparison of a column to a sub-select in a where clause will result
in a parenthetical 'Where' clause.
"""
self.assertEquals(
Update(
{self.schema.BOZ.QUX: 9},
Where=self.schema.BOZ.QUX ==
Select([self.schema.FOO.BAR], From=self.schema.FOO,
Where=self.schema.FOO.BAZ == 12)).toSQL(),
SQLFragment(
# NOTE: it's very important that the comparison _always_ go in
# this order (column from the UPDATE first, inner SELECT second)
# as the other order will be considered a syntax error.
"update BOZ set QUX = ? where QUX = ("
"select BAR from FOO where BAZ = ?)", [9, 12]
)
)
def test_tupleComparison(self):
"""
A L{Tuple} allows for simultaneous comparison of multiple values in a
C{Where} clause. This feature is particularly useful when issuing an
L{Update} or L{Delete}, where the comparison is with values from a
subselect. (A L{Tuple} will be automatically generated upon comparison
to a C{tuple} or C{list}.)
"""
self.assertEquals(
Update(
{self.schema.BOZ.QUX: 1},
Where=(self.schema.BOZ.QUX, self.schema.BOZ.QUUX) ==
Select([self.schema.FOO.BAR, self.schema.FOO.BAZ],
From=self.schema.FOO,
Where=self.schema.FOO.BAZ == 2)).toSQL(),
SQLFragment(
# NOTE: it's very important that the comparison _always_ go in
# this order (tuple of columns from the UPDATE first, inner
# SELECT second) as the other order will be considered a syntax
# error.
"update BOZ set QUX = ? where (QUX, QUUX) = ("
"select BAR, BAZ from FOO where BAZ = ?)", [1, 2]
)
)
def test_tupleOfConstantsComparison(self):
"""
For some reason Oracle requires multiple parentheses for comparisons.
"""
self.assertEquals(
Select(
[self.schema.FOO.BAR],
From=self.schema.FOO,
Where=(Tuple([self.schema.FOO.BAR, self.schema.FOO.BAZ]) ==
Tuple([Constant(7), Constant(9)]))
).toSQL(),
SQLFragment(
"select BAR from FOO where (BAR, BAZ) = ((?, ?))", [7, 9]
)
)
def test_oracleTableTruncation(self):
"""
L{Table}'s SQL generation logic will truncate table names if the dialect
(i.e. Oracle) demands it. (See txdav.common.datastore.sql_tables for
the schema translator and enforcement of name uniqueness in the derived
schema.)
"""
addSQLToSchema(
self.schema.model,
"create table veryveryveryveryveryveryveryverylong "
"(foo integer);"
)
vvl = self.schema.veryveryveryveryveryveryveryverylong
self.assertEquals(
Insert({vvl.foo: 1}).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
SQLFragment(
"insert into veryveryveryveryveryveryveryve (foo) values "
"(?)", [1]
)
)
def test_columnEqualityTruth(self):
"""
Mostly in support of test_columnsAsDictKeys, the 'same' column should
compare True to itself and False to other values.
"""
s = self.schema
self.assertEquals(bool(s.FOO.BAR == s.FOO.BAR), True)
self.assertEquals(bool(s.FOO.BAR != s.FOO.BAR), False)
self.assertEquals(bool(s.FOO.BAZ != s.FOO.BAR), True)
def test_columnsAsDictKeys(self):
"""
An odd corner of the syntactic sugar provided by the DAL is that the
column objects have to participate both in augmented equality comparison
("==" returns an expression object) as well as dictionary keys (for
Insert and Update statement objects). Therefore it should be possible
to I{manipulate} dictionaries of keys as well.
"""
values = {self.schema.FOO.BAR: 1}
self.assertEquals(values, {self.schema.FOO.BAR: 1})
values.pop(self.schema.FOO.BAR)
self.assertEquals(values, {})
class OracleConnectionMethods(object):
def test_rewriteOracleNULLs_Insert(self):
"""
The behavior described in L{test_rewriteOracleNULLs_Select} applies to
other statement types as well, specifically those with 'returning'
clauses.
"""
# Add 2 cursor variable values so that these will be used by
# FakeVariable.getvalue.
self.factory.varvals.extend([None, None])
rows = self.resultOf(
Insert({self.schema.NULLCHECK.ASTRING: '',
self.schema.NULLCHECK.ANUMBER: None},
Return=[self.schema.NULLCHECK.ASTRING,
self.schema.NULLCHECK.ANUMBER]
).on(self.createTransaction()))[0]
self.assertEquals(rows, [['', None]])
def test_insertMultiReturnOnOracleTxn(self):
"""
As described in L{test_insertMultiReturnOracle}, Oracle deals with
'returning' clauses by using out parameters. However, this is not quite
enough, as the code needs to actually retrieve the values from the out
parameters.
"""
i = Insert({self.schema.FOO.BAR: 40,
self.schema.FOO.BAZ: 50},
Return=(self.schema.FOO.BAR, self.schema.FOO.BAZ))
self.factory.varvals.extend(["first val!", "second val!"])
result = self.resultOf(i.on(self.createTransaction()))
self.assertEquals(result, [[["first val!", "second val!"]]])
curvars = self.factory.connections[0].cursors[0].variables
self.assertEquals(len(curvars), 2)
self.assertEquals(curvars[0].type, FakeCXOracleModule.NUMBER)
self.assertEquals(curvars[1].type, FakeCXOracleModule.STRING)
def test_insertNoReturnOracle(self):
"""
In addition to being able to execute insert statements with a Return
attribute, oracle also ought to be able to execute insert statements
with no Return at all.
"""
# This statement should return nothing from .fetchall(), so...
self.factory.hasResults = False
i = Insert({self.schema.FOO.BAR: 40,
self.schema.FOO.BAZ: 50})
result = self.resultOf(i.on(self.createTransaction()))
self.assertEquals(result, [None])
class OracleConnectionTests(ConnectionPoolHelper, ExampleSchemaHelper,
OracleConnectionMethods, TestCase):
"""
Tests which use an oracle connection.
"""
dialect = ORACLE_DIALECT
def setUp(self):
"""
Create a fake oracle-ish connection pool without using real threads or a
real database.
"""
self.patch(syntax, 'cx_Oracle', FakeCXOracleModule)
super(OracleConnectionTests, self).setUp()
ExampleSchemaHelper.setUp(self)
class OracleNetConnectionTests(NetworkedPoolHelper, ExampleSchemaHelper,
OracleConnectionMethods, TestCase):
dialect = ORACLE_DIALECT
def setUp(self):
self.patch(syntax, 'cx_Oracle', FakeCXOracleModule)
super(OracleNetConnectionTests, self).setUp()
ExampleSchemaHelper.setUp(self)
self.pump.client.dialect = ORACLE_DIALECT
calendarserver-5.2+dfsg/twext/enterprise/dal/test/__init__.py 0000644 0001750 0001750 00000001207 12263343324 023474 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for twext.enterprise.dal.
"""
calendarserver-5.2+dfsg/twext/enterprise/dal/record.py 0000644 0001750 0001750 00000030753 12263343324 022244 0 ustar rahul rahul # -*- test-case-name: twext.enterprise.dal.test.test_record -*-
##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
RECORD: Relational Entity Creation from Objects Representing Data.
This is an asynchronous object-relational mapper based on
L{twext.enterprise.dal.syntax}.
"""
from twisted.internet.defer import inlineCallbacks, returnValue
from twext.enterprise.dal.syntax import (
Select, Tuple, Constant, ColumnSyntax, Insert, Update, Delete
)
from twext.enterprise.util import parseSQLTimestamp
# from twext.enterprise.dal.syntax import ExpressionSyntax
class ReadOnly(AttributeError):
"""
A caller attempted to set an attribute on a database-backed record, rather
than updating it through L{Record.update}.
"""
def __init__(self, className, attributeName):
self.className = className
self.attributeName = attributeName
super(ReadOnly, self).__init__("SQL-backed attribute '{0}.{1}' is "
"read-only. Use '.update(...)' to "
"modify attributes."
.format(className, attributeName))
class NoSuchRecord(Exception):
"""
No matching record could be found.
"""
class _RecordMeta(type):
"""
Metaclass for associating a L{fromTable} with a L{Record} at inheritance
time.
"""
def __new__(cls, name, bases, ns):
"""
Create a new instance of this meta-type.
"""
newbases = []
table = None
namer = None
for base in bases:
if isinstance(base, fromTable):
if table is not None:
raise RuntimeError(
"Can't define a class from two or more tables at once."
)
table = base.table
elif getattr(base, "table", None) is not None:
raise RuntimeError(
"Can't define a record class by inheriting one already "
"mapped to a table."
# TODO: more info
)
else:
if namer is None:
if isinstance(base, _RecordMeta):
namer = base
newbases.append(base)
if table is not None:
attrmap = {}
colmap = {}
allColumns = list(table)
for column in allColumns:
attrname = namer.namingConvention(column.model.name)
attrmap[attrname] = column
colmap[column] = attrname
ns.update(table=table, __attrmap__=attrmap, __colmap__=colmap)
ns.update(attrmap)
return super(_RecordMeta, cls).__new__(cls, name, tuple(newbases), ns)
class fromTable(object):
"""
Inherit from this after L{Record} to specify which table your L{Record}
subclass is mapped to.
"""
def __init__(self, aTable):
"""
@param table: The table to map to.
@type table: L{twext.enterprise.dal.syntax.TableSyntax}
"""
self.table = aTable
class Record(object):
"""
Superclass for all database-backed record classes. (i.e. an object mapped
from a database record).
@cvar table: the table that represents this L{Record} in the database.
@type table: L{TableSyntax}
@ivar transaction: The L{IAsyncTransaction} where this record is being
loaded. This may be C{None} if this L{Record} is not participating in
a transaction, which may be true if it was instantiated but never
saved.
@cvar __colmap__: map of L{ColumnSyntax} objects to attribute names.
@type __colmap__: L{dict}
@cvar __attrmap__: map of attribute names to L{ColumnSyntax} objects.
@type __attrmap__: L{dict}
"""
__metaclass__ = _RecordMeta
transaction = None
def __setattr__(self, name, value):
"""
Once the transaction is initialized, this object is immutable. If you
want to change it, use L{Record.update}.
"""
if self.transaction is not None:
raise ReadOnly(self.__class__.__name__, name)
return super(Record, self).__setattr__(name, value)
def __repr__(self):
r = "<{0} record from table {1}".format(self.__class__.__name__,
self.table.model.name)
for k in sorted(self.__attrmap__.keys()):
r += " {0}={1}".format(k, repr(getattr(self, k)))
r += ">"
return r
@staticmethod
def namingConvention(columnName):
"""
Implement the convention for naming-conversion between column names
(typically, upper-case database names map to lower-case attribute
names).
"""
words = columnName.lower().split("_")
def cap(word):
if word.lower() == 'id':
return word.upper()
else:
return word.capitalize()
return words[0] + "".join(map(cap, words[1:]))
@classmethod
def _primaryKeyExpression(cls):
return Tuple([ColumnSyntax(c) for c in cls.table.model.primaryKey])
def _primaryKeyValue(self):
val = []
for col in self._primaryKeyExpression().columns:
val.append(getattr(self, self.__class__.__colmap__[col]))
return val
@classmethod
def _primaryKeyComparison(cls, primaryKey):
return (cls._primaryKeyExpression() ==
Tuple(map(Constant, primaryKey)))
@classmethod
@inlineCallbacks
def load(cls, transaction, *primaryKey):
results = (yield cls.query(transaction,
cls._primaryKeyComparison(primaryKey)))
if len(results) != 1:
raise NoSuchRecord()
else:
returnValue(results[0])
@classmethod
@inlineCallbacks
def create(cls, transaction, **k):
"""
Create a row.
Used like this::
MyRecord.create(transaction, column1=1, column2=u'two')
"""
self = cls()
colmap = {}
attrtocol = cls.__attrmap__
needsCols = []
needsAttrs = []
for attr in attrtocol:
col = attrtocol[attr]
if attr in k:
setattr(self, attr, k[attr])
colmap[col] = k.pop(attr)
else:
if col.model.needsValue():
raise TypeError("required attribute " + repr(attr) +
" not passed")
else:
needsCols.append(col)
needsAttrs.append(attr)
if k:
raise TypeError("received unknown attribute{0}: {1}".format(
"s" if len(k) > 1 else "", ", ".join(sorted(k))
))
result = yield (Insert(colmap, Return=needsCols if needsCols else None)
.on(transaction))
if needsCols:
self._attributesFromRow(zip(needsAttrs, result[0]))
self.transaction = transaction
returnValue(self)
def _attributesFromRow(self, attributeList):
"""
Take some data loaded from a row and apply it to this instance,
converting types as necessary.
@param attributeList: a C{list} of 2-C{tuples} of C{(attributeName,
attributeValue)}.
"""
for setAttribute, setValue in attributeList:
setColumn = self.__attrmap__[setAttribute]
if setColumn.model.type.name == "timestamp":
setValue = parseSQLTimestamp(setValue)
setattr(self, setAttribute, setValue)
def delete(self):
"""
Delete this row from the database.
@return: a L{Deferred} which fires with C{None} when the underlying row
has been deleted, or fails with L{NoSuchRecord} if the underlying
row was already deleted.
"""
return Delete(From=self.table,
Where=self._primaryKeyComparison(self._primaryKeyValue())
).on(self.transaction, raiseOnZeroRowCount=NoSuchRecord)
@inlineCallbacks
def update(self, **kw):
"""
Modify the given attributes in the database.
@return: a L{Deferred} that fires when the updates have been sent to
the database.
"""
colmap = {}
for k, v in kw.iteritems():
colmap[self.__attrmap__[k]] = v
yield (Update(colmap,
Where=self._primaryKeyComparison(self._primaryKeyValue()))
.on(self.transaction))
self.__dict__.update(kw)
@classmethod
def pop(cls, transaction, *primaryKey):
"""
Atomically retrieve and remove a row from this L{Record}'s table
with a primary key value of C{primaryKey}.
@return: a L{Deferred} that fires with an instance of C{cls}, or fails
with L{NoSuchRecord} if there were no records in the database.
@rtype: L{Deferred}
"""
return cls._rowsFromQuery(
transaction, Delete(Where=cls._primaryKeyComparison(primaryKey),
From=cls.table, Return=list(cls.table)),
lambda : NoSuchRecord()
).addCallback(lambda x: x[0])
@classmethod
def query(cls, transaction, expr, order=None, ascending=True, group=None):
"""
Query the table that corresponds to C{cls}, and return instances of
C{cls} corresponding to the rows that are returned from that table.
@param expr: An L{ExpressionSyntax} that constraints the results of the
query. This is most easily produced by accessing attributes on the
class; for example, C{MyRecordType.query((MyRecordType.col1 >
MyRecordType.col2).And(MyRecordType.col3 == 7))}
@param order: A L{ColumnSyntax} to order the resulting record objects
by.
@param ascending: A boolean; if C{order} is not C{None}, whether to
sort in ascending or descending order.
@param group: a L{ColumnSyntax} to group the resulting record objects
by.
"""
kw = {}
if order is not None:
kw.update(OrderBy=order, Ascending=ascending)
if group is not None:
kw.update(GroupBy=group)
return cls._rowsFromQuery(transaction, Select(list(cls.table),
From=cls.table,
Where=expr, **kw), None)
@classmethod
def all(cls, transaction):
"""
Load all rows from the table that corresponds to C{cls} and return
instances of C{cls} corresponding to all.
"""
return cls._rowsFromQuery(transaction,
Select(list(cls.table),
From=cls.table,
OrderBy=cls._primaryKeyExpression()),
None)
@classmethod
@inlineCallbacks
def _rowsFromQuery(cls, transaction, qry, rozrc):
"""
Execute the given query, and transform its results into instances of
C{cls}.
@param transaction: an L{IAsyncTransaction} to execute the query on.
@param qry: a L{_DMLStatement} (XXX: maybe _DMLStatement or some
interface that defines 'on' should be public?) whose results are
the list of columns in C{self.table}.
@param rozrc: The C{raiseOnZeroRowCount} argument.
@return: a L{Deferred} that succeeds with a C{list} of instances of
C{cls} or fails with an exception produced by C{rozrc}.
"""
rows = yield qry.on(transaction, raiseOnZeroRowCount=rozrc)
selves = []
names = [cls.__colmap__[column] for column in list(cls.table)]
for row in rows:
self = cls()
self._attributesFromRow(zip(names, row))
self.transaction = transaction
selves.append(self)
returnValue(selves)
__all__ = [
"ReadOnly",
"fromTable",
"NoSuchRecord",
]
calendarserver-5.2+dfsg/twext/enterprise/dal/parseschema.py 0000644 0001750 0001750 00000052270 12263343324 023257 0 ustar rahul rahul # -*- test-case-name: twext.enterprise.dal.test.test_parseschema -*-
##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
from __future__ import print_function
"""
Parser for SQL schema.
"""
from itertools import chain
from sqlparse import parse, keywords
from sqlparse.tokens import (
Keyword, Punctuation, Number, String, Name, Comparison as CompTok
)
from sqlparse.sql import (Comment, Identifier, Parenthesis, IdentifierList,
Function, Comparison)
from twext.enterprise.dal.model import (
Schema, Table, SQLType, ProcedureCall, Constraint, Sequence, Index)
from twext.enterprise.dal.syntax import (
ColumnSyntax, CompoundComparison, Constant, Function as FunctionSyntax
)
def _fixKeywords():
"""
Work around bugs in SQLParse, adding SEQUENCE as a keyword (since it is
treated as one in postgres) and removing ACCESS and SIZE (since we use those
as column names). Technically those are keywords in SQL, but they aren't
treated as such by postgres's parser.
"""
keywords.KEYWORDS['SEQUENCE'] = Keyword
for columnNameKeyword in ['ACCESS', 'SIZE']:
del keywords.KEYWORDS[columnNameKeyword]
_fixKeywords()
def tableFromCreateStatement(schema, stmt):
"""
Add a table from a CREATE TABLE sqlparse statement object.
@param schema: The schema to add the table statement to.
@type schema: L{Schema}
@param stmt: The C{CREATE TABLE} statement object.
@type stmt: L{Statement}
"""
i = iterSignificant(stmt)
expect(i, ttype=Keyword.DDL, value='CREATE')
expect(i, ttype=Keyword, value='TABLE')
function = expect(i, cls=Function)
i = iterSignificant(function)
name = expect(i, cls=Identifier).get_name().encode('utf-8')
self = Table(schema, name)
parens = expect(i, cls=Parenthesis)
cp = _ColumnParser(self, iterSignificant(parens), parens)
cp.parse()
return self
def schemaFromPath(path):
"""
Get a L{Schema}.
@param path: a L{FilePath}-like object containing SQL.
@return: a L{Schema} object with the contents of the given C{path} parsed
and added to it as L{Table} objects.
"""
schema = Schema(path.basename())
schemaData = path.getContent()
addSQLToSchema(schema, schemaData)
return schema
def schemaFromString(data):
"""
Get a L{Schema}.
@param data: a C{str} containing SQL.
@return: a L{Schema} object with the contents of the given C{str} parsed
and added to it as L{Table} objects.
"""
schema = Schema()
addSQLToSchema(schema, data)
return schema
def addSQLToSchema(schema, schemaData):
"""
Add new SQL to an existing schema.
@param schema: The schema to add the new SQL to.
@type schema: L{Schema}
@param schemaData: A string containing some SQL statements.
@type schemaData: C{str}
@return: the C{schema} argument
"""
parsed = parse(schemaData)
for stmt in parsed:
preface = ''
while stmt.tokens and not significant(stmt.tokens[0]):
preface += str(stmt.tokens.pop(0))
if not stmt.tokens:
continue
if stmt.get_type() == 'CREATE':
createType = stmt.token_next(1, True).value.upper()
if createType == u'TABLE':
t = tableFromCreateStatement(schema, stmt)
t.addComment(preface)
elif createType == u'SEQUENCE':
Sequence(schema,
stmt.token_next(2, True).get_name().encode('utf-8'))
elif createType in (u'INDEX', u'UNIQUE'):
signifindex = iterSignificant(stmt)
expect(signifindex, ttype=Keyword.DDL, value='CREATE')
token = signifindex.next()
unique = False
if token.match(Keyword, "UNIQUE"):
unique = True
token = signifindex.next()
if not token.match(Keyword, "INDEX"):
raise ViolatedExpectation("INDEX or UNQIUE", token.value)
indexName = nameOrIdentifier(signifindex.next())
expect(signifindex, ttype=Keyword, value='ON')
token = signifindex.next()
if isinstance(token, Function):
[tableName, columnArgs] = iterSignificant(token)
else:
tableName = token
token = signifindex.next()
if token.match(Keyword, "USING"):
[_ignore, columnArgs] = iterSignificant(expect(signifindex, cls=Function))
else:
raise ViolatedExpectation('USING', token)
tableName = nameOrIdentifier(tableName)
arggetter = iterSignificant(columnArgs)
expect(arggetter, ttype=Punctuation, value=u'(')
valueOrValues = arggetter.next()
if isinstance(valueOrValues, IdentifierList):
valuelist = valueOrValues.get_identifiers()
else:
valuelist = [valueOrValues]
expect(arggetter, ttype=Punctuation, value=u')')
idx = Index(schema, indexName, schema.tableNamed(tableName), unique)
for token in valuelist:
columnName = nameOrIdentifier(token)
idx.addColumn(idx.table.columnNamed(columnName))
elif stmt.get_type() == 'INSERT':
insertTokens = iterSignificant(stmt)
expect(insertTokens, ttype=Keyword.DML, value='INSERT')
expect(insertTokens, ttype=Keyword, value='INTO')
tableName = expect(insertTokens, cls=Identifier).get_name()
expect(insertTokens, ttype=Keyword, value='VALUES')
values = expect(insertTokens, cls=Parenthesis)
vals = iterSignificant(values)
expect(vals, ttype=Punctuation, value='(')
valuelist = expect(vals, cls=IdentifierList)
expect(vals, ttype=Punctuation, value=')')
rowData = []
for ident in valuelist.get_identifiers():
rowData.append(
{Number.Integer: int,
String.Single: _destringify}
[ident.ttype](ident.value)
)
schema.tableNamed(tableName).insertSchemaRow(rowData)
else:
print('unknown type:', stmt.get_type())
return schema
class _ColumnParser(object):
"""
Stateful parser for the things between commas.
"""
def __init__(self, table, parenIter, parens):
"""
@param table: the L{Table} to add data to.
@param parenIter: the iterator.
"""
self.parens = parens
self.iter = parenIter
self.table = table
def __iter__(self):
"""
This object is an iterator; return itself.
"""
return self
def next(self):
"""
Get the next L{IdentifierList}.
"""
result = self.iter.next()
if isinstance(result, IdentifierList):
# Expand out all identifier lists, since they seem to pop up
# incorrectly. We should never see one in a column list anyway.
# http://code.google.com/p/python-sqlparse/issues/detail?id=25
while result.tokens:
it = result.tokens.pop()
if significant(it):
self.pushback(it)
return self.next()
return result
def pushback(self, value):
"""
Push the value back onto this iterator so it will be returned by the
next call to C{next}.
"""
self.iter = chain(iter((value,)), self.iter)
def parse(self):
"""
Parse everything.
"""
expect(self.iter, ttype=Punctuation, value=u"(")
while self.nextColumn():
pass
def nextColumn(self):
"""
Parse the next column or constraint, depending on the next token.
"""
maybeIdent = self.next()
if maybeIdent.ttype == Name:
return self.parseColumn(maybeIdent.value)
elif isinstance(maybeIdent, Identifier):
return self.parseColumn(maybeIdent.get_name())
else:
return self.parseConstraint(maybeIdent)
def namesInParens(self, parens):
parens = iterSignificant(parens)
expect(parens, ttype=Punctuation, value="(")
idorids = parens.next()
if isinstance(idorids, Identifier):
idnames = [idorids.get_name()]
elif isinstance(idorids, IdentifierList):
idnames = [x.get_name() for x in idorids.get_identifiers()]
else:
raise ViolatedExpectation("identifier or list", repr(idorids))
expect(parens, ttype=Punctuation, value=")")
return idnames
def readExpression(self, parens):
"""
Read a given expression from a Parenthesis object. (This is currently
a limited parser in support of simple CHECK constraints, not something
suitable for a full WHERE Clause.)
"""
parens = iterSignificant(parens)
expect(parens, ttype=Punctuation, value="(")
nexttok = parens.next()
if isinstance(nexttok, Comparison):
lhs, op, rhs = list(iterSignificant(nexttok))
result = CompoundComparison(self.nameOrValue(lhs),
op.value.encode("ascii"),
self.nameOrValue(rhs))
elif isinstance(nexttok, Identifier):
# our version of SQLParse seems to break down and not create a nice
# "Comparison" object when a keyword is present. This is just a
# simple workaround.
lhs = self.nameOrValue(nexttok)
op = expect(parens, ttype=CompTok).value.encode("ascii")
funcName = expect(parens, ttype=Keyword).value.encode("ascii")
rhs = FunctionSyntax(funcName)(*[
ColumnSyntax(self.table.columnNamed(x)) for x in
self.namesInParens(expect(parens, cls=Parenthesis))
])
result = CompoundComparison(lhs, op, rhs)
expect(parens, ttype=Punctuation, value=")")
return result
def nameOrValue(self, tok):
"""
Inspecting a token present in an expression (for a CHECK constraint on
this table), return a L{twext.enterprise.dal.syntax} object for that
value.
"""
if isinstance(tok, Identifier):
return ColumnSyntax(self.table.columnNamed(tok.get_name()))
elif tok.ttype == Number.Integer:
return Constant(int(tok.value))
def parseConstraint(self, constraintType):
"""
Parse a 'free' constraint, described explicitly in the table as opposed
to being implicitly associated with a column by being placed after it.
"""
ident = None
# TODO: make use of identifier in tableConstraint, currently only used
# for checkConstraint.
if constraintType.match(Keyword, 'CONSTRAINT'):
ident = expect(self, cls=Identifier).get_name()
constraintType = expect(self, ttype=Keyword)
if constraintType.match(Keyword, 'PRIMARY'):
expect(self, ttype=Keyword, value='KEY')
names = self.namesInParens(expect(self, cls=Parenthesis))
self.table.primaryKey = [self.table.columnNamed(n) for n in names]
elif constraintType.match(Keyword, 'UNIQUE'):
names = self.namesInParens(expect(self, cls=Parenthesis))
self.table.tableConstraint(Constraint.UNIQUE, names)
elif constraintType.match(Keyword, 'CHECK'):
self.table.checkConstraint(self.readExpression(self.next()), ident)
else:
raise ViolatedExpectation('PRIMARY or UNIQUE', constraintType)
return self.checkEnd(self.next())
def checkEnd(self, val):
"""
After a column or constraint, check the end.
"""
if val.value == u",":
return True
elif val.value == u")":
return False
else:
raise ViolatedExpectation(", or )", val)
def parseColumn(self, name):
"""
Parse a column with the given name.
"""
typeName = self.next()
if isinstance(typeName, Function):
[funcIdent, args] = iterSignificant(typeName)
typeName = funcIdent
arggetter = iterSignificant(args)
expect(arggetter, value=u'(')
typeLength = int(expect(arggetter,
ttype=Number.Integer).value.encode('utf-8'))
else:
maybeTypeArgs = self.next()
if isinstance(maybeTypeArgs, Parenthesis):
# type arguments
significant = iterSignificant(maybeTypeArgs)
expect(significant, value=u"(")
typeLength = int(significant.next().value)
else:
# something else
typeLength = None
self.pushback(maybeTypeArgs)
theType = SQLType(typeName.value.encode("utf-8"), typeLength)
theColumn = self.table.addColumn(
name=name.encode("utf-8"), type=theType
)
for val in self:
if val.ttype == Punctuation:
return self.checkEnd(val)
else:
expected = True
def oneConstraint(t):
self.table.tableConstraint(t, [theColumn.name])
if val.match(Keyword, 'PRIMARY'):
expect(self, ttype=Keyword, value='KEY')
# XXX check to make sure there's no other primary key yet
self.table.primaryKey = [theColumn]
elif val.match(Keyword, 'UNIQUE'):
# XXX add UNIQUE constraint
oneConstraint(Constraint.UNIQUE)
elif val.match(Keyword, 'NOT'):
# possibly not necessary, as 'NOT NULL' is a single keyword
# in sqlparse as of 0.1.2
expect(self, ttype=Keyword, value='NULL')
oneConstraint(Constraint.NOT_NULL)
elif val.match(Keyword, 'NOT NULL'):
oneConstraint(Constraint.NOT_NULL)
elif val.match(Keyword, 'CHECK'):
self.table.checkConstraint(self.readExpression(self.next()))
elif val.match(Keyword, 'DEFAULT'):
theDefault = self.next()
if isinstance(theDefault, Parenthesis):
iDefault = iterSignificant(theDefault)
expect(iDefault, ttype=Punctuation, value="(")
theDefault = iDefault.next()
if isinstance(theDefault, Function):
thingo = theDefault.tokens[0].get_name()
parens = expectSingle(
theDefault.tokens[-1], cls=Parenthesis
)
pareniter = iterSignificant(parens)
if thingo.upper() == 'NEXTVAL':
expect(pareniter, ttype=Punctuation, value="(")
seqname = _destringify(
expect(pareniter, ttype=String.Single).value)
defaultValue = self.table.schema.sequenceNamed(
seqname
)
defaultValue.referringColumns.append(theColumn)
else:
defaultValue = ProcedureCall(thingo.encode('utf-8'),
parens)
elif theDefault.ttype == Number.Integer:
defaultValue = int(theDefault.value)
elif (theDefault.ttype == Keyword and
theDefault.value.lower() == 'false'):
defaultValue = False
elif (theDefault.ttype == Keyword and
theDefault.value.lower() == 'true'):
defaultValue = True
elif (theDefault.ttype == Keyword and
theDefault.value.lower() == 'null'):
defaultValue = None
elif theDefault.ttype == String.Single:
defaultValue = _destringify(theDefault.value)
else:
raise RuntimeError(
"not sure what to do: default %r" % (
theDefault))
theColumn.setDefaultValue(defaultValue)
elif val.match(Keyword, 'REFERENCES'):
target = nameOrIdentifier(self.next())
theColumn.doesReferenceName(target)
elif val.match(Keyword, 'ON'):
expect(self, ttype=Keyword.DML, value='DELETE')
refAction = self.next()
if refAction.ttype == Keyword and refAction.value.upper() == 'CASCADE':
theColumn.deleteAction = 'cascade'
elif refAction.ttype == Keyword and refAction.value.upper() == 'SET':
setAction = self.next()
if setAction.ttype == Keyword and setAction.value.upper() == 'NULL':
theColumn.deleteAction = 'set null'
elif setAction.ttype == Keyword and setAction.value.upper() == 'DEFAULT':
theColumn.deleteAction = 'set default'
else:
raise RuntimeError("Invalid on delete set %r" % (setAction.value,))
else:
raise RuntimeError("Invalid on delete %r" % (refAction.value,))
else:
expected = False
if not expected:
print('UNEXPECTED TOKEN:', repr(val), theColumn)
print(self.parens)
import pprint
pprint.pprint(self.parens.tokens)
return 0
class ViolatedExpectation(Exception):
"""
An expectation about the structure of the SQL syntax was violated.
"""
def __init__(self, expected, got):
self.expected = expected
self.got = got
super(ViolatedExpectation, self).__init__(
"Expected %r got %s" % (expected, got)
)
def nameOrIdentifier(token):
"""
Determine if the given object is a name or an identifier, and return the
textual value of that name or identifier.
@rtype: L{str}
"""
if isinstance(token, Identifier):
return token.get_name()
elif token.ttype == Name:
return token.value
else:
raise ViolatedExpectation("identifier or name", repr(token))
def expectSingle(nextval, ttype=None, value=None, cls=None):
"""
Expect some properties from retrieved value.
@param ttype: A token type to compare against.
@param value: A value to compare against.
@param cls: A class to check if the value is an instance of.
@raise ViolatedExpectation: if an unexpected token is found.
@return: C{nextval}, if it matches.
"""
if ttype is not None:
if nextval.ttype != ttype:
raise ViolatedExpectation(ttype, '%s:%r' % (nextval.ttype, nextval))
if value is not None:
if nextval.value.upper() != value.upper():
raise ViolatedExpectation(value, nextval.value)
if cls is not None:
if nextval.__class__ != cls:
raise ViolatedExpectation(cls, '%s:%r' %
(nextval.__class__.__name__, nextval))
return nextval
def expect(iterator, **kw):
"""
Retrieve a value from an iterator and check its properties. Same signature
as L{expectSingle}, except it takes an iterator instead of a value.
@see: L{expectSingle}
"""
nextval = iterator.next()
return expectSingle(nextval, **kw)
def significant(token):
"""
Determine if the token is 'significant', i.e. that it is not a comment and
not whitespace.
"""
# comment has 'None' is_whitespace() result. intentional?
return (not isinstance(token, Comment) and not token.is_whitespace())
def iterSignificant(tokenList):
"""
Iterate tokens that pass the test given by L{significant}, from a given
L{TokenList}.
"""
for token in tokenList.tokens:
if significant(token):
yield token
def _destringify(strval):
"""
Convert a single-quoted SQL string into its actual repsresented value.
(Assumes standards compliance, since we should be controlling all the input
here. The only quoting syntax respected is "''".)
"""
return strval[1:-1].replace("''", "'")
calendarserver-5.2+dfsg/twext/enterprise/dal/syntax.py 0000644 0001750 0001750 00000160047 12263343324 022314 0 ustar rahul rahul # -*- test-case-name: twext.enterprise.dal.test.test_sqlsyntax -*-
##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Syntax wrappers and generators for SQL.
"""
from itertools import count, repeat
from functools import partial
from operator import eq, ne
from zope.interface import implements
from twisted.internet.defer import succeed
from twext.enterprise.dal.model import Schema, Table, Column, Sequence, SQLType
from twext.enterprise.ienterprise import (
POSTGRES_DIALECT, ORACLE_DIALECT, SQLITE_DIALECT, IDerivedParameter
)
from twext.enterprise.util import mapOracleOutputType
from twisted.internet.defer import inlineCallbacks, returnValue
try:
import cx_Oracle
cx_Oracle
except ImportError:
cx_Oracle = None
class DALError(Exception):
"""
Base class for exceptions raised by this module. This can be raised
directly for API violations. This exception represents a serious
programming error and should normally never be caught or ignored.
"""
class QueryPlaceholder(object):
"""
Representation of the placeholders required to generate some SQL, for a
single statement. Contains information necessary to generate place holder
strings based on the database dialect.
"""
def placeholder(self):
raise NotImplementedError("See subclasses.")
class FixedPlaceholder(QueryPlaceholder):
"""
Fixed string used as the place holder.
"""
def __init__(self, placeholder):
self._placeholder = placeholder
def placeholder(self):
return self._placeholder
class NumericPlaceholder(QueryPlaceholder):
"""
Numeric counter used as the place holder.
"""
def __init__(self):
self._next = count(1).next
def placeholder(self):
return ':' + str(self._next())
def defaultPlaceholder():
"""
Generate a default L{QueryPlaceholder}
"""
return FixedPlaceholder('?')
class QueryGenerator(object):
"""
Maintains various pieces of transient information needed when building a
query. This includes the SQL dialect, the format of the place holder and
and automated id generator.
"""
def __init__(self, dialect=None, placeholder=None):
self.dialect = dialect if dialect else POSTGRES_DIALECT
if placeholder is None:
placeholder = defaultPlaceholder()
self.placeholder = placeholder
self.generatedID = count(1).next
def nextGeneratedID(self):
return "genid_%d" % (self.generatedID(),)
def shouldQuote(self, name):
return (self.dialect == ORACLE_DIALECT and name.lower() in _KEYWORDS)
class TableMismatch(Exception):
"""
A table in a statement did not match with a column.
"""
class NotEnoughValues(DALError):
"""
Not enough values were supplied for an L{Insert}.
"""
class _Statement(object):
"""
An SQL statement that may be executed. (An abstract base class, must
implement several methods.)
"""
_paramstyles = {
'pyformat': partial(FixedPlaceholder, "%s"),
'numeric': NumericPlaceholder,
'qmark': defaultPlaceholder,
}
def toSQL(self, queryGenerator=None):
if queryGenerator is None:
queryGenerator = QueryGenerator()
return self._toSQL(queryGenerator)
def _extraVars(self, txn, queryGenerator):
"""
A hook for subclasses to provide additional keyword arguments to the
C{bind} call when L{_Statement.on} is executed. Currently this is used
only for 'out' parameters to capture results when executing statements
that do not normally have a result (L{Insert}, L{Delete}, L{Update}).
"""
return {}
def _extraResult(self, result, outvars, queryGenerator):
"""
A hook for subclasses to manipulate the results of 'on', after they've
been retrieved by the database but before they've been given to
application code.
@param result: a L{Deferred} that will fire with the rows as returned by
the database.
@type result: C{list} of rows, which are C{list}s or C{tuple}s.
@param outvars: a dictionary of extra variables returned by
C{self._extraVars}.
@param queryGenerator: information about the connection where the statement
was executed.
@type queryGenerator: L{QueryGenerator} (a subclass thereof)
@return: the result to be returned from L{_Statement.on}.
@rtype: L{Deferred} firing result rows
"""
return result
def on(self, txn, raiseOnZeroRowCount=None, **kw):
"""
Execute this statement on a given L{IAsyncTransaction} and return the
resulting L{Deferred}.
@param txn: the L{IAsyncTransaction} to execute this on.
@param raiseOnZeroRowCount: a 0-argument callable which returns an
exception to raise if the executed SQL does not affect any rows.
@param kw: keyword arguments, mapping names of L{Parameter} objects
located somewhere in C{self}
@return: results from the database.
@rtype: a L{Deferred} firing a C{list} of records (C{tuple}s or
C{list}s)
"""
queryGenerator = QueryGenerator(txn.dialect, self._paramstyles[txn.paramstyle]())
outvars = self._extraVars(txn, queryGenerator)
kw.update(outvars)
fragment = self.toSQL(queryGenerator).bind(**kw)
result = txn.execSQL(fragment.text, fragment.parameters,
raiseOnZeroRowCount)
result = self._extraResult(result, outvars, queryGenerator)
if queryGenerator.dialect == ORACLE_DIALECT and result:
result.addCallback(self._fixOracleNulls)
return result
def _resultColumns(self):
"""
Subclasses must implement this to return a description of the columns
expected to be returned. This is a list of L{ColumnSyntax} objects, and
possibly other expression syntaxes which will be converted to C{None}.
"""
raise NotImplementedError(
"Each statement subclass must describe its result"
)
def _resultShape(self):
"""
Process the result of the subclass's C{_resultColumns}, as described in
the docstring above.
"""
for expectation in self._resultColumns():
if isinstance(expectation, ColumnSyntax):
yield expectation.model
else:
yield None
def _fixOracleNulls(self, rows):
"""
Oracle treats empty strings as C{NULL}. Fix this by looking at the
columns we expect to have returned, and replacing any C{None}s with
empty strings in the appropriate position.
"""
if rows is None:
return None
newRows = []
for row in rows:
newRow = []
for column, description in zip(row, self._resultShape()):
if ((description is not None and
# FIXME: "is the python type str" is what I mean; this list
# should be more centrally maintained
description.type.name in ('varchar', 'text', 'char') and
column is None
)):
column = ''
newRow.append(column)
newRows.append(newRow)
return newRows
class Syntax(object):
"""
Base class for syntactic convenience.
This class will define dynamic attribute access to represent its underlying
model as a Python namespace.
You can access the underlying model as '.model'.
"""
modelType = None
model = None
def __init__(self, model):
if not isinstance(model, self.modelType):
# make sure we don't get a misleading repr()
raise DALError("type mismatch: %r %r", type(self), model)
self.model = model
def __repr__(self):
if self.model is not None:
return '' % (self.model,)
return super(Syntax, self).__repr__()
def comparison(comparator):
def __(self, other):
if other is None:
return NullComparison(self, comparator)
if isinstance(other, Select):
return NotImplemented
if isinstance(other, ColumnSyntax):
return ColumnComparison(self, comparator, other)
if isinstance(other, ExpressionSyntax):
return CompoundComparison(self, comparator, other)
else:
return CompoundComparison(self, comparator, Constant(other))
return __
class ExpressionSyntax(Syntax):
__eq__ = comparison('=')
__ne__ = comparison('!=')
# NB: these operators "cannot be used with lists" (see ORA-01796)
__gt__ = comparison('>')
__ge__ = comparison('>=')
__lt__ = comparison('<')
__le__ = comparison('<=')
# TODO: operators aren't really comparisons; these should behave slightly
# differently. (For example; in Oracle, 'select 3 = 4 from dual' doesn't
# work, but 'select 3 + 4 from dual' does; similarly, you can't do 'select *
# from foo where 3 + 4', but you can do 'select * from foo where 3 + 4 >
# 0'.)
__add__ = comparison("+")
__sub__ = comparison("-")
__div__ = comparison("/")
__mul__ = comparison("*")
def __nonzero__(self):
raise DALError(
"SQL expressions should not be tested for truth value in Python.")
def In(self, other):
"""
We support two forms of the SQL "IN" syntax: one where a list of values is supplied, the other where
a sub-select is used to provide a set of values.
@param other: a constant parameter or sub-select
@type other: L{Parameter} or L{Select}
"""
if isinstance(other, Parameter):
if other.count is None:
raise DALError("IN expression needs an explicit count of parameters")
return CompoundComparison(self, 'in', Constant(other))
else:
# Can't be Select.__contains__ because __contains__ gets __nonzero__
# called on its result by the 'in' syntax.
return CompoundComparison(self, 'in', other)
def StartsWith(self, other):
return CompoundComparison(self, "like", CompoundComparison(Constant(other), '||', Constant('%')))
def EndsWith(self, other):
return CompoundComparison(self, "like", CompoundComparison(Constant('%'), '||', Constant(other)))
def Contains(self, other):
return CompoundComparison(self, "like", CompoundComparison(Constant('%'), '||', CompoundComparison(Constant(other), '||', Constant('%'))))
class FunctionInvocation(ExpressionSyntax):
def __init__(self, function, *args):
self.function = function
self.args = args
def allColumns(self):
"""
All of the columns in all of the arguments' columns.
"""
def ac():
for arg in self.args:
for column in arg.allColumns():
yield column
return list(ac())
def subSQL(self, queryGenerator, allTables):
result = SQLFragment(self.function.nameFor(queryGenerator))
result.append(_inParens(
_commaJoined(_convert(arg).subSQL(queryGenerator, allTables)
for arg in self.args)))
return result
class Constant(ExpressionSyntax):
"""
Generates an expression for a place holder where a value will be bound to the query. If the constant is a Parameter
with count > 1 then a parenthesized, comma-separated list of place holders will be generated.
"""
def __init__(self, value):
self.value = value
def allColumns(self):
return []
def subSQL(self, queryGenerator, allTables):
if isinstance(self.value, Parameter) and self.value.count is not None:
return _inParens(_CommaList(
[SQLFragment(queryGenerator.placeholder.placeholder(), [self.value] if ctr == 0 else []) for ctr in range(self.value.count)]
).subSQL(queryGenerator, allTables))
else:
return SQLFragment(queryGenerator.placeholder.placeholder(), [self.value])
class NamedValue(ExpressionSyntax):
"""
A constant within the database; something predefined, such as
CURRENT_TIMESTAMP.
"""
def __init__(self, name):
self.name = name
def subSQL(self, queryGenerator, allTables):
return SQLFragment(self.name)
class Function(object):
"""
An L{Function} is a representation of an SQL Function function.
"""
def __init__(self, name, oracleName=None):
self.name = name
self.oracleName = oracleName
def nameFor(self, queryGenerator):
if queryGenerator.dialect == ORACLE_DIALECT and self.oracleName is not None:
return self.oracleName
return self.name
def __call__(self, *args):
"""
Produce an L{FunctionInvocation}
"""
return FunctionInvocation(self, *args)
Count = Function("count")
Sum = Function("sum")
Max = Function("max")
Len = Function("character_length", "length")
Upper = Function("upper")
Lower = Function("lower")
_sqliteLastInsertRowID = Function("last_insert_rowid")
# Use a specific value here for "the convention for case-insensitive values in
# the database" so we don't need to keep remembering whether it's upper or
# lowercase.
CaseFold = Lower
class SchemaSyntax(Syntax):
"""
Syntactic convenience for L{Schema}.
"""
modelType = Schema
def __getattr__(self, attr):
try:
tableModel = self.model.tableNamed(attr)
except KeyError:
try:
seqModel = self.model.sequenceNamed(attr)
except KeyError:
raise AttributeError("schema has no table or sequence %r" % (attr,))
else:
return SequenceSyntax(seqModel)
else:
syntax = TableSyntax(tableModel)
# Needs to be preserved here so that aliasing will work.
setattr(self, attr, syntax)
return syntax
def __iter__(self):
for table in self.model.tables:
yield TableSyntax(table)
class SequenceSyntax(ExpressionSyntax):
"""
Syntactic convenience for L{Sequence}.
"""
modelType = Sequence
def subSQL(self, queryGenerator, allTables):
"""
Convert to an SQL fragment.
"""
if queryGenerator.dialect == ORACLE_DIALECT:
fmt = "%s.nextval"
else:
fmt = "nextval('%s')"
return SQLFragment(fmt % (self.model.name,))
def _nameForDialect(name, dialect):
"""
If the given name is being computed in the oracle dialect, truncate it to 30
characters.
"""
if dialect == ORACLE_DIALECT:
name = name[:30]
return name
class TableSyntax(Syntax):
"""
Syntactic convenience for L{Table}.
"""
modelType = Table
def alias(self):
"""
Return an alias for this L{TableSyntax} so that it might be joined
against itself.
As in SQL, C{someTable.join(someTable)} is an error; you can't join a
table against itself. However, C{t = someTable.alias();
someTable.join(t)} is usable as a 'from' clause.
"""
return TableAlias(self.model)
def join(self, otherTableSyntax, on=None, type=''):
"""
Create a L{Join}, representing a join between two tables.
"""
if on is None:
type = 'cross'
return Join(self, type, otherTableSyntax, on)
def subSQL(self, queryGenerator, allTables):
"""
Generate the L{SQLFragment} for this table's identification; this is
for use in a 'from' clause.
"""
# XXX maybe there should be a specific method which is only invoked
# from the FROM clause, that only tables and joins would implement?
return SQLFragment(_nameForDialect(self.model.name, queryGenerator.dialect))
def __getattr__(self, attr):
"""
Attributes named after columns on a L{TableSyntax} are returned by
accessing their names as attributes. For example, if there is a schema
syntax object created from SQL equivalent to 'create table foo (bar
integer, baz integer)', 'schemaSyntax.foo.bar' and
'schemaSyntax.foo.baz'
"""
try:
column = self.model.columnNamed(attr)
except KeyError:
raise AttributeError("table {0} has no column {1}".format(
self.model.name, attr
))
else:
return ColumnSyntax(column)
def __iter__(self):
"""
Yield a L{ColumnSyntax} for each L{Column} in this L{TableSyntax}'s
model's table.
"""
for column in self.model.columns:
yield ColumnSyntax(column)
def tables(self):
"""
Return a C{list} of tables involved in the query by this table. (This
method is expected by anything that can act as the C{From} clause: see
L{Join.tables})
"""
return [self]
def columnAliases(self):
"""
Inspect the Python aliases for this table in the given schema. Python
aliases for a table are created by setting an attribute on the schema.
For example, in a schema which had "schema.MYTABLE.ID =
schema.MYTABLE.MYTABLE_ID" applied to it,
schema.MYTABLE.columnAliases() would return C{[("ID",
schema.MYTABLE.MYTABLE_ID)]}.
@return: a list of 2-tuples of (alias (C{str}), column
(C{ColumnSyntax})), enumerating all of the Python aliases provided.
"""
result = {}
for k, v in self.__dict__.items():
if isinstance(v, ColumnSyntax):
result[k] = v
return result
def __contains__(self, columnSyntax):
if isinstance(columnSyntax, FunctionInvocation):
columnSyntax = columnSyntax.arg
return (columnSyntax.model.table is self.model)
class TableAlias(TableSyntax):
"""
An alias for a table, under a different name, for the purpose of doing a
self-join.
"""
def subSQL(self, queryGenerator, allTables):
"""
Return an L{SQLFragment} with a string of the form C{'mytable myalias'}
suitable for use in a FROM clause.
"""
result = super(TableAlias, self).subSQL(queryGenerator, allTables)
result.append(SQLFragment(" " + self._aliasName(allTables)))
return result
def _aliasName(self, allTables):
"""
The alias under which this table will be known in the query.
@param allTables: a C{list}, as passed to a C{subSQL} method during SQL
generation.
@return: a string naming this alias, a unique identifier, albeit one
which is only stable within the query which populated C{allTables}.
@rtype: C{str}
"""
anum = [t for t in allTables
if isinstance(t, TableAlias)].index(self) + 1
return 'alias%d' % (anum,)
def __getattr__(self, attr):
return AliasedColumnSyntax(self, self.model.columnNamed(attr))
class Join(object):
"""
A DAL object representing an SQL 'join' statement.
@ivar leftSide: a L{Join} or L{TableSyntax} representing the left side of
this join.
@ivar rightSide: a L{TableSyntax} representing the right side of this join.
@ivar type: the type of join this is. For example, for a left outer join,
this would be C{'left outer'}.
@type type: C{str}
@ivar on: the 'on' clause of this table.
@type on: L{ExpressionSyntax}
"""
def __init__(self, leftSide, type, rightSide, on):
self.leftSide = leftSide
self.type = type
self.rightSide = rightSide
self.on = on
def subSQL(self, queryGenerator, allTables):
stmt = SQLFragment()
stmt.append(self.leftSide.subSQL(queryGenerator, allTables))
stmt.text += ' '
if self.type:
stmt.text += self.type
stmt.text += ' '
stmt.text += 'join '
stmt.append(self.rightSide.subSQL(queryGenerator, allTables))
if self.type != 'cross':
stmt.text += ' on '
stmt.append(self.on.subSQL(queryGenerator, allTables))
return stmt
def tables(self):
"""
Return a C{list} of tables which this L{Join} will involve in a query:
all those present on the left side, as well as all those present on the
right side.
"""
return self.leftSide.tables() + self.rightSide.tables()
def join(self, otherTable, on=None, type=None):
if on is None:
type = 'cross'
return Join(self, type, otherTable, on)
_KEYWORDS = ["access",
# SQL keyword, but we have a column with this name
"path",
# Not actually a standard keyword, but a function in oracle, and we
# have a column with this name.
"size",
# not actually sure what this is; only experimentally determined
# that not quoting it causes an issue.
]
class ColumnSyntax(ExpressionSyntax):
"""
Syntactic convenience for L{Column}.
@ivar _alwaysQualified: a boolean indicating whether to always qualify the
column name in generated SQL, regardless of whether the column name is
specific enough even when unqualified.
@type _alwaysQualified: C{bool}
"""
modelType = Column
_alwaysQualified = False
def allColumns(self):
return [self]
def subSQL(self, queryGenerator, allTables):
# XXX This, and 'model', could in principle conflict with column names.
# Maybe do something about that.
name = self.model.name
if queryGenerator.shouldQuote(name):
name = '"%s"' % (name,)
if self._alwaysQualified:
qualified = True
else:
qualified = False
for tableSyntax in allTables:
if self.model.table is not tableSyntax.model:
if self.model.name in (c.name for c in
tableSyntax.model.columns):
qualified = True
break
if qualified:
return SQLFragment(self._qualify(name, allTables))
else:
return SQLFragment(name)
def __hash__(self):
return hash(self.model) + 10
def _qualify(self, name, allTables):
return self.model.table.name + '.' + name
class ResultAliasSyntax(ExpressionSyntax):
def __init__(self, expression, alias=None):
self.expression = expression
self.alias = alias
def aliasName(self, queryGenerator):
if self.alias is None:
self.alias = queryGenerator.nextGeneratedID()
return self.alias
def columnReference(self):
return AliasReferenceSyntax(self)
def allColumns(self):
return self.expression.allColumns()
def subSQL(self, queryGenerator, allTables):
result = SQLFragment()
result.append(self.expression.subSQL(queryGenerator, allTables))
result.append(SQLFragment(" %s" % (self.aliasName(queryGenerator),)))
return result
class AliasReferenceSyntax(ExpressionSyntax):
def __init__(self, resultAlias):
self.resultAlias = resultAlias
def allColumns(self):
return self.resultAlias.allColumns()
def subSQL(self, queryGenerator, allTables):
return SQLFragment(self.resultAlias.aliasName(queryGenerator))
class AliasedColumnSyntax(ColumnSyntax):
"""
An L{AliasedColumnSyntax} is like a L{ColumnSyntax}, but it generates SQL
for a column of a table under an alias, rather than directly. i.e. this is
used for C{'something.col'} in C{'select something.col from tablename
something'} rather than the 'col' in C{'select col from tablename'}.
@see: L{TableSyntax.alias}
"""
_alwaysQualified = True
def __init__(self, tableAlias, model):
super(AliasedColumnSyntax, self).__init__(model)
self._tableAlias = tableAlias
def _qualify(self, name, allTables):
return self._tableAlias._aliasName(allTables) + '.' + name
class Comparison(ExpressionSyntax):
def __init__(self, a, op, b):
self.a = a
self.op = op
self.b = b
def _subexpression(self, expr, queryGenerator, allTables):
result = expr.subSQL(queryGenerator, allTables)
if self.op not in ('and', 'or') and isinstance(expr, Comparison):
result = _inParens(result)
return result
def booleanOp(self, operand, other):
return CompoundComparison(self, operand, other)
def And(self, other):
return self.booleanOp('and', other)
def Or(self, other):
return self.booleanOp('or', other)
class NullComparison(Comparison):
"""
A L{NullComparison} is a comparison of a column or expression with None.
"""
def __init__(self, a, op):
# 'b' is always None for this comparison type
super(NullComparison, self).__init__(a, op, None)
def subSQL(self, queryGenerator, allTables):
sqls = SQLFragment()
sqls.append(self.a.subSQL(queryGenerator, allTables))
sqls.text += " is "
if self.op != "=":
sqls.text += "not "
sqls.text += "null"
return sqls
class CompoundComparison(Comparison):
"""
A compound comparison; two or more constraints, joined by an operation
(currently only AND or OR).
"""
def allColumns(self):
return self.a.allColumns() + self.b.allColumns()
def subSQL(self, queryGenerator, allTables):
if (queryGenerator.dialect == ORACLE_DIALECT
and isinstance(self.b, Constant) and self.b.value == ''
and self.op in ('=', '!=')):
return NullComparison(self.a, self.op).subSQL(queryGenerator, allTables)
stmt = SQLFragment()
result = self._subexpression(self.a, queryGenerator, allTables)
if (isinstance(self.a, CompoundComparison)
and self.a.op == 'or' and self.op == 'and'):
result = _inParens(result)
stmt.append(result)
stmt.text += ' %s ' % (self.op,)
result = self._subexpression(self.b, queryGenerator, allTables)
if (isinstance(self.b, CompoundComparison)
and self.b.op == 'or' and self.op == 'and'):
result = _inParens(result)
if isinstance(self.b, Tuple):
# If the right-hand side of the comparison is a Tuple, it needs to
# be double-parenthesized in Oracle, as per
# http://docs.oracle.com/cd/B28359_01/server.111/b28286/expressions015.htm#i1033664
# because it is an expression list.
result = _inParens(result)
stmt.append(result)
return stmt
_operators = {"=": eq, "!=": ne}
class ColumnComparison(CompoundComparison):
"""
Comparing two columns is the same as comparing any other two expressions,
except that Python can retrieve a truth value, so that columns may be
compared for value equality in scripts that want to interrogate schemas.
"""
def __nonzero__(self):
thunk = _operators.get(self.op)
if thunk is None:
return super(ColumnComparison, self).__nonzero__()
return thunk(self.a.model, self.b.model)
class _AllColumns(NamedValue):
def __init__(self):
self.name = "*"
def allColumns(self):
return []
ALL_COLUMNS = _AllColumns()
class _SomeColumns(object):
def __init__(self, columns):
self.columns = columns
def subSQL(self, queryGenerator, allTables):
first = True
cstatement = SQLFragment()
for column in self.columns:
if first:
first = False
else:
cstatement.append(SQLFragment(", "))
cstatement.append(column.subSQL(queryGenerator, allTables))
return cstatement
def _checkColumnsMatchTables(columns, tables):
"""
Verify that the given C{columns} match the given C{tables}; that is, that
every L{TableSyntax} referenced by every L{ColumnSyntax} referenced by
every L{ExpressionSyntax} in the given C{columns} list is present in the
given C{tables} list.
@param columns: a L{list} of L{ExpressionSyntax}, each of which references
some set of L{ColumnSyntax}es via its C{allColumns} method.
@param tables: a L{list} of L{TableSyntax}
@return: L{None}
@rtype: L{NoneType}
@raise TableMismatch: if any table referenced by a column is I{not} found
in C{tables}
"""
for expression in columns:
for column in expression.allColumns():
for table in tables:
if column in table:
break
else:
raise TableMismatch("{} not found in {}".format(
column, tables
))
return None
class Tuple(ExpressionSyntax):
def __init__(self, columns):
self.columns = columns
def __iter__(self):
return iter(self.columns)
def subSQL(self, queryGenerator, allTables):
return _inParens(_commaJoined(c.subSQL(queryGenerator, allTables)
for c in self.columns))
def allColumns(self):
return self.columns
class SetExpression(object):
"""
A UNION, INTERSECT, or EXCEPT construct used inside a SELECT.
"""
OPTYPE_ALL = "all"
OPTYPE_DISTINCT = "distinct"
def __init__(self, selects, optype=None):
"""
@param selects: a single Select or a list of Selects
@type selects: C{list} or L{Select}
@param optype: whether to use the ALL, DISTINCT constructs: C{None} use neither, OPTYPE_ALL, or OPTYPE_DISTINCT
@type optype: C{str}
"""
if isinstance(selects, Select):
selects = (selects,)
self.selects = selects
self.optype = optype
for select in self.selects:
if not isinstance(select, Select):
raise DALError("Must have SELECT statements in a set expression")
if self.optype not in (None, SetExpression.OPTYPE_ALL, SetExpression.OPTYPE_DISTINCT,):
raise DALError("Must have either 'all' or 'distinct' in a set expression")
def subSQL(self, queryGenerator, allTables):
result = SQLFragment()
for select in self.selects:
result.append(self.setOpSQL(queryGenerator))
if self.optype == SetExpression.OPTYPE_ALL:
result.append(SQLFragment("ALL "))
elif self.optype == SetExpression.OPTYPE_DISTINCT:
result.append(SQLFragment("DISTINCT "))
result.append(select.subSQL(queryGenerator, allTables))
return result
def allColumns(self):
return []
class Union(SetExpression):
"""
A UNION construct used inside a SELECT.
"""
def setOpSQL(self, queryGenerator):
return SQLFragment(" UNION ")
class Intersect(SetExpression):
"""
An INTERSECT construct used inside a SELECT.
"""
def setOpSQL(self, queryGenerator):
return SQLFragment(" INTERSECT ")
class Except(SetExpression):
"""
An EXCEPT construct used inside a SELECT.
"""
def setOpSQL(self, queryGenerator):
if queryGenerator.dialect == POSTGRES_DIALECT:
return SQLFragment(" EXCEPT ")
elif queryGenerator.dialect == ORACLE_DIALECT:
return SQLFragment(" MINUS ")
else:
raise NotImplementedError("Unsupported dialect")
class Select(_Statement):
"""
'select' statement.
"""
def __init__(self, columns=None, Where=None, From=None, OrderBy=None,
GroupBy=None, Limit=None, ForUpdate=False, NoWait=False, Ascending=None,
Having=None, Distinct=False, As=None,
SetExpression=None):
self.From = From
self.Where = Where
self.Distinct = Distinct
if not isinstance(OrderBy, (Tuple, list, tuple, type(None))):
OrderBy = [OrderBy]
self.OrderBy = OrderBy
if not isinstance(GroupBy, (list, tuple, type(None))):
GroupBy = [GroupBy]
self.GroupBy = GroupBy
self.Limit = Limit
self.Having = Having
self.SetExpression = SetExpression
if columns is None:
columns = ALL_COLUMNS
else:
_checkColumnsMatchTables(columns, From.tables())
columns = _SomeColumns(columns)
self.columns = columns
self.ForUpdate = ForUpdate
self.NoWait = NoWait
self.Ascending = Ascending
self.As = As
# A FROM that uses a sub-select will need the AS alias name
if isinstance(self.From, Select):
if self.From.As is None:
self.From.As = ""
def __eq__(self, other):
"""
Create a comparison.
"""
if isinstance(other, (list, tuple)):
other = Tuple(other)
return CompoundComparison(other, '=', self)
def _toSQL(self, queryGenerator):
"""
@return: a 'select' statement with placeholders and arguments
@rtype: L{SQLFragment}
"""
if self.SetExpression is not None:
stmt = SQLFragment("(")
else:
stmt = SQLFragment()
stmt.append(SQLFragment("select "))
if self.Distinct:
stmt.text += "distinct "
allTables = self.From.tables()
stmt.append(self.columns.subSQL(queryGenerator, allTables))
stmt.text += " from "
stmt.append(self.From.subSQL(queryGenerator, allTables))
if self.Where is not None:
wherestmt = self.Where.subSQL(queryGenerator, allTables)
stmt.text += " where "
stmt.append(wherestmt)
if self.GroupBy is not None:
stmt.text += " group by "
fst = True
for subthing in self.GroupBy:
if fst:
fst = False
else:
stmt.text += ', '
stmt.append(subthing.subSQL(queryGenerator, allTables))
if self.Having is not None:
havingstmt = self.Having.subSQL(queryGenerator, allTables)
stmt.text += " having "
stmt.append(havingstmt)
if self.SetExpression is not None:
stmt.append(SQLFragment(")"))
stmt.append(self.SetExpression.subSQL(queryGenerator, allTables))
if self.OrderBy is not None:
stmt.text += " order by "
fst = True
for subthing in self.OrderBy:
if fst:
fst = False
else:
stmt.text += ', '
stmt.append(subthing.subSQL(queryGenerator, allTables))
if self.Ascending is not None:
if self.Ascending:
kw = " asc"
else:
kw = " desc"
stmt.append(SQLFragment(kw))
if self.ForUpdate:
stmt.text += " for update"
if self.NoWait:
stmt.text += " nowait"
if self.Limit is not None:
limitConst = Constant(self.Limit).subSQL(queryGenerator, allTables)
if queryGenerator.dialect == ORACLE_DIALECT:
wrapper = SQLFragment("select * from (")
wrapper.append(stmt)
wrapper.append(SQLFragment(") where ROWNUM <= "))
stmt = wrapper
else:
stmt.text += " limit "
stmt.append(limitConst)
return stmt
def subSQL(self, queryGenerator, allTables):
result = SQLFragment("(")
result.append(self.toSQL(queryGenerator))
result.append(SQLFragment(")"))
if self.As is not None:
if self.As == "":
self.As = queryGenerator.nextGeneratedID()
result.append(SQLFragment(" %s" % (self.As,)))
return result
def _resultColumns(self):
"""
Determine the list of L{ColumnSyntax} objects that will represent the
result. Normally just the list of selected columns; if wildcard syntax
is used though, determine the ordering from the database.
"""
if self.columns is ALL_COLUMNS:
# TODO: Possibly this rewriting should always be done, before even
# executing the query, so that if we develop a schema mismatch with
# the database (additional columns), the application will still see
# the right rows.
for table in self.From.tables():
for column in table:
yield column
else:
for column in self.columns.columns:
yield column
def tables(self):
"""
Determine the tables used by the result columns.
"""
if self.columns is ALL_COLUMNS:
# TODO: Possibly this rewriting should always be done, before even
# executing the query, so that if we develop a schema mismatch with
# the database (additional columns), the application will still see
# the right rows.
return self.From.tables()
else:
tables = set([column.model.table for column in self.columns.columns if isinstance(column, ColumnSyntax)])
for table in self.From.tables():
tables.add(table.model)
return [TableSyntax(table) for table in tables]
def _commaJoined(stmts):
first = True
cstatement = SQLFragment()
for stmt in stmts:
if first:
first = False
else:
cstatement.append(SQLFragment(", "))
cstatement.append(stmt)
return cstatement
def _inParens(stmt):
result = SQLFragment("(")
result.append(stmt)
result.append(SQLFragment(")"))
return result
def _fromSameTable(columns):
"""
Extract the common table used by a list of L{Column} objects, raising
L{TableMismatch}.
"""
table = columns[0].table
for column in columns:
if table is not column.table:
raise TableMismatch("Columns must all be from the same table.")
return table
def _modelsFromMap(columnMap):
"""
Get the L{Column} objects from a mapping of L{ColumnSyntax} to values.
"""
return [c.model for c in columnMap.keys()]
class _CommaList(object):
def __init__(self, subfragments):
self.subfragments = subfragments
def subSQL(self, queryGenerator, allTables):
return _commaJoined(f.subSQL(queryGenerator, allTables)
for f in self.subfragments)
class _DMLStatement(_Statement):
"""
Common functionality of Insert/Update/Delete statements.
"""
def _returningClause(self, queryGenerator, stmt, allTables):
"""
Add a dialect-appropriate 'returning' clause to the end of the given
SQL statement.
@param queryGenerator: describes the database we are generating the
statement for.
@type queryGenerator: L{QueryGenerator}
@param stmt: the SQL fragment generated without the 'returning' clause
@type stmt: L{SQLFragment}
@param allTables: all tables involved in the query; see any C{subSQL}
method.
@return: the C{stmt} parameter.
"""
retclause = self.Return
if retclause is None:
return stmt
if isinstance(retclause, (tuple, list)):
retclause = _CommaList(retclause)
if queryGenerator.dialect == SQLITE_DIALECT:
# sqlite does this another way.
return stmt
elif retclause is not None:
stmt.text += ' returning '
stmt.append(retclause.subSQL(queryGenerator, allTables))
if queryGenerator.dialect == ORACLE_DIALECT:
stmt.text += ' into '
params = []
retvals = self._returnAsList()
for n, _ignore_v in enumerate(retvals):
params.append(
Constant(Parameter("oracle_out_" + str(n)))
.subSQL(queryGenerator, allTables)
)
stmt.append(_commaJoined(params))
return stmt
def _returnAsList(self):
if not isinstance(self.Return, (tuple, list)):
return [self.Return]
else:
return self.Return
def _extraVars(self, txn, queryGenerator):
if self.Return is None:
return []
result = []
rvars = self._returnAsList()
if queryGenerator.dialect == ORACLE_DIALECT:
for n, v in enumerate(rvars):
result.append(("oracle_out_" + str(n), _OracleOutParam(v)))
return result
def _extraResult(self, result, outvars, queryGenerator):
if queryGenerator.dialect == ORACLE_DIALECT and self.Return is not None:
def processIt(shouldBeNone):
result = [[v.value for _ignore_k, v in outvars]]
return result
return result.addCallback(processIt)
else:
return result
def _resultColumns(self):
return self._returnAsList()
class _OracleOutParam(object):
"""
A parameter that will be populated using the cx_Oracle API for host
variables.
"""
implements(IDerivedParameter)
def __init__(self, columnSyntax):
self.typeID = columnSyntax.model.type.name.lower()
def preQuery(self, cursor):
typeMap = {'integer': cx_Oracle.NUMBER,
'text': cx_Oracle.NCLOB,
'varchar': cx_Oracle.STRING,
'timestamp': cx_Oracle.TIMESTAMP}
self.var = cursor.var(typeMap[self.typeID])
return self.var
def postQuery(self, cursor):
self.value = mapOracleOutputType(self.var.getvalue())
self.var = None
class Insert(_DMLStatement):
"""
'insert' statement.
"""
def __init__(self, columnMap, Return=None):
self.columnMap = columnMap
self.Return = Return
columns = _modelsFromMap(columnMap)
table = _fromSameTable(columns)
required = [column for column in table.columns if column.needsValue()]
unspecified = [column for column in required
if column not in columns]
if unspecified:
raise NotEnoughValues(
'Columns [%s] required.' %
(', '.join([c.name for c in unspecified])))
def _toSQL(self, queryGenerator):
"""
@return: a 'insert' statement with placeholders and arguments
@rtype: L{SQLFragment}
"""
columnsAndValues = self.columnMap.items()
tableModel = columnsAndValues[0][0].model.table
specifiedColumnModels = [x.model for x in self.columnMap.keys()]
if queryGenerator.dialect == ORACLE_DIALECT:
# See test_nextSequenceDefaultImplicitExplicitOracle.
for column in tableModel.columns:
if isinstance(column.default, Sequence):
columnSyntax = ColumnSyntax(column)
if column not in specifiedColumnModels:
columnsAndValues.append(
(columnSyntax, SequenceSyntax(column.default))
)
sortedColumns = sorted(columnsAndValues,
key=lambda (c, v): c.model.name)
allTables = []
stmt = SQLFragment('insert into ')
stmt.append(TableSyntax(tableModel).subSQL(queryGenerator, allTables))
stmt.append(SQLFragment(" "))
stmt.append(_inParens(_commaJoined(
[c.subSQL(queryGenerator, allTables) for (c, _ignore_v) in
sortedColumns])))
stmt.append(SQLFragment(" values "))
stmt.append(_inParens(_commaJoined(
[_convert(v).subSQL(queryGenerator, allTables)
for (c, v) in sortedColumns])))
return self._returningClause(queryGenerator, stmt, allTables)
def on(self, txn, *a, **kw):
"""
Override to provide extra logic for L{Insert}s that return values on
databases that don't provide return values as part of their C{INSERT}
behavior.
"""
result = super(_DMLStatement, self).on(txn, *a, **kw)
if self.Return is not None and txn.dialect == SQLITE_DIALECT:
table = self._returnAsList()[0].model.table
return Select(self._returnAsList(),
# TODO: error reporting when 'return' includes columns
# foreign to the primary table.
From=TableSyntax(table),
Where=ColumnSyntax(Column(table, "rowid",
SQLType("integer", None))) ==
_sqliteLastInsertRowID()
).on(txn, *a, **kw)
return result
def _convert(x):
"""
Convert a value to an appropriate SQL AST node. (Currently a simple
isinstance, could be promoted to use adaptation if we want to get fancy.)
"""
if isinstance(x, ExpressionSyntax):
return x
else:
return Constant(x)
class Update(_DMLStatement):
"""
'update' statement
@ivar columnMap: A L{dict} mapping L{ColumnSyntax} objects to values to
change; values may be simple database values (such as L{str},
L{unicode}, L{datetime.datetime}, L{float}, L{int} etc) or L{Parameter}
instances.
@type columnMap: L{dict}
"""
def __init__(self, columnMap, Where, Return=None):
super(Update, self).__init__()
_fromSameTable(_modelsFromMap(columnMap))
self.columnMap = columnMap
self.Where = Where
self.Return = Return
@inlineCallbacks
def on(self, txn, *a, **kw):
"""
Override to provide extra logic for L{Update}s that return values on
databases that don't provide return values as part of their C{UPDATE}
behavior.
"""
doExtra = self.Return is not None and txn.dialect == SQLITE_DIALECT
upcall = lambda: super(_DMLStatement, self).on(txn, *a, **kw)
if doExtra:
table = self._returnAsList()[0].model.table
rowidcol = ColumnSyntax(Column(table, "rowid",
SQLType("integer", None)))
prequery = Select([rowidcol], From=TableSyntax(table),
Where=self.Where)
preresult = prequery.on(txn, *a, **kw)
before = yield preresult
yield upcall()
result = (yield Select(self._returnAsList(),
# TODO: error reporting when 'return' includes
# columns foreign to the primary table.
From=TableSyntax(table),
Where=reduce(lambda left, right: left.Or(right),
((rowidcol == x) for [x] in before))
).on(txn, *a, **kw))
returnValue(result)
else:
returnValue((yield upcall()))
def _toSQL(self, queryGenerator):
"""
@return: a 'insert' statement with placeholders and arguments
@rtype: L{SQLFragment}
"""
sortedColumns = sorted(self.columnMap.items(),
key=lambda (c, v): c.model.name)
allTables = []
result = SQLFragment('update ')
result.append(
TableSyntax(sortedColumns[0][0].model.table).subSQL(
queryGenerator, allTables)
)
result.text += ' set '
result.append(
_commaJoined(
[c.subSQL(queryGenerator, allTables).append(
SQLFragment(" = ").subSQL(queryGenerator, allTables)
).append(_convert(v).subSQL(queryGenerator, allTables))
for (c, v) in sortedColumns]
)
)
if self.Where is not None:
result.append(SQLFragment(' where '))
result.append(self.Where.subSQL(queryGenerator, allTables))
return self._returningClause(queryGenerator, result, allTables)
class Delete(_DMLStatement):
"""
'delete' statement.
"""
def __init__(self, From, Where, Return=None):
"""
If Where is None then all rows will be deleted.
"""
self.From = From
self.Where = Where
self.Return = Return
def _toSQL(self, queryGenerator):
result = SQLFragment()
allTables = self.From.tables()
result.text += 'delete from '
result.append(self.From.subSQL(queryGenerator, allTables))
if self.Where is not None:
result.text += ' where '
result.append(self.Where.subSQL(queryGenerator, allTables))
return self._returningClause(queryGenerator, result, allTables)
@inlineCallbacks
def on(self, txn, *a, **kw):
upcall = lambda: super(Delete, self).on(txn, *a, **kw)
if txn.dialect == SQLITE_DIALECT and self.Return is not None:
result = yield Select(self._returnAsList(), From=self.From,
Where=self.Where).on(txn, *a, **kw)
yield upcall()
else:
result = yield upcall()
returnValue(result)
class _LockingStatement(_Statement):
"""
A statement related to lock management, which implicitly has no results.
"""
def _resultColumns(self):
"""
No columns should be expected, so return an infinite iterator of None.
"""
return repeat(None)
class Lock(_LockingStatement):
"""
An SQL 'lock' statement.
"""
def __init__(self, table, mode):
self.table = table
self.mode = mode
@classmethod
def exclusive(cls, table):
return cls(table, 'exclusive')
def _toSQL(self, queryGenerator):
if queryGenerator.dialect == SQLITE_DIALECT:
# FIXME - this is only stubbed out for testing right now, actual
# concurrency would require some kind of locking statement here.
# BEGIN IMMEDIATE maybe, if that's okay in the middle of a
# transaction or repeatedly?
return SQLFragment('select null')
return SQLFragment('lock table ').append(
self.table.subSQL(queryGenerator, [self.table])).append(
SQLFragment(' in %s mode' % (self.mode,)))
class DatabaseLock(_LockingStatement):
"""
An SQL exclusive session level advisory lock
"""
def _toSQL(self, queryGenerator):
assert(queryGenerator.dialect == POSTGRES_DIALECT)
return SQLFragment('select pg_advisory_lock(1)')
def on(self, txn, *a, **kw):
"""
Override on() to only execute on Postgres
"""
if txn.dialect == POSTGRES_DIALECT:
return super(DatabaseLock, self).on(txn, *a, **kw)
return succeed(None)
class DatabaseUnlock(_LockingStatement):
"""
An SQL exclusive session level advisory lock
"""
def _toSQL(self, queryGenerator):
assert(queryGenerator.dialect == POSTGRES_DIALECT)
return SQLFragment('select pg_advisory_unlock(1)')
def on(self, txn, *a, **kw):
"""
Override on() to only execute on Postgres
"""
if txn.dialect == POSTGRES_DIALECT:
return super(DatabaseUnlock, self).on(txn, *a, **kw)
return succeed(None)
class Savepoint(_LockingStatement):
"""
An SQL 'savepoint' statement.
"""
def __init__(self, name):
self.name = name
def _toSQL(self, queryGenerator):
return SQLFragment('savepoint %s' % (self.name,))
class RollbackToSavepoint(_LockingStatement):
"""
An SQL 'rollback to savepoint' statement.
"""
def __init__(self, name):
self.name = name
def _toSQL(self, queryGenerator):
return SQLFragment('rollback to savepoint %s' % (self.name,))
class ReleaseSavepoint(_LockingStatement):
"""
An SQL 'release savepoint' statement.
"""
def __init__(self, name):
self.name = name
def _toSQL(self, queryGenerator):
return SQLFragment('release savepoint %s' % (self.name,))
class SavepointAction(object):
def __init__(self, name):
self._name = name
def acquire(self, txn):
return Savepoint(self._name).on(txn)
def rollback(self, txn):
return RollbackToSavepoint(self._name).on(txn)
def release(self, txn):
if txn.dialect == ORACLE_DIALECT:
# There is no 'release savepoint' statement in oracle, but then, we
# don't need it because there's no resource to manage. Just don't
# do anything.
return NoOp()
else:
return ReleaseSavepoint(self._name).on(txn)
class NoOp(object):
def on(self, *a, **kw):
return succeed(None)
class SQLFragment(object):
"""
Combination of SQL text and arguments; a statement which may be executed
against a database.
"""
def __init__(self, text="", parameters=None):
self.text = text
if parameters is None:
parameters = []
self.parameters = parameters
def bind(self, **kw):
params = []
for parameter in self.parameters:
if isinstance(parameter, Parameter):
if parameter.count is not None:
if parameter.count != len(kw[parameter.name]):
raise DALError("Number of place holders does not match number of items to bind")
for item in kw[parameter.name]:
params.append(item)
else:
params.append(kw[parameter.name])
else:
params.append(parameter)
return SQLFragment(self.text, params)
def append(self, anotherStatement):
self.text += anotherStatement.text
self.parameters += anotherStatement.parameters
return self
def __eq__(self, stmt):
if not isinstance(stmt, SQLFragment):
return NotImplemented
return (self.text, self.parameters) == (stmt.text, stmt.parameters)
def __ne__(self, stmt):
if not isinstance(stmt, SQLFragment):
return NotImplemented
return not self.__eq__(stmt)
def __repr__(self):
return self.__class__.__name__ + repr((self.text, self.parameters))
def subSQL(self, queryGenerator, allTables):
return self
class Parameter(object):
"""
Used to represent a place holder for a value to be bound to the query
at a later date. If count > 1, then a "set" of parenthesized,
comma separate place holders will be generated.
"""
def __init__(self, name, count=None):
self.name = name
self.count = count
if self.count is not None and self.count < 1:
raise DALError("Must have Parameter.count > 0")
def __eq__(self, param):
if not isinstance(param, Parameter):
return NotImplemented
return self.name == param.name and self.count == param.count
def __ne__(self, param):
if not isinstance(param, Parameter):
return NotImplemented
return not self.__eq__(param)
def __repr__(self):
return 'Parameter(%r)' % (self.name,)
# Common helpers:
# current timestamp in UTC format. Hack to support standard syntax for this,
# rather than the compatibility procedure found in various databases.
utcNowSQL = NamedValue("CURRENT_TIMESTAMP at time zone 'UTC'")
# You can't insert a column with no rows. In SQL that just isn't valid syntax,
# and in this DAL you need at least one key or we can't tell what table you're
# talking about. Luckily there's the 'default' keyword to the rescue, which, in
# the context of an INSERT statement means 'use the default value explicitly'.
# (Although this is a special keyword in a CREATE statement, in an INSERT it
# behaves like an expression to the best of my knowledge.)
default = NamedValue('default')
calendarserver-5.2+dfsg/twext/enterprise/dal/__init__.py 0000644 0001750 0001750 00000002051 12263343324 022513 0 ustar rahul rahul ##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Toolkit for building a Data-Access Layer (DAL).
This includes an abstract representation of SQL objects like tables, columns,
sequences and queries, a parser to convert your schema to that representation,
and tools for working with it.
In some ways this is similar to the low levels of something like SQLAlchemy, but
it is designed to be more introspectable, to allow for features like automatic
caching and index detection. NB: work in progress.
"""
calendarserver-5.2+dfsg/twext/enterprise/util.py 0000644 0001750 0001750 00000006240 12263343324 021175 0 ustar rahul rahul # -*- test-case-name: twext.enterprise.test.test_util -*-
##
# Copyright (c) 2010-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Utilities for dealing with different databases.
"""
from datetime import datetime
SQL_TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S.%f"
def parseSQLTimestamp(ts):
"""
Parse an SQL timestamp string.
"""
# Handle case where fraction seconds may not be present
if len(ts) < len(SQL_TIMESTAMP_FORMAT):
ts += ".0"
return datetime.strptime(ts, SQL_TIMESTAMP_FORMAT)
def mapOracleOutputType(column):
"""
Map a single output value from cx_Oracle based on some rules and
expectations that we have based on the pgdb bindings.
@param column: a single value from a column.
@return: a converted value based on the type of the input; oracle CLOBs and
datetime timestamps will be converted to strings, unicode values will be
converted to UTF-8 encoded byte sequences (C{str}s), and floating point
numbers will be converted to integer types if they are integers. Any
other types will be left alone.
"""
if hasattr(column, 'read'):
# Try to detect large objects and format convert them to
# strings on the fly. We need to do this as we read each
# row, due to the issue described here -
# http://cx-oracle.sourceforge.net/html/lob.html - in
# particular, the part where it says "In particular, do not
# use the fetchall() method".
column = column.read()
elif isinstance(column, datetime):
# cx_Oracle properly maps the type of timestamps to datetime
# objects. However, our code is mostly written against
# PyGreSQL, which just emits strings as results and expects
# to have to convert them itself.. Since it's easier to
# just detect the datetimes and stringify them, for now
# we'll do that.
return column.strftime(SQL_TIMESTAMP_FORMAT)
elif isinstance(column, float):
# cx_Oracle maps _all_ nubmers to float types, which is more consistent,
# but we expect the database to be able to store integers as integers
# (in fact almost all the values in our schema are integers), so we map
# those values which exactly match back into integers.
if int(column) == column:
return int(column)
else:
return column
if isinstance(column, unicode):
# Finally, we process all data as UTF-8 bytestrings in order to reduce
# memory consumption. Pass any unicode string values back to the
# application as unicode.
column = column.encode('utf-8')
return column
calendarserver-5.2+dfsg/twext/protocols/ 0000755 0001750 0001750 00000000000 12322625326 017511 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/protocols/test/ 0000755 0001750 0001750 00000000000 12322625326 020470 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/protocols/test/test_memcache.py 0000644 0001750 0001750 00000044501 11156045201 023637 0 ustar rahul rahul # Copyright (c) 2007-2009 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test the memcache client protocol.
"""
from twext.protocols.memcache import MemCacheProtocol, NoSuchCommand
from twext.protocols.memcache import ClientError, ServerError
from twisted.trial.unittest import TestCase
from twisted.test.proto_helpers import StringTransportWithDisconnection
from twisted.internet.task import Clock
from twisted.internet.defer import Deferred, gatherResults, TimeoutError
class MemCacheTestCase(TestCase):
"""
Test client protocol class L{MemCacheProtocol}.
"""
def setUp(self):
"""
Create a memcache client, connect it to a string protocol, and make it
use a deterministic clock.
"""
self.proto = MemCacheProtocol()
self.clock = Clock()
self.proto.callLater = self.clock.callLater
self.transport = StringTransportWithDisconnection()
self.transport.protocol = self.proto
self.proto.makeConnection(self.transport)
def _test(self, d, send, recv, result):
"""
Shortcut method for classic tests.
@param d: the resulting deferred from the memcache command.
@type d: C{Deferred}
@param send: the expected data to be sent.
@type send: C{str}
@param recv: the data to simulate as reception.
@type recv: C{str}
@param result: the expected result.
@type result: C{any}
"""
def cb(res):
self.assertEquals(res, result)
self.assertEquals(self.transport.value(), send)
d.addCallback(cb)
self.proto.dataReceived(recv)
return d
def test_get(self):
"""
L{MemCacheProtocol.get} should return a L{Deferred} which is
called back with the value and the flag associated with the given key
if the server returns a successful result.
"""
return self._test(self.proto.get("foo"), "get foo\r\n",
"VALUE foo 0 3\r\nbar\r\nEND\r\n", (0, "bar"))
def test_emptyGet(self):
"""
Test getting a non-available key: it should succeed but return C{None}
as value and C{0} as flag.
"""
return self._test(self.proto.get("foo"), "get foo\r\n",
"END\r\n", (0, None))
def test_set(self):
"""
L{MemCacheProtocol.set} should return a L{Deferred} which is
called back with C{True} when the operation succeeds.
"""
return self._test(self.proto.set("foo", "bar"),
"set foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_add(self):
"""
L{MemCacheProtocol.add} should return a L{Deferred} which is
called back with C{True} when the operation succeeds.
"""
return self._test(self.proto.add("foo", "bar"),
"add foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_replace(self):
"""
L{MemCacheProtocol.replace} should return a L{Deferred} which
is called back with C{True} when the operation succeeds.
"""
return self._test(self.proto.replace("foo", "bar"),
"replace foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_errorAdd(self):
"""
Test an erroneous add: if a L{MemCacheProtocol.add} is called but the
key already exists on the server, it returns a B{NOT STORED} answer,
which should callback the resulting L{Deferred} with C{False}.
"""
return self._test(self.proto.add("foo", "bar"),
"add foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False)
def test_errorReplace(self):
"""
Test an erroneous replace: if a L{MemCacheProtocol.replace} is called
but the key doesn't exist on the server, it returns a B{NOT STORED}
answer, which should callback the resulting L{Deferred} with C{False}.
"""
return self._test(self.proto.replace("foo", "bar"),
"replace foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False)
def test_delete(self):
"""
L{MemCacheProtocol.delete} should return a L{Deferred} which is
called back with C{True} when the server notifies a success.
"""
return self._test(self.proto.delete("bar"), "delete bar\r\n",
"DELETED\r\n", True)
def test_errorDelete(self):
"""
Test a error during a delete: if key doesn't exist on the server, it
returns a B{NOT FOUND} answer which should callback the resulting
L{Deferred} with C{False}.
"""
return self._test(self.proto.delete("bar"), "delete bar\r\n",
"NOT FOUND\r\n", False)
def test_increment(self):
"""
Test incrementing a variable: L{MemCacheProtocol.increment} should
return a L{Deferred} which is called back with the incremented value of
the given key.
"""
return self._test(self.proto.increment("foo"), "incr foo 1\r\n",
"4\r\n", 4)
def test_decrement(self):
"""
Test decrementing a variable: L{MemCacheProtocol.decrement} should
return a L{Deferred} which is called back with the decremented value of
the given key.
"""
return self._test(
self.proto.decrement("foo"), "decr foo 1\r\n", "5\r\n", 5)
def test_incrementVal(self):
"""
L{MemCacheProtocol.increment} takes an optional argument C{value} which
should replace the default value of 1 when specified.
"""
return self._test(self.proto.increment("foo", 8), "incr foo 8\r\n",
"4\r\n", 4)
def test_decrementVal(self):
"""
L{MemCacheProtocol.decrement} takes an optional argument C{value} which
should replace the default value of 1 when specified.
"""
return self._test(self.proto.decrement("foo", 3), "decr foo 3\r\n",
"5\r\n", 5)
def test_stats(self):
"""
Test retrieving server statistics via the L{MemCacheProtocol.stats}
command: it should parse the data sent by the server and call back the
resulting L{Deferred} with a dictionary of the received statistics.
"""
return self._test(self.proto.stats(), "stats\r\n",
"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
{"foo": "bar", "egg": "spam"})
def test_statsWithArgument(self):
"""
L{MemCacheProtocol.stats} takes an optional C{str} argument which,
if specified, is sent along with the I{STAT} command. The I{STAT}
responses from the server are parsed as key/value pairs and returned
as a C{dict} (as in the case where the argument is not specified).
"""
return self._test(self.proto.stats("blah"), "stats blah\r\n",
"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
{"foo": "bar", "egg": "spam"})
def test_version(self):
"""
Test version retrieval via the L{MemCacheProtocol.version} command: it
should return a L{Deferred} which is called back with the version sent
by the server.
"""
return self._test(self.proto.version(), "version\r\n",
"VERSION 1.1\r\n", "1.1")
def test_flushAll(self):
"""
L{MemCacheProtocol.flushAll} should return a L{Deferred} which is
called back with C{True} if the server acknowledges success.
"""
return self._test(self.proto.flushAll(), "flush_all\r\n",
"OK\r\n", True)
def test_invalidGetResponse(self):
"""
If the value returned doesn't match the expected key of the current, we
should get an error in L{MemCacheProtocol.dataReceived}.
"""
self.proto.get("foo")
s = "spamegg"
self.assertRaises(RuntimeError,
self.proto.dataReceived,
"VALUE bar 0 %s\r\n%s\r\nEND\r\n" % (len(s), s))
def test_timeOut(self):
"""
Test the timeout on outgoing requests: when timeout is detected, all
current commands should fail with a L{TimeoutError}, and the
connection should be closed.
"""
d1 = self.proto.get("foo")
d2 = self.proto.get("bar")
d3 = Deferred()
self.proto.connectionLost = d3.callback
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, TimeoutError)
def checkMessage(error):
self.assertEquals(str(error), "Connection timeout")
d1.addCallback(checkMessage)
return gatherResults([d1, d2, d3])
def test_timeoutRemoved(self):
"""
When a request gets a response, no pending timeout call should remain
around.
"""
d = self.proto.get("foo")
self.clock.advance(self.proto.persistentTimeOut - 1)
self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")
def check(result):
self.assertEquals(result, (0, "bar"))
self.assertEquals(len(self.clock.calls), 0)
d.addCallback(check)
return d
def test_timeOutRaw(self):
"""
Test the timeout when raw mode was started: the timeout should not be
reset until all the data has been received, so we can have a
L{TimeoutError} when waiting for raw data.
"""
d1 = self.proto.get("foo")
d2 = Deferred()
self.proto.connectionLost = d2.callback
self.proto.dataReceived("VALUE foo 0 10\r\n12345")
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
return gatherResults([d1, d2])
def test_timeOutStat(self):
"""
Test the timeout when stat command has started: the timeout should not
be reset until the final B{END} is received.
"""
d1 = self.proto.stats()
d2 = Deferred()
self.proto.connectionLost = d2.callback
self.proto.dataReceived("STAT foo bar\r\n")
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
return gatherResults([d1, d2])
def test_timeoutPipelining(self):
"""
When two requests are sent, a timeout call should remain around for the
second request, and its timeout time should be correct.
"""
d1 = self.proto.get("foo")
d2 = self.proto.get("bar")
d3 = Deferred()
self.proto.connectionLost = d3.callback
self.clock.advance(self.proto.persistentTimeOut - 1)
self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")
def check(result):
self.assertEquals(result, (0, "bar"))
self.assertEquals(len(self.clock.calls), 1)
for i in range(self.proto.persistentTimeOut):
self.clock.advance(1)
return self.assertFailure(d2, TimeoutError).addCallback(checkTime)
def checkTime(ignored):
# Check that the timeout happened C{self.proto.persistentTimeOut}
# after the last response
self.assertEquals(self.clock.seconds(),
2 * self.proto.persistentTimeOut - 1)
d1.addCallback(check)
return d1
def test_timeoutNotReset(self):
"""
Check that timeout is not resetted for every command, but keep the
timeout from the first command without response.
"""
d1 = self.proto.get("foo")
d3 = Deferred()
self.proto.connectionLost = d3.callback
self.clock.advance(self.proto.persistentTimeOut - 1)
d2 = self.proto.get("bar")
self.clock.advance(1)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, TimeoutError)
return gatherResults([d1, d2, d3])
def test_tooLongKey(self):
"""
Test that an error is raised when trying to use a too long key: the
called command should return a L{Deferred} which fail with a
L{ClientError}.
"""
d1 = self.assertFailure(self.proto.set("a" * 500, "bar"), ClientError)
d2 = self.assertFailure(self.proto.increment("a" * 500), ClientError)
d3 = self.assertFailure(self.proto.get("a" * 500), ClientError)
d4 = self.assertFailure(self.proto.append("a" * 500, "bar"), ClientError)
d5 = self.assertFailure(self.proto.prepend("a" * 500, "bar"), ClientError)
return gatherResults([d1, d2, d3, d4, d5])
def test_invalidCommand(self):
"""
When an unknown command is sent directly (not through public API), the
server answers with an B{ERROR} token, and the command should fail with
L{NoSuchCommand}.
"""
d = self.proto._set("egg", "foo", "bar", 0, 0, "")
self.assertEquals(self.transport.value(), "egg foo 0 0 3\r\nbar\r\n")
self.assertFailure(d, NoSuchCommand)
self.proto.dataReceived("ERROR\r\n")
return d
def test_clientError(self):
"""
Test the L{ClientError} error: when the server send a B{CLIENT_ERROR}
token, the originating command should fail with L{ClientError}, and the
error should contain the text sent by the server.
"""
a = "eggspamm"
d = self.proto.set("foo", a)
self.assertEquals(self.transport.value(),
"set foo 0 0 8\r\neggspamm\r\n")
self.assertFailure(d, ClientError)
def check(err):
self.assertEquals(str(err), "We don't like egg and spam")
d.addCallback(check)
self.proto.dataReceived("CLIENT_ERROR We don't like egg and spam\r\n")
return d
def test_serverError(self):
"""
Test the L{ServerError} error: when the server send a B{SERVER_ERROR}
token, the originating command should fail with L{ServerError}, and the
error should contain the text sent by the server.
"""
a = "eggspamm"
d = self.proto.set("foo", a)
self.assertEquals(self.transport.value(),
"set foo 0 0 8\r\neggspamm\r\n")
self.assertFailure(d, ServerError)
def check(err):
self.assertEquals(str(err), "zomg")
d.addCallback(check)
self.proto.dataReceived("SERVER_ERROR zomg\r\n")
return d
def test_unicodeKey(self):
"""
Using a non-string key as argument to commands should raise an error.
"""
d1 = self.assertFailure(self.proto.set(u"foo", "bar"), ClientError)
d2 = self.assertFailure(self.proto.increment(u"egg"), ClientError)
d3 = self.assertFailure(self.proto.get(1), ClientError)
d4 = self.assertFailure(self.proto.delete(u"bar"), ClientError)
d5 = self.assertFailure(self.proto.append(u"foo", "bar"), ClientError)
d6 = self.assertFailure(self.proto.prepend(u"foo", "bar"), ClientError)
return gatherResults([d1, d2, d3, d4, d5, d6])
def test_unicodeValue(self):
"""
Using a non-string value should raise an error.
"""
return self.assertFailure(self.proto.set("foo", u"bar"), ClientError)
def test_pipelining(self):
"""
Test that multiple requests can be sent subsequently to the server, and
that the protocol order the responses correctly and dispatch to the
corresponding client command.
"""
d1 = self.proto.get("foo")
d1.addCallback(self.assertEquals, (0, "bar"))
d2 = self.proto.set("bar", "spamspamspam")
d2.addCallback(self.assertEquals, True)
d3 = self.proto.get("egg")
d3.addCallback(self.assertEquals, (0, "spam"))
self.assertEquals(self.transport.value(),
"get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n")
self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n"
"STORED\r\n"
"VALUE egg 0 4\r\nspam\r\nEND\r\n")
return gatherResults([d1, d2, d3])
def test_getInChunks(self):
"""
If the value retrieved by a C{get} arrive in chunks, the protocol
should be able to reconstruct it and to produce the good value.
"""
d = self.proto.get("foo")
d.addCallback(self.assertEquals, (0, "0123456789"))
self.assertEquals(self.transport.value(), "get foo\r\n")
self.proto.dataReceived("VALUE foo 0 10\r\n0123456")
self.proto.dataReceived("789")
self.proto.dataReceived("\r\nEND")
self.proto.dataReceived("\r\n")
return d
def test_append(self):
"""
L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set}
method: it should return a L{Deferred} which is called back with
C{True} when the operation succeeds.
"""
return self._test(self.proto.append("foo", "bar"),
"append foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_prepend(self):
"""
L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set}
method: it should return a L{Deferred} which is called back with
C{True} when the operation succeeds.
"""
return self._test(self.proto.prepend("foo", "bar"),
"prepend foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_gets(self):
"""
L{MemCacheProtocol.get} should handle an additional cas result when
C{withIdentifier} is C{True} and forward it in the resulting
L{Deferred}.
"""
return self._test(self.proto.get("foo", True), "gets foo\r\n",
"VALUE foo 0 3 1234\r\nbar\r\nEND\r\n", (0, "1234", "bar"))
def test_emptyGets(self):
"""
Test getting a non-available key with gets: it should succeed but
return C{None} as value, C{0} as flag and an empty cas value.
"""
return self._test(self.proto.get("foo", True), "gets foo\r\n",
"END\r\n", (0, "", None))
def test_checkAndSet(self):
"""
L{MemCacheProtocol.checkAndSet} passes an additional cas identifier that the
server should handle to check if the data has to be updated.
"""
return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"),
"cas foo 0 0 3 1234\r\nbar\r\n", "STORED\r\n", True)
def test_casUnknowKey(self):
"""
When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the resulting
L{Deferred} should fire with C{False}.
"""
return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"),
"cas foo 0 0 3 1234\r\nbar\r\n", "EXISTS\r\n", False)
calendarserver-5.2+dfsg/twext/protocols/test/__init__.py 0000644 0001750 0001750 00000001207 12263343324 022600 0 ustar rahul rahul ##
# Copyright (c) 2009-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Extentions to twisted.protocols
"""
calendarserver-5.2+dfsg/twext/protocols/memcache.py 0000644 0001750 0001750 00000046372 12147725751 021651 0 ustar rahul rahul # -*- test-case-name: twisted.test.test_memcache -*-
# Copyright (c) 2007-2009 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Memcache client protocol. Memcached is a caching server, storing data in the
form of pairs key/value, and memcache is the protocol to talk with it.
To connect to a server, create a factory for L{MemCacheProtocol}::
from twisted.internet import reactor, protocol
from twisted.protocols.memcache import MemCacheProtocol, DEFAULT_PORT
d = protocol.ClientCreator(reactor, MemCacheProtocol
).connectTCP("localhost", DEFAULT_PORT)
def doSomething(proto):
# Here you call the memcache operations
return proto.set("mykey", "a lot of data")
d.addCallback(doSomething)
reactor.run()
All the operations of the memcache protocol are present, but
L{MemCacheProtocol.set} and L{MemCacheProtocol.get} are the more important.
See U{http://code.sixapart.com/svn/memcached/trunk/server/doc/protocol.txt} for
more information about the protocol.
"""
try:
from collections import deque
except ImportError:
class deque(list):
def popleft(self):
return self.pop(0)
from twisted.protocols.basic import LineReceiver
from twisted.protocols.policies import TimeoutMixin
from twisted.internet.defer import Deferred, fail, TimeoutError
from twext.python.log import Logger
log = Logger()
DEFAULT_PORT = 11211
class NoSuchCommand(Exception):
"""
Exception raised when a non existent command is called.
"""
class ClientError(Exception):
"""
Error caused by an invalid client call.
"""
class ServerError(Exception):
"""
Problem happening on the server.
"""
class Command(object):
"""
Wrap a client action into an object, that holds the values used in the
protocol.
@ivar _deferred: the L{Deferred} object that will be fired when the result
arrives.
@type _deferred: L{Deferred}
@ivar command: name of the command sent to the server.
@type command: C{str}
"""
def __init__(self, command, **kwargs):
"""
Create a command.
@param command: the name of the command.
@type command: C{str}
@param kwargs: this values will be stored as attributes of the object
for future use
"""
self.command = command
self._deferred = Deferred()
for k, v in kwargs.items():
setattr(self, k, v)
def success(self, value):
"""
Shortcut method to fire the underlying deferred.
"""
self._deferred.callback(value)
def fail(self, error):
"""
Make the underlying deferred fails.
"""
self._deferred.errback(error)
class MemCacheProtocol(LineReceiver, TimeoutMixin):
"""
MemCache protocol: connect to a memcached server to store/retrieve values.
@ivar persistentTimeOut: the timeout period used to wait for a response.
@type persistentTimeOut: C{int}
@ivar _current: current list of requests waiting for an answer from the
server.
@type _current: C{deque} of L{Command}
@ivar _lenExpected: amount of data expected in raw mode, when reading for
a value.
@type _lenExpected: C{int}
@ivar _getBuffer: current buffer of data, used to store temporary data
when reading in raw mode.
@type _getBuffer: C{list}
@ivar _bufferLength: the total amount of bytes in C{_getBuffer}.
@type _bufferLength: C{int}
"""
MAX_KEY_LENGTH = 250
def __init__(self, timeOut=60):
"""
Create the protocol.
@param timeOut: the timeout to wait before detecting that the
connection is dead and close it. It's expressed in seconds.
@type timeOut: C{int}
"""
self._current = deque()
self._lenExpected = None
self._getBuffer = None
self._bufferLength = None
self.persistentTimeOut = self.timeOut = timeOut
def timeoutConnection(self):
"""
Close the connection in case of timeout.
"""
for cmd in self._current:
cmd.fail(TimeoutError("Connection timeout"))
self.transport.loseConnection()
def sendLine(self, line):
"""
Override sendLine to add a timeout to response.
"""
if not self._current:
self.setTimeout(self.persistentTimeOut)
LineReceiver.sendLine(self, line)
def rawDataReceived(self, data):
"""
Collect data for a get.
"""
self.resetTimeout()
self._getBuffer.append(data)
self._bufferLength += len(data)
if self._bufferLength >= self._lenExpected + 2:
data = "".join(self._getBuffer)
buf = data[:self._lenExpected]
rem = data[self._lenExpected + 2:]
val = buf
self._lenExpected = None
self._getBuffer = None
self._bufferLength = None
cmd = self._current[0]
cmd.value = val
self.setLineMode(rem)
def cmd_STORED(self):
"""
Manage a success response to a set operation.
"""
self._current.popleft().success(True)
def cmd_NOT_STORED(self):
"""
Manage a specific 'not stored' response to a set operation: this is not
an error, but some condition wasn't met.
"""
self._current.popleft().success(False)
def cmd_END(self):
"""
This the end token to a get or a stat operation.
"""
cmd = self._current.popleft()
if cmd.command == "get":
cmd.success((cmd.flags, cmd.value))
elif cmd.command == "gets":
cmd.success((cmd.flags, cmd.cas, cmd.value))
elif cmd.command == "stats":
cmd.success(cmd.values)
def cmd_NOT_FOUND(self):
"""
Manage error response for incr/decr/delete.
"""
self._current.popleft().success(False)
def cmd_VALUE(self, line):
"""
Prepare the reading a value after a get.
"""
cmd = self._current[0]
if cmd.command == "get":
key, flags, length = line.split()
cas = ""
else:
key, flags, length, cas = line.split()
self._lenExpected = int(length)
self._getBuffer = []
self._bufferLength = 0
if cmd.key != key:
raise RuntimeError("Unexpected commands answer.")
cmd.flags = int(flags)
cmd.length = self._lenExpected
cmd.cas = cas
self.setRawMode()
def cmd_STAT(self, line):
"""
Reception of one stat line.
"""
cmd = self._current[0]
key, val = line.split(" ", 1)
cmd.values[key] = val
def cmd_VERSION(self, versionData):
"""
Read version token.
"""
self._current.popleft().success(versionData)
def cmd_ERROR(self):
"""
An non-existent command has been sent.
"""
log.error("Non-existent command sent.")
cmd = self._current.popleft()
cmd.fail(NoSuchCommand())
def cmd_CLIENT_ERROR(self, errText):
"""
An invalid input as been sent.
"""
log.error("Invalid input: %s" % (errText,))
cmd = self._current.popleft()
cmd.fail(ClientError(errText))
def cmd_SERVER_ERROR(self, errText):
"""
An error has happened server-side.
"""
log.error("Server error: %s" % (errText,))
cmd = self._current.popleft()
cmd.fail(ServerError(errText))
def cmd_DELETED(self):
"""
A delete command has completed successfully.
"""
self._current.popleft().success(True)
def cmd_OK(self):
"""
The last command has been completed.
"""
self._current.popleft().success(True)
def cmd_EXISTS(self):
"""
A C{checkAndSet} update has failed.
"""
self._current.popleft().success(False)
def lineReceived(self, line):
"""
Receive line commands from the server.
"""
self.resetTimeout()
token = line.split(" ", 1)[0]
# First manage standard commands without space
cmd = getattr(self, "cmd_%s" % (token,), None)
if cmd is not None:
args = line.split(" ", 1)[1:]
if args:
cmd(args[0])
else:
cmd()
else:
# Then manage commands with space in it
line = line.replace(" ", "_")
cmd = getattr(self, "cmd_%s" % (line,), None)
if cmd is not None:
cmd()
else:
# Increment/Decrement response
cmd = self._current.popleft()
val = int(line)
cmd.success(val)
if not self._current:
# No pending request, remove timeout
self.setTimeout(None)
def increment(self, key, val=1):
"""
Increment the value of C{key} by given value (default to 1).
C{key} must be consistent with an int. Return the new value.
@param key: the key to modify.
@type key: C{str}
@param val: the value to increment.
@type val: C{int}
@return: a deferred with will be called back with the new value
associated with the key (after the increment).
@rtype: L{Deferred}
"""
return self._incrdecr("incr", key, val)
def decrement(self, key, val=1):
"""
Decrement the value of C{key} by given value (default to 1).
C{key} must be consistent with an int. Return the new value, coerced to
0 if negative.
@param key: the key to modify.
@type key: C{str}
@param val: the value to decrement.
@type val: C{int}
@return: a deferred with will be called back with the new value
associated with the key (after the decrement).
@rtype: L{Deferred}
"""
return self._incrdecr("decr", key, val)
def _incrdecr(self, cmd, key, val):
"""
Internal wrapper for incr/decr.
"""
if not isinstance(key, str):
return fail(ClientError(
"Invalid type for key: %s, expecting a string" % (type(key),)))
if len(key) > self.MAX_KEY_LENGTH:
return fail(ClientError("Key too long"))
fullcmd = "%s %s %d" % (cmd, key, int(val))
self.sendLine(fullcmd)
cmdObj = Command(cmd, key=key)
self._current.append(cmdObj)
return cmdObj._deferred
def replace(self, key, val, flags=0, expireTime=0):
"""
Replace the given C{key}. It must already exist in the server.
@param key: the key to replace.
@type key: C{str}
@param val: the new value associated with the key.
@type val: C{str}
@param flags: the flags to store with the key.
@type flags: C{int}
@param expireTime: if different from 0, the relative time in seconds
when the key will be deleted from the store.
@type expireTime: C{int}
@return: a deferred that will fire with C{True} if the operation has
succeeded, and C{False} with the key didn't previously exist.
@rtype: L{Deferred}
"""
return self._set("replace", key, val, flags, expireTime, "")
def add(self, key, val, flags=0, expireTime=0):
"""
Add the given C{key}. It must not exist in the server.
@param key: the key to add.
@type key: C{str}
@param val: the value associated with the key.
@type val: C{str}
@param flags: the flags to store with the key.
@type flags: C{int}
@param expireTime: if different from 0, the relative time in seconds
when the key will be deleted from the store.
@type expireTime: C{int}
@return: a deferred that will fire with C{True} if the operation has
succeeded, and C{False} with the key already exists.
@rtype: L{Deferred}
"""
return self._set("add", key, val, flags, expireTime, "")
def set(self, key, val, flags=0, expireTime=0):
"""
Set the given C{key}.
@param key: the key to set.
@type key: C{str}
@param val: the value associated with the key.
@type val: C{str}
@param flags: the flags to store with the key.
@type flags: C{int}
@param expireTime: if different from 0, the relative time in seconds
when the key will be deleted from the store.
@type expireTime: C{int}
@return: a deferred that will fire with C{True} if the operation has
succeeded.
@rtype: L{Deferred}
"""
return self._set("set", key, val, flags, expireTime, "")
def checkAndSet(self, key, val, cas, flags=0, expireTime=0):
"""
Change the content of C{key} only if the C{cas} value matches the
current one associated with the key. Use this to store a value which
hasn't been modified since last time you fetched it.
@param key: The key to set.
@type key: C{str}
@param val: The value associated with the key.
@type val: C{str}
@param cas: Unique 64-bit value returned by previous call of C{get}.
@type cas: C{str}
@param flags: The flags to store with the key.
@type flags: C{int}
@param expireTime: If different from 0, the relative time in seconds
when the key will be deleted from the store.
@type expireTime: C{int}
@return: A deferred that will fire with C{True} if the operation has
succeeded, C{False} otherwise.
@rtype: L{Deferred}
"""
return self._set("cas", key, val, flags, expireTime, cas)
def _set(self, cmd, key, val, flags, expireTime, cas):
"""
Internal wrapper for setting values.
"""
if not isinstance(key, str):
return fail(ClientError(
"Invalid type for key: %s, expecting a string" % (type(key),)))
if len(key) > self.MAX_KEY_LENGTH:
return fail(ClientError("Key too long"))
if not isinstance(val, str):
return fail(ClientError(
"Invalid type for value: %s, expecting a string" %
(type(val),)))
if cas:
cas = " " + cas
length = len(val)
fullcmd = "%s %s %d %d %d%s" % (
cmd, key, flags, expireTime, length, cas)
self.sendLine(fullcmd)
self.sendLine(val)
cmdObj = Command(cmd, key=key, flags=flags, length=length)
self._current.append(cmdObj)
return cmdObj._deferred
def append(self, key, val):
"""
Append given data to the value of an existing key.
@param key: The key to modify.
@type key: C{str}
@param val: The value to append to the current value associated with
the key.
@type val: C{str}
@return: A deferred that will fire with C{True} if the operation has
succeeded, C{False} otherwise.
@rtype: L{Deferred}
"""
# Even if flags and expTime values are ignored, we have to pass them
return self._set("append", key, val, 0, 0, "")
def prepend(self, key, val):
"""
Prepend given data to the value of an existing key.
@param key: The key to modify.
@type key: C{str}
@param val: The value to prepend to the current value associated with
the key.
@type val: C{str}
@return: A deferred that will fire with C{True} if the operation has
succeeded, C{False} otherwise.
@rtype: L{Deferred}
"""
# Even if flags and expTime values are ignored, we have to pass them
return self._set("prepend", key, val, 0, 0, "")
def get(self, key, withIdentifier=False):
"""
Get the given C{key}. It doesn't support multiple keys. If
C{withIdentifier} is set to C{True}, the command issued is a C{gets},
that will return the current identifier associated with the value. This
identifier has to be used when issuing C{checkAndSet} update later,
using the corresponding method.
@param key: The key to retrieve.
@type key: C{str}
@param withIdentifier: If set to C{True}, retrieve the current
identifier along with the value and the flags.
@type withIdentifier: C{bool}
@return: A deferred that will fire with the tuple (flags, value) if
C{withIdentifier} is C{False}, or (flags, cas identifier, value)
if C{True}.
@rtype: L{Deferred}
"""
if not isinstance(key, str):
return fail(ClientError(
"Invalid type for key: %s, expecting a string" % (type(key),)))
if len(key) > self.MAX_KEY_LENGTH:
return fail(ClientError("Key too long"))
if withIdentifier:
cmd = "gets"
else:
cmd = "get"
fullcmd = "%s %s" % (cmd, key)
self.sendLine(fullcmd)
cmdObj = Command(cmd, key=key, value=None, flags=0, cas="")
self._current.append(cmdObj)
return cmdObj._deferred
def stats(self, arg=None):
"""
Get some stats from the server. It will be available as a dict.
@param arg: An optional additional string which will be sent along
with the I{stats} command. The interpretation of this value by
the server is left undefined by the memcache protocol
specification.
@type arg: L{NoneType} or L{str}
@return: a deferred that will fire with a C{dict} of the available
statistics.
@rtype: L{Deferred}
"""
cmd = "stats"
if arg:
cmd = "stats " + arg
self.sendLine(cmd)
cmdObj = Command("stats", values={})
self._current.append(cmdObj)
return cmdObj._deferred
def version(self):
"""
Get the version of the server.
@return: a deferred that will fire with the string value of the
version.
@rtype: L{Deferred}
"""
self.sendLine("version")
cmdObj = Command("version")
self._current.append(cmdObj)
return cmdObj._deferred
def delete(self, key):
"""
Delete an existing C{key}.
@param key: the key to delete.
@type key: C{str}
@return: a deferred that will be called back with C{True} if the key
was successfully deleted, or C{False} if not.
@rtype: L{Deferred}
"""
if not isinstance(key, str):
return fail(ClientError(
"Invalid type for key: %s, expecting a string" % (type(key),)))
self.sendLine("delete %s" % key)
cmdObj = Command("delete", key=key)
self._current.append(cmdObj)
return cmdObj._deferred
def flushAll(self):
"""
Flush all cached values.
@return: a deferred that will be called back with C{True} when the
operation has succeeded.
@rtype: L{Deferred}
"""
self.sendLine("flush_all")
cmdObj = Command("flush_all")
self._current.append(cmdObj)
return cmdObj._deferred
__all__ = ["MemCacheProtocol", "DEFAULT_PORT", "NoSuchCommand", "ClientError",
"ServerError"]
calendarserver-5.2+dfsg/twext/protocols/__init__.py 0000644 0001750 0001750 00000001207 12263343324 021621 0 ustar rahul rahul ##
# Copyright (c) 2009-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Extentions to twisted.protocols
"""
calendarserver-5.2+dfsg/twext/backport/ 0000755 0001750 0001750 00000000000 12322625326 017272 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/backport/__init__.py 0000644 0001750 0001750 00000001302 12263343324 021376 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Backports of portions of Twisted.
(Specifically, those required for IPv6 client support).
"""
calendarserver-5.2+dfsg/twext/backport/internet/ 0000755 0001750 0001750 00000000000 12322625326 021122 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/backport/internet/tcp.py 0000644 0001750 0001750 00000117075 11742073632 022277 0 ustar rahul rahul # -*- test-case-name: twisted.test.test_tcp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Various asynchronous TCP/IP classes.
End users shouldn't use this module directly - use the reactor APIs instead.
"""
# System Imports
import types
import socket
import sys
import operator
import struct
from zope.interface import implements
from twisted.python.runtime import platformType
from twisted.python import versions, deprecate
try:
# Try to get the memory BIO based startTLS implementation, available since
# pyOpenSSL 0.10
from twisted.internet._newtls import (
ConnectionMixin as _TLSConnectionMixin,
ClientMixin as _TLSClientMixin,
ServerMixin as _TLSServerMixin)
except ImportError:
try:
# Try to get the socket BIO based startTLS implementation, available in
# all pyOpenSSL versions
from twisted.internet._oldtls import (
ConnectionMixin as _TLSConnectionMixin,
ClientMixin as _TLSClientMixin,
ServerMixin as _TLSServerMixin)
except ImportError:
# There is no version of startTLS available
class _TLSConnectionMixin(object):
TLS = False
class _TLSClientMixin(object):
pass
class _TLSServerMixin(object):
pass
if platformType == 'win32':
# no such thing as WSAEPERM or error code 10001 according to winsock.h or MSDN
EPERM = object()
from errno import WSAEINVAL as EINVAL
from errno import WSAEWOULDBLOCK as EWOULDBLOCK
from errno import WSAEINPROGRESS as EINPROGRESS
from errno import WSAEALREADY as EALREADY
from errno import WSAECONNRESET as ECONNRESET
from errno import WSAEISCONN as EISCONN
from errno import WSAENOTCONN as ENOTCONN
from errno import WSAEINTR as EINTR
from errno import WSAENOBUFS as ENOBUFS
from errno import WSAEMFILE as EMFILE
# No such thing as WSAENFILE, either.
ENFILE = object()
# Nor ENOMEM
ENOMEM = object()
EAGAIN = EWOULDBLOCK
from errno import WSAECONNRESET as ECONNABORTED
from twisted.python.win32 import formatError as strerror
else:
from errno import EPERM
from errno import EINVAL
from errno import EWOULDBLOCK
from errno import EINPROGRESS
from errno import EALREADY
from errno import ECONNRESET
from errno import EISCONN
from errno import ENOTCONN
from errno import EINTR
from errno import ENOBUFS
from errno import EMFILE
from errno import ENFILE
from errno import ENOMEM
from errno import EAGAIN
from errno import ECONNABORTED
from os import strerror
from errno import errorcode
# Twisted Imports
from twisted.internet import base, address, fdesc
from twisted.internet.task import deferLater
from twisted.python import log, failure, reflect
from twisted.python.util import unsignedID
from twisted.internet.error import CannotListenError
from twisted.internet import abstract, main, interfaces, error
# Not all platforms have, or support, this flag.
_AI_NUMERICSERV = getattr(socket, "AI_NUMERICSERV", 0)
class _SocketCloser(object):
_socketShutdownMethod = 'shutdown'
def _closeSocket(self, orderly):
# The call to shutdown() before close() isn't really necessary, because
# we set FD_CLOEXEC now, which will ensure this is the only process
# holding the FD, thus ensuring close() really will shutdown the TCP
# socket. However, do it anyways, just to be safe.
skt = self.socket
try:
if orderly:
getattr(skt, self._socketShutdownMethod)(2)
else:
# Set SO_LINGER to 1,0 which, by convention, causes a
# connection reset to be sent when close is called,
# instead of the standard FIN shutdown sequence.
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack("ii", 1, 0))
except socket.error:
pass
try:
skt.close()
except socket.error:
pass
class _AbortingMixin(object):
"""
Common implementation of C{abortConnection}.
@ivar _aborting: Set to C{True} when C{abortConnection} is called.
@type _aborting: C{bool}
"""
_aborting = False
def abortConnection(self):
"""
Aborts the connection immediately, dropping any buffered data.
@since: 11.1
"""
if self.disconnected or self._aborting:
return
self._aborting = True
self.stopReading()
self.stopWriting()
self.doRead = lambda *args, **kwargs: None
self.doWrite = lambda *args, **kwargs: None
self.reactor.callLater(0, self.connectionLost,
failure.Failure(error.ConnectionAborted()))
class Connection(_TLSConnectionMixin, abstract.FileDescriptor, _SocketCloser,
_AbortingMixin):
"""
Superclass of all socket-based FileDescriptors.
This is an abstract superclass of all objects which represent a TCP/IP
connection based socket.
@ivar logstr: prefix used when logging events related to this connection.
@type logstr: C{str}
"""
implements(interfaces.ITCPTransport, interfaces.ISystemHandle)
def __init__(self, skt, protocol, reactor=None):
abstract.FileDescriptor.__init__(self, reactor=reactor)
self.socket = skt
self.socket.setblocking(0)
self.fileno = skt.fileno
self.protocol = protocol
def getHandle(self):
"""Return the socket for this connection."""
return self.socket
def doRead(self):
"""Calls self.protocol.dataReceived with all available data.
This reads up to self.bufferSize bytes of data from its socket, then
calls self.dataReceived(data) to process it. If the connection is not
lost through an error in the physical recv(), this function will return
the result of the dataReceived call.
"""
try:
data = self.socket.recv(self.bufferSize)
except socket.error, se:
if se.args[0] == EWOULDBLOCK:
return
else:
return main.CONNECTION_LOST
if not data:
return main.CONNECTION_DONE
rval = self.protocol.dataReceived(data)
if rval is not None:
offender = self.protocol.dataReceived
warningFormat = (
'Returning a value other than None from %(fqpn)s is '
'deprecated since %(version)s.')
warningString = deprecate.getDeprecationWarningString(
offender, versions.Version('Twisted', 11, 0, 0),
format=warningFormat)
deprecate.warnAboutFunction(offender, warningString)
return rval
def writeSomeData(self, data):
"""
Write as much as possible of the given data to this TCP connection.
This sends up to C{self.SEND_LIMIT} bytes from C{data}. If the
connection is lost, an exception is returned. Otherwise, the number
of bytes successfully written is returned.
"""
try:
# Limit length of buffer to try to send, because some OSes are too
# stupid to do so themselves (ahem windows)
return self.socket.send(buffer(data, 0, self.SEND_LIMIT))
except socket.error, se:
if se.args[0] == EINTR:
return self.writeSomeData(data)
elif se.args[0] in (EWOULDBLOCK, ENOBUFS):
return 0
else:
return main.CONNECTION_LOST
def _closeWriteConnection(self):
try:
getattr(self.socket, self._socketShutdownMethod)(1)
except socket.error:
pass
p = interfaces.IHalfCloseableProtocol(self.protocol, None)
if p:
try:
p.writeConnectionLost()
except:
f = failure.Failure()
log.err()
self.connectionLost(f)
def readConnectionLost(self, reason):
p = interfaces.IHalfCloseableProtocol(self.protocol, None)
if p:
try:
p.readConnectionLost()
except:
log.err()
self.connectionLost(failure.Failure())
else:
self.connectionLost(reason)
def connectionLost(self, reason):
"""See abstract.FileDescriptor.connectionLost().
"""
# Make sure we're not called twice, which can happen e.g. if
# abortConnection() is called from protocol's dataReceived and then
# code immediately after throws an exception that reaches the
# reactor. We can't rely on "disconnected" attribute for this check
# since twisted.internet._oldtls does evil things to it:
if not hasattr(self, "socket"):
return
abstract.FileDescriptor.connectionLost(self, reason)
self._closeSocket(not reason.check(error.ConnectionAborted))
protocol = self.protocol
del self.protocol
del self.socket
del self.fileno
protocol.connectionLost(reason)
logstr = "Uninitialized"
def logPrefix(self):
"""Return the prefix to log with when I own the logging thread.
"""
return self.logstr
def getTcpNoDelay(self):
return operator.truth(self.socket.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY))
def setTcpNoDelay(self, enabled):
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled)
def getTcpKeepAlive(self):
return operator.truth(self.socket.getsockopt(socket.SOL_SOCKET,
socket.SO_KEEPALIVE))
def setTcpKeepAlive(self, enabled):
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, enabled)
class _BaseBaseClient(object):
"""
Code shared with other (non-POSIX) reactors for management of general
outgoing connections.
Requirements upon subclasses are documented as instance variables rather
than abstract methods, in order to avoid MRO confusion, since this base is
mixed in to unfortunately weird and distinctive multiple-inheritance
hierarchies and many of these attributes are provided by peer classes
rather than descendant classes in those hierarchies.
@ivar addressFamily: The address family constant (C{socket.AF_INET},
C{socket.AF_INET6}, C{socket.AF_UNIX}) of the underlying socket of this
client connection.
@type addressFamily: C{int}
@ivar socketType: The socket type constant (C{socket.SOCK_STREAM} or
C{socket.SOCK_DGRAM}) of the underlying socket.
@type socketType: C{int}
@ivar _requiresResolution: A flag indicating whether the address of this
client will require name resolution. C{True} if the hostname of said
address indicates a name that must be resolved by hostname lookup,
C{False} if it indicates an IP address literal.
@type _requiresResolution: C{bool}
@cvar _commonConnection: Subclasses must provide this attribute, which
indicates the L{Connection}-alike class to invoke C{__init__} and
C{connectionLost} on.
@type _commonConnection: C{type}
@ivar _stopReadingAndWriting: Subclasses must implement in order to remove
this transport from its reactor's notifications in response to a
terminated connection attempt.
@type _stopReadingAndWriting: 0-argument callable returning C{None}
@ivar _closeSocket: Subclasses must implement in order to close the socket
in response to a terminated connection attempt.
@type _closeSocket: 1-argument callable; see L{_SocketCloser._closeSocket}
@ivar _collectSocketDetails: Clean up references to the attached socket in
its underlying OS resource (such as a file descriptor or file handle),
as part of post connection-failure cleanup.
@type _collectSocketDetails: 0-argument callable returning C{None}.
@ivar reactor: The class pointed to by C{_commonConnection} should set this
attribute in its constructor.
@type reactor: L{twisted.internet.interfaces.IReactorTime},
L{twisted.internet.interfaces.IReactorCore},
L{twisted.internet.interfaces.IReactorFDSet}
"""
addressFamily = socket.AF_INET
socketType = socket.SOCK_STREAM
def _finishInit(self, whenDone, skt, error, reactor):
"""
Called by subclasses to continue to the stage of initialization where
the socket connect attempt is made.
@param whenDone: A 0-argument callable to invoke once the connection is
set up. This is C{None} if the connection could not be prepared
due to a previous error.
@param skt: The socket object to use to perform the connection.
@type skt: C{socket._socketobject}
@param error: The error to fail the connection with.
@param reactor: The reactor to use for this client.
@type reactor: L{twisted.internet.interfaces.IReactorTime}
"""
if whenDone:
self._commonConnection.__init__(self, skt, None, reactor)
reactor.callLater(0, whenDone)
else:
reactor.callLater(0, self.failIfNotConnected, error)
def resolveAddress(self):
"""
Resolve the name that was passed to this L{_BaseBaseClient}, if
necessary, and then move on to attempting the connection once an
address has been determined. (The connection will be attempted
immediately within this function if either name resolution can be
synchronous or the address was an IP address literal.)
@note: You don't want to call this method from outside, as it won't do
anything useful; it's just part of the connection bootstrapping
process. Also, although this method is on L{_BaseBaseClient} for
historical reasons, it's not used anywhere except for L{Client}
itself.
@return: C{None}
"""
if self._requiresResolution:
d = self.reactor.resolve(self.addr[0])
d.addCallback(lambda n: (n,) + self.addr[1:])
d.addCallbacks(self._setRealAddress, self.failIfNotConnected)
else:
self._setRealAddress(self.addr)
def _setRealAddress(self, address):
"""
Set the resolved address of this L{_BaseBaseClient} and initiate the
connection attempt.
@param address: Depending on whether this is an IPv4 or IPv6 connection
attempt, a 2-tuple of C{(host, port)} or a 4-tuple of C{(host,
port, flow, scope)}. At this point it is a fully resolved address,
and the 'host' portion will always be an IP address, not a DNS
name.
"""
self.realAddress = address
self.doConnect()
def failIfNotConnected(self, err):
"""
Generic method called when the attemps to connect failed. It basically
cleans everything it can: call connectionFailed, stop read and write,
delete socket related members.
"""
if (self.connected or self.disconnected or
not hasattr(self, "connector")):
return
self._stopReadingAndWriting()
try:
self._closeSocket(True)
except AttributeError:
pass
else:
self._collectSocketDetails()
self.connector.connectionFailed(failure.Failure(err))
del self.connector
def stopConnecting(self):
"""
If a connection attempt is still outstanding (i.e. no connection is
yet established), immediately stop attempting to connect.
"""
self.failIfNotConnected(error.UserError())
def connectionLost(self, reason):
"""
Invoked by lower-level logic when it's time to clean the socket up.
Depending on the state of the connection, either inform the attached
L{Connector} that the connection attempt has failed, or inform the
connected L{IProtocol} that the established connection has been lost.
@param reason: the reason that the connection was terminated
@type reason: L{Failure}
"""
if not self.connected:
self.failIfNotConnected(error.ConnectError(string=reason))
else:
self._commonConnection.connectionLost(self, reason)
self.connector.connectionLost(reason)
class BaseClient(_BaseBaseClient, _TLSClientMixin, Connection):
"""
A base class for client TCP (and similiar) sockets.
@ivar realAddress: The address object that will be used for socket.connect;
this address is an address tuple (the number of elements dependent upon
the address family) which does not contain any names which need to be
resolved.
@type realAddress: C{tuple}
@ivar _base: L{Connection}, which is the base class of this class which has
all of the useful file descriptor methods. This is used by
L{_TLSServerMixin} to call the right methods to directly manipulate the
transport, as is necessary for writing TLS-encrypted bytes (whereas
those methods on L{Server} will go through another layer of TLS if it
has been enabled).
"""
_base = Connection
_commonConnection = Connection
def _stopReadingAndWriting(self):
"""
Implement the POSIX-ish (i.e.
L{twisted.internet.interfaces.IReactorFDSet}) method of detaching this
socket from the reactor for L{_BaseBaseClient}.
"""
if hasattr(self, "reactor"):
# this doesn't happen if we failed in __init__
self.stopReading()
self.stopWriting()
def _collectSocketDetails(self):
"""
Clean up references to the socket and its file descriptor.
@see: L{_BaseBaseClient}
"""
del self.socket, self.fileno
def createInternetSocket(self):
"""(internal) Create a non-blocking socket using
self.addressFamily, self.socketType.
"""
s = socket.socket(self.addressFamily, self.socketType)
s.setblocking(0)
fdesc._setCloseOnExec(s.fileno())
return s
def doConnect(self):
"""
Initiate the outgoing connection attempt.
@note: Applications do not need to call this method; it will be invoked
internally as part of L{IReactorTCP.connectTCP}.
"""
self.doWrite = self.doConnect
self.doRead = self.doConnect
if not hasattr(self, "connector"):
# this happens when connection failed but doConnect
# was scheduled via a callLater in self._finishInit
return
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err:
self.failIfNotConnected(error.getConnectError((err, strerror(err))))
return
# doConnect gets called twice. The first time we actually need to
# start the connection attempt. The second time we don't really
# want to (SO_ERROR above will have taken care of any errors, and if
# it reported none, the mere fact that doConnect was called again is
# sufficient to indicate that the connection has succeeded), but it
# is not /particularly/ detrimental to do so. This should get
# cleaned up some day, though.
try:
connectResult = self.socket.connect_ex(self.realAddress)
except socket.error, se:
connectResult = se.args[0]
if connectResult:
if connectResult == EISCONN:
pass
# on Windows EINVAL means sometimes that we should keep trying:
# http://msdn.microsoft.com/library/default.asp?url=/library/en-us/winsock/winsock/connect_2.asp
elif ((connectResult in (EWOULDBLOCK, EINPROGRESS, EALREADY)) or
(connectResult == EINVAL and platformType == "win32")):
self.startReading()
self.startWriting()
return
else:
self.failIfNotConnected(error.getConnectError((connectResult, strerror(connectResult))))
return
# If I have reached this point without raising or returning, that means
# that the socket is connected.
del self.doWrite
del self.doRead
# we first stop and then start, to reset any references to the old doRead
self.stopReading()
self.stopWriting()
self._connectDone()
def _connectDone(self):
"""
This is a hook for when a connection attempt has succeeded.
Here, we build the protocol from the
L{twisted.internet.protocol.ClientFactory} that was passed in, compute
a log string, begin reading so as to send traffic to the newly built
protocol, and finally hook up the protocol itself.
This hook is overridden by L{ssl.Client} to initiate the TLS protocol.
"""
self.protocol = self.connector.buildProtocol(self.getPeer())
self.connected = 1
logPrefix = self._getLogPrefix(self.protocol)
self.logstr = "%s,client" % logPrefix
self.startReading()
self.protocol.makeConnection(self)
_NUMERIC_ONLY = socket.AI_NUMERICHOST | _AI_NUMERICSERV
def _resolveIPv6(ip, port):
"""
Resolve an IPv6 literal into an IPv6 address.
This is necessary to resolve any embedded scope identifiers to the relevant
C{sin6_scope_id} for use with C{socket.connect()}, C{socket.listen()}, or
C{socket.bind()}; see U{RFC 3493 } for
more information.
@param ip: An IPv6 address literal.
@type ip: C{str}
@param port: A port number.
@type port: C{int}
@return: a 4-tuple of C{(host, port, flow, scope)}, suitable for use as an
IPv6 address.
@raise socket.gaierror: if either the IP or port is not numeric as it
should be.
"""
return socket.getaddrinfo(ip, port, 0, 0, 0, _NUMERIC_ONLY)[0][4]
class _BaseTCPClient(object):
"""
Code shared with other (non-POSIX) reactors for management of outgoing TCP
connections (both TCPv4 and TCPv6).
@note: In order to be functional, this class must be mixed into the same
hierarchy as L{_BaseBaseClient}. It would subclass L{_BaseBaseClient}
directly, but the class hierarchy here is divided in strange ways out
of the need to share code along multiple axes; specifically, with the
IOCP reactor and also with UNIX clients in other reactors.
@ivar _addressType: The Twisted _IPAddress implementation for this client
@type _addressType: L{IPv4Address} or L{IPv6Address}
@ivar connector: The L{Connector} which is driving this L{_BaseTCPClient}'s
connection attempt.
@ivar addr: The address that this socket will be connecting to.
@type addr: If IPv4, a 2-C{tuple} of C{(str host, int port)}. If IPv6, a
4-C{tuple} of (C{str host, int port, int ignored, int scope}).
@ivar createInternetSocket: Subclasses must implement this as a method to
create a python socket object of the appropriate address family and
socket type.
@type createInternetSocket: 0-argument callable returning
C{socket._socketobject}.
"""
_addressType = address.IPv4Address
def __init__(self, host, port, bindAddress, connector, reactor=None):
# BaseClient.__init__ is invoked later
self.connector = connector
self.addr = (host, port)
whenDone = self.resolveAddress
err = None
skt = None
if abstract.isIPAddress(host):
self._requiresResolution = False
elif abstract.isIPv6Address(host):
self._requiresResolution = False
self.addr = _resolveIPv6(host, port)
self.addressFamily = socket.AF_INET6
self._addressType = address.IPv6Address
else:
self._requiresResolution = True
try:
skt = self.createInternetSocket()
except socket.error, se:
err = error.ConnectBindError(se.args[0], se.args[1])
whenDone = None
if whenDone and bindAddress is not None:
try:
if abstract.isIPv6Address(bindAddress[0]):
bindinfo = _resolveIPv6(*bindAddress)
else:
bindinfo = bindAddress
skt.bind(bindinfo)
except socket.error, se:
err = error.ConnectBindError(se.args[0], se.args[1])
whenDone = None
self._finishInit(whenDone, skt, err, reactor)
def getHost(self):
"""
Returns an L{IPv4Address} or L{IPv6Address}.
This indicates the address from which I am connecting.
"""
return self._addressType('TCP', *self.socket.getsockname()[:2])
def getPeer(self):
"""
Returns an L{IPv4Address} or L{IPv6Address}.
This indicates the address that I am connected to.
"""
# an ipv6 realAddress has more than two elements, but the IPv6Address
# constructor still only takes two.
return self._addressType('TCP', *self.realAddress[:2])
def __repr__(self):
s = '<%s to %s at %x>' % (self.__class__, self.addr, unsignedID(self))
return s
class Client(_BaseTCPClient, BaseClient):
"""
A transport for a TCP protocol; either TCPv4 or TCPv6.
Do not create these directly; use L{IReactorTCP.connectTCP}.
"""
class Server(_TLSServerMixin, Connection):
"""
Serverside socket-stream connection class.
This is a serverside network connection transport; a socket which came from
an accept() on a server.
@ivar _base: L{Connection}, which is the base class of this class which has
all of the useful file descriptor methods. This is used by
L{_TLSServerMixin} to call the right methods to directly manipulate the
transport, as is necessary for writing TLS-encrypted bytes (whereas
those methods on L{Server} will go through another layer of TLS if it
has been enabled).
"""
_base = Connection
_addressType = address.IPv4Address
def __init__(self, sock, protocol, client, server, sessionno, reactor):
"""
Server(sock, protocol, client, server, sessionno)
Initialize it with a socket, a protocol, a descriptor for my peer (a
tuple of host, port describing the other end of the connection), an
instance of Port, and a session number.
"""
Connection.__init__(self, sock, protocol, reactor)
if len(client) != 2:
self._addressType = address.IPv6Address
self.server = server
self.client = client
self.sessionno = sessionno
self.hostname = client[0]
logPrefix = self._getLogPrefix(self.protocol)
self.logstr = "%s,%s,%s" % (logPrefix,
sessionno,
self.hostname)
self.repstr = "<%s #%s on %s>" % (self.protocol.__class__.__name__,
self.sessionno,
self.server._realPortNumber)
self.startReading()
self.connected = 1
def __repr__(self):
"""A string representation of this connection.
"""
return self.repstr
def getHost(self):
"""
Returns an L{IPv4Address} or L{IPv6Address}.
This indicates the server's address.
"""
host, port = self.socket.getsockname()[:2]
return self._addressType('TCP', host, port)
def getPeer(self):
"""
Returns an L{IPv4Address} or L{IPv6Address}.
This indicates the client's address.
"""
return self._addressType('TCP', *self.client[:2])
class Port(base.BasePort, _SocketCloser):
"""
A TCP server port, listening for connections.
When a connection is accepted, this will call a factory's buildProtocol
with the incoming address as an argument, according to the specification
described in L{twisted.internet.interfaces.IProtocolFactory}.
If you wish to change the sort of transport that will be used, the
C{transport} attribute will be called with the signature expected for
C{Server.__init__}, so it can be replaced.
@ivar deferred: a deferred created when L{stopListening} is called, and
that will fire when connection is lost. This is not to be used it
directly: prefer the deferred returned by L{stopListening} instead.
@type deferred: L{defer.Deferred}
@ivar disconnecting: flag indicating that the L{stopListening} method has
been called and that no connections should be accepted anymore.
@type disconnecting: C{bool}
@ivar connected: flag set once the listen has successfully been called on
the socket.
@type connected: C{bool}
@ivar _type: A string describing the connections which will be created by
this port. Normally this is C{"TCP"}, since this is a TCP port, but
when the TLS implementation re-uses this class it overrides the value
with C{"TLS"}. Only used for logging.
@ivar _preexistingSocket: If not C{None}, a L{socket.socket} instance which
was created and initialized outside of the reactor and will be used to
listen for connections (instead of a new socket being created by this
L{Port}).
"""
implements(interfaces.IListeningPort)
socketType = socket.SOCK_STREAM
transport = Server
sessionno = 0
interface = ''
backlog = 50
_type = 'TCP'
# Actual port number being listened on, only set to a non-None
# value when we are actually listening.
_realPortNumber = None
# An externally initialized socket that we will use, rather than creating
# our own.
_preexistingSocket = None
addressFamily = socket.AF_INET
_addressType = address.IPv4Address
def __init__(self, port, factory, backlog=50, interface='', reactor=None):
"""Initialize with a numeric port to listen on.
"""
base.BasePort.__init__(self, reactor=reactor)
self.port = port
self.factory = factory
self.backlog = backlog
if abstract.isIPv6Address(interface):
self.addressFamily = socket.AF_INET6
self._addressType = address.IPv6Address
self.interface = interface
@classmethod
def _fromListeningDescriptor(cls, reactor, fd, addressFamily, factory):
"""
Create a new L{Port} based on an existing listening I{SOCK_STREAM}
I{AF_INET} socket.
Arguments are the same as to L{Port.__init__}, except where noted.
@param fd: An integer file descriptor associated with a listening
socket. The socket must be in non-blocking mode. Any additional
attributes desired, such as I{FD_CLOEXEC}, must also be set already.
@param addressFamily: The address family (sometimes called I{domain}) of
the existing socket. For example, L{socket.AF_INET}.
@return: A new instance of C{cls} wrapping the socket given by C{fd}.
"""
port = socket.fromfd(fd, addressFamily, cls.socketType)
interface = port.getsockname()[0]
self = cls(None, factory, None, interface, reactor)
self._preexistingSocket = port
return self
def __repr__(self):
if self._realPortNumber is not None:
return "<%s of %s on %s>" % (self.__class__,
self.factory.__class__, self._realPortNumber)
else:
return "<%s of %s (not listening)>" % (self.__class__, self.factory.__class__)
def createInternetSocket(self):
s = base.BasePort.createInternetSocket(self)
if platformType == "posix" and sys.platform != "cygwin":
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s
def startListening(self):
"""Create and bind my socket, and begin listening on it.
This is called on unserialization, and must be called after creating a
server to begin listening on the specified port.
"""
if self._preexistingSocket is None:
# Create a new socket and make it listen
try:
skt = self.createInternetSocket()
if self.addressFamily == socket.AF_INET6:
addr = _resolveIPv6(self.interface, self.port)
else:
addr = (self.interface, self.port)
skt.bind(addr)
except socket.error, le:
raise CannotListenError, (self.interface, self.port, le)
skt.listen(self.backlog)
else:
# Re-use the externally specified socket
skt = self._preexistingSocket
self._preexistingSocket = None
# Make sure that if we listened on port 0, we update that to
# reflect what the OS actually assigned us.
self._realPortNumber = skt.getsockname()[1]
log.msg("%s starting on %s" % (
self._getLogPrefix(self.factory), self._realPortNumber))
# The order of the next 5 lines is kind of bizarre. If no one
# can explain it, perhaps we should re-arrange them.
self.factory.doStart()
self.connected = True
self.socket = skt
self.fileno = self.socket.fileno
self.numberAccepts = 100
self.startReading()
def _buildAddr(self, address):
host, port = address[:2]
return self._addressType('TCP', host, port)
def doRead(self):
"""Called when my socket is ready for reading.
This accepts a connection and calls self.protocol() to handle the
wire-level protocol.
"""
try:
if platformType == "posix":
numAccepts = self.numberAccepts
else:
# win32 event loop breaks if we do more than one accept()
# in an iteration of the event loop.
numAccepts = 1
for i in range(numAccepts):
# we need this so we can deal with a factory's buildProtocol
# calling our loseConnection
if self.disconnecting:
return
try:
skt, addr = self.socket.accept()
except socket.error, e:
if e.args[0] in (EWOULDBLOCK, EAGAIN):
self.numberAccepts = i
break
elif e.args[0] == EPERM:
# Netfilter on Linux may have rejected the
# connection, but we get told to try to accept()
# anyway.
continue
elif e.args[0] in (EMFILE, ENOBUFS, ENFILE, ENOMEM, ECONNABORTED):
# Linux gives EMFILE when a process is not allowed
# to allocate any more file descriptors. *BSD and
# Win32 give (WSA)ENOBUFS. Linux can also give
# ENFILE if the system is out of inodes, or ENOMEM
# if there is insufficient memory to allocate a new
# dentry. ECONNABORTED is documented as possible on
# both Linux and Windows, but it is not clear
# whether there are actually any circumstances under
# which it can happen (one might expect it to be
# possible if a client sends a FIN or RST after the
# server sends a SYN|ACK but before application code
# calls accept(2), however at least on Linux this
# _seems_ to be short-circuited by syncookies.
log.msg("Could not accept new connection (%s)" % (
errorcode[e.args[0]],))
break
raise
fdesc._setCloseOnExec(skt.fileno())
protocol = self.factory.buildProtocol(self._buildAddr(addr))
if protocol is None:
skt.close()
continue
s = self.sessionno
self.sessionno = s+1
transport = self.transport(skt, protocol, addr, self, s, self.reactor)
protocol.makeConnection(transport)
else:
self.numberAccepts = self.numberAccepts+20
except:
# Note that in TLS mode, this will possibly catch SSL.Errors
# raised by self.socket.accept()
#
# There is no "except SSL.Error:" above because SSL may be
# None if there is no SSL support. In any case, all the
# "except SSL.Error:" suite would probably do is log.deferr()
# and return, so handling it here works just as well.
log.deferr()
def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)):
"""
Stop accepting connections on this port.
This will shut down the socket and call self.connectionLost(). It
returns a deferred which will fire successfully when the port is
actually closed, or with a failure if an error occurs shutting down.
"""
self.disconnecting = True
self.stopReading()
if self.connected:
self.deferred = deferLater(
self.reactor, 0, self.connectionLost, connDone)
return self.deferred
stopListening = loseConnection
def _logConnectionLostMsg(self):
"""
Log message for closing port
"""
log.msg('(%s Port %s Closed)' % (self._type, self._realPortNumber))
def connectionLost(self, reason):
"""
Cleans up the socket.
"""
self._logConnectionLostMsg()
self._realPortNumber = None
base.BasePort.connectionLost(self, reason)
self.connected = False
self._closeSocket(True)
del self.socket
del self.fileno
try:
self.factory.doStop()
finally:
self.disconnecting = False
def logPrefix(self):
"""Returns the name of my class, to prefix log entries with.
"""
return reflect.qual(self.factory.__class__)
def getHost(self):
"""
Return an L{IPv4Address} or L{IPv6Address} indicating the listening
address of this port.
"""
host, port = self.socket.getsockname()[:2]
return self._addressType('TCP', host, port)
class Connector(base.BaseConnector):
"""
A L{Connector} provides of L{twisted.internet.interfaces.IConnector} for
all POSIX-style reactors.
@ivar _addressType: the type returned by L{Connector.getDestination}.
Either L{IPv4Address} or L{IPv6Address}, depending on the type of
address.
@type _addressType: C{type}
"""
_addressType = address.IPv4Address
def __init__(self, host, port, factory, timeout, bindAddress, reactor=None):
if isinstance(port, types.StringTypes):
try:
port = socket.getservbyname(port, 'tcp')
except socket.error, e:
raise error.ServiceNameUnknownError(string="%s (%r)" % (e, port))
self.host, self.port = host, port
if abstract.isIPv6Address(host):
self._addressType = address.IPv6Address
self.bindAddress = bindAddress
base.BaseConnector.__init__(self, factory, timeout, reactor)
def _makeTransport(self):
"""
Create a L{Client} bound to this L{Connector}.
@return: a new L{Client}
@rtype: L{Client}
"""
return Client(self.host, self.port, self.bindAddress, self, self.reactor)
def getDestination(self):
"""
@see: L{twisted.internet.interfaces.IConnector.getDestination}.
"""
return self._addressType('TCP', self.host, self.port)
calendarserver-5.2+dfsg/twext/backport/internet/address.py 0000644 0001750 0001750 00000007715 11742073632 023135 0 ustar rahul rahul # Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Address objects for network connections.
"""
import warnings, os
from zope.interface import implements
from twisted.internet.interfaces import IAddress
from twisted.python import util
class _IPAddress(object, util.FancyEqMixin):
"""
An L{_IPAddress} represents the address of an IP socket endpoint, providing
common behavior for IPv4 and IPv6.
@ivar type: A string describing the type of transport, either 'TCP' or
'UDP'.
@ivar host: A string containing the presentation format of the IP address;
for example, "127.0.0.1" or "::1".
@type host: C{str}
@ivar port: An integer representing the port number.
@type port: C{int}
"""
implements(IAddress)
compareAttributes = ('type', 'host', 'port')
def __init__(self, type, host, port):
assert type in ('TCP', 'UDP')
self.type = type
self.host = host
self.port = port
def __repr__(self):
return '%s(%s, %r, %d)' % (
self.__class__.__name__, self.type, self.host, self.port)
def __hash__(self):
return hash((self.type, self.host, self.port))
class IPv4Address(_IPAddress):
"""
An L{IPv4Address} represents the address of an IPv4 socket endpoint.
@ivar host: A string containing a dotted-quad IPv4 address; for example,
"127.0.0.1".
@type host: C{str}
"""
def __init__(self, type, host, port, _bwHack=None):
_IPAddress.__init__(self, type, host, port)
if _bwHack is not None:
warnings.warn("twisted.internet.address.IPv4Address._bwHack "
"is deprecated since Twisted 11.0",
DeprecationWarning, stacklevel=2)
class IPv6Address(_IPAddress):
"""
An L{IPv6Address} represents the address of an IPv6 socket endpoint.
@ivar host: A string containing a colon-separated, hexadecimal formatted
IPv6 address; for example, "::1".
@type host: C{str}
"""
class UNIXAddress(object, util.FancyEqMixin):
"""
Object representing a UNIX socket endpoint.
@ivar name: The filename associated with this socket.
@type name: C{str}
"""
implements(IAddress)
compareAttributes = ('name', )
def __init__(self, name, _bwHack = None):
self.name = name
if _bwHack is not None:
warnings.warn("twisted.internet.address.UNIXAddress._bwHack is deprecated since Twisted 11.0",
DeprecationWarning, stacklevel=2)
if getattr(os.path, 'samefile', None) is not None:
def __eq__(self, other):
"""
overriding L{util.FancyEqMixin} to ensure the os level samefile
check is done if the name attributes do not match.
"""
res = super(UNIXAddress, self).__eq__(other)
if res == False:
try:
return os.path.samefile(self.name, other.name)
except OSError:
pass
return res
def __repr__(self):
return 'UNIXAddress(%r)' % (self.name,)
def __hash__(self):
try:
s1 = os.stat(self.name)
return hash((s1.st_ino, s1.st_dev))
except OSError:
return hash(self.name)
# These are for buildFactory backwards compatability due to
# stupidity-induced inconsistency.
class _ServerFactoryIPv4Address(IPv4Address):
"""Backwards compatability hack. Just like IPv4Address in practice."""
def __eq__(self, other):
if isinstance(other, tuple):
warnings.warn("IPv4Address.__getitem__ is deprecated. Use attributes instead.",
category=DeprecationWarning, stacklevel=2)
return (self.host, self.port) == other
elif isinstance(other, IPv4Address):
a = (self.type, self.host, self.port)
b = (other.type, other.host, other.port)
return a == b
return False
calendarserver-5.2+dfsg/twext/backport/internet/__init__.py 0000644 0001750 0001750 00000001316 12263343324 023233 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Backports of portions of L{twisted.internet}.
(Specifically, those required for IPv6 client support).
"""
calendarserver-5.2+dfsg/twext/backport/internet/endpoints.py 0000644 0001750 0001750 00000114750 11742073632 023511 0 ustar rahul rahul # -*- test-case-name: twisted.internet.test.test_endpoints -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementations of L{IStreamServerEndpoint} and L{IStreamClientEndpoint} that
wrap the L{IReactorTCP}, L{IReactorSSL}, and L{IReactorUNIX} interfaces.
This also implements an extensible mini-language for describing endpoints,
parsed by the L{clientFromString} and L{serverFromString} functions.
@since: 10.1
"""
import os, socket
from zope.interface import implements, directlyProvides
import warnings
from twisted.internet import interfaces, defer, error, fdesc
from twisted.internet.protocol import ClientFactory, Protocol
from twisted.plugin import IPlugin, getPlugins
from twisted.internet.interfaces import IStreamServerEndpointStringParser
from twisted.internet.interfaces import IStreamClientEndpointStringParser
from twisted.python.filepath import FilePath
#from twisted.python.systemd import ListenFDs
__all__ = ["clientFromString", "serverFromString",
"TCP4ServerEndpoint", "TCP4ClientEndpoint",
"UNIXServerEndpoint", "UNIXClientEndpoint",
"SSL4ServerEndpoint", "SSL4ClientEndpoint",
"AdoptedStreamServerEndpoint"]
class _WrappingProtocol(Protocol):
"""
Wrap another protocol in order to notify my user when a connection has
been made.
@ivar _connectedDeferred: The L{Deferred} that will callback
with the C{wrappedProtocol} when it is connected.
@ivar _wrappedProtocol: An L{IProtocol} provider that will be
connected.
"""
def __init__(self, connectedDeferred, wrappedProtocol):
"""
@param connectedDeferred: The L{Deferred} that will callback
with the C{wrappedProtocol} when it is connected.
@param wrappedProtocol: An L{IProtocol} provider that will be
connected.
"""
self._connectedDeferred = connectedDeferred
self._wrappedProtocol = wrappedProtocol
if interfaces.IHalfCloseableProtocol.providedBy(
self._wrappedProtocol):
directlyProvides(self, interfaces.IHalfCloseableProtocol)
def logPrefix(self):
"""
Transparently pass through the wrapped protocol's log prefix.
"""
if interfaces.ILoggingContext.providedBy(self._wrappedProtocol):
return self._wrappedProtocol.logPrefix()
return self._wrappedProtocol.__class__.__name__
def connectionMade(self):
"""
Connect the C{self._wrappedProtocol} to our C{self.transport} and
callback C{self._connectedDeferred} with the C{self._wrappedProtocol}
"""
self._wrappedProtocol.makeConnection(self.transport)
self._connectedDeferred.callback(self._wrappedProtocol)
def dataReceived(self, data):
"""
Proxy C{dataReceived} calls to our C{self._wrappedProtocol}
"""
return self._wrappedProtocol.dataReceived(data)
def connectionLost(self, reason):
"""
Proxy C{connectionLost} calls to our C{self._wrappedProtocol}
"""
return self._wrappedProtocol.connectionLost(reason)
def readConnectionLost(self):
"""
Proxy L{IHalfCloseableProtocol.readConnectionLost} to our
C{self._wrappedProtocol}
"""
self._wrappedProtocol.readConnectionLost()
def writeConnectionLost(self):
"""
Proxy L{IHalfCloseableProtocol.writeConnectionLost} to our
C{self._wrappedProtocol}
"""
self._wrappedProtocol.writeConnectionLost()
class _WrappingFactory(ClientFactory):
"""
Wrap a factory in order to wrap the protocols it builds.
@ivar _wrappedFactory: A provider of I{IProtocolFactory} whose buildProtocol
method will be called and whose resulting protocol will be wrapped.
@ivar _onConnection: An L{Deferred} that fires when the protocol is
connected
@ivar _connector: A L{connector }
that is managing the current or previous connection attempt.
"""
protocol = _WrappingProtocol
def __init__(self, wrappedFactory):
"""
@param wrappedFactory: A provider of I{IProtocolFactory} whose
buildProtocol method will be called and whose resulting protocol
will be wrapped.
"""
self._wrappedFactory = wrappedFactory
self._onConnection = defer.Deferred(canceller=self._canceller)
def startedConnecting(self, connector):
"""
A connection attempt was started. Remember the connector which started
said attempt, for use later.
"""
self._connector = connector
def _canceller(self, deferred):
"""
The outgoing connection attempt was cancelled. Fail that L{Deferred}
with a L{error.ConnectingCancelledError}.
@param deferred: The L{Deferred } that was cancelled;
should be the same as C{self._onConnection}.
@type deferred: L{Deferred }
@note: This relies on startedConnecting having been called, so it may
seem as though there's a race condition where C{_connector} may not
have been set. However, using public APIs, this condition is
impossible to catch, because a connection API
(C{connectTCP}/C{SSL}/C{UNIX}) is always invoked before a
L{_WrappingFactory}'s L{Deferred } is returned to
C{connect()}'s caller.
@return: C{None}
"""
deferred.errback(
error.ConnectingCancelledError(
self._connector.getDestination()))
self._connector.stopConnecting()
def doStart(self):
"""
Start notifications are passed straight through to the wrapped factory.
"""
self._wrappedFactory.doStart()
def doStop(self):
"""
Stop notifications are passed straight through to the wrapped factory.
"""
self._wrappedFactory.doStop()
def buildProtocol(self, addr):
"""
Proxy C{buildProtocol} to our C{self._wrappedFactory} or errback
the C{self._onConnection} L{Deferred}.
@return: An instance of L{_WrappingProtocol} or C{None}
"""
try:
proto = self._wrappedFactory.buildProtocol(addr)
except:
self._onConnection.errback()
else:
return self.protocol(self._onConnection, proto)
def clientConnectionFailed(self, connector, reason):
"""
Errback the C{self._onConnection} L{Deferred} when the
client connection fails.
"""
if not self._onConnection.called:
self._onConnection.errback(reason)
class TCP4ServerEndpoint(object):
"""
TCP server endpoint with an IPv4 configuration
@ivar _reactor: An L{IReactorTCP} provider.
@type _port: int
@ivar _port: The port number on which to listen for incoming connections.
@type _backlog: int
@ivar _backlog: size of the listen queue
@type _interface: str
@ivar _interface: the hostname to bind to, defaults to '' (all)
"""
implements(interfaces.IStreamServerEndpoint)
def __init__(self, reactor, port, backlog=50, interface=''):
"""
@param reactor: An L{IReactorTCP} provider.
@param port: The port number used listening
@param backlog: size of the listen queue
@param interface: the hostname to bind to, defaults to '' (all)
"""
self._reactor = reactor
self._port = port
self._listenArgs = dict(backlog=50, interface='')
self._backlog = backlog
self._interface = interface
def listen(self, protocolFactory):
"""
Implement L{IStreamServerEndpoint.listen} to listen on a TCP socket
"""
return defer.execute(self._reactor.listenTCP,
self._port,
protocolFactory,
backlog=self._backlog,
interface=self._interface)
class TCP4ClientEndpoint(object):
"""
TCP client endpoint with an IPv4 configuration.
@ivar _reactor: An L{IReactorTCP} provider.
@type _host: str
@ivar _host: The hostname to connect to as a C{str}
@type _port: int
@ivar _port: The port to connect to as C{int}
@type _timeout: int
@ivar _timeout: number of seconds to wait before assuming the
connection has failed.
@type _bindAddress: tuple
@type _bindAddress: a (host, port) tuple of local address to bind
to, or None.
"""
implements(interfaces.IStreamClientEndpoint)
def __init__(self, reactor, host, port, timeout=30, bindAddress=None):
"""
@param reactor: An L{IReactorTCP} provider
@param host: A hostname, used when connecting
@param port: The port number, used when connecting
@param timeout: number of seconds to wait before assuming the
connection has failed.
@param bindAddress: a (host, port tuple of local address to bind to,
or None.
"""
self._reactor = reactor
self._host = host
self._port = port
self._timeout = timeout
self._bindAddress = bindAddress
def connect(self, protocolFactory):
"""
Implement L{IStreamClientEndpoint.connect} to connect via TCP.
"""
try:
wf = _WrappingFactory(protocolFactory)
self._reactor.connectTCP(
self._host, self._port, wf,
timeout=self._timeout, bindAddress=self._bindAddress)
return wf._onConnection
except:
return defer.fail()
class SSL4ServerEndpoint(object):
"""
SSL secured TCP server endpoint with an IPv4 configuration.
@ivar _reactor: An L{IReactorSSL} provider.
@type _host: str
@ivar _host: The hostname to connect to as a C{str}
@type _port: int
@ivar _port: The port to connect to as C{int}
@type _sslContextFactory: L{OpenSSLCertificateOptions}
@var _sslContextFactory: SSL Configuration information as an
L{OpenSSLCertificateOptions}
@type _backlog: int
@ivar _backlog: size of the listen queue
@type _interface: str
@ivar _interface: the hostname to bind to, defaults to '' (all)
"""
implements(interfaces.IStreamServerEndpoint)
def __init__(self, reactor, port, sslContextFactory,
backlog=50, interface=''):
"""
@param reactor: An L{IReactorSSL} provider.
@param port: The port number used listening
@param sslContextFactory: An instance of
L{twisted.internet._sslverify.OpenSSLCertificateOptions}.
@param timeout: number of seconds to wait before assuming the
connection has failed.
@param bindAddress: a (host, port tuple of local address to bind to,
or None.
"""
self._reactor = reactor
self._port = port
self._sslContextFactory = sslContextFactory
self._backlog = backlog
self._interface = interface
def listen(self, protocolFactory):
"""
Implement L{IStreamServerEndpoint.listen} to listen for SSL on a
TCP socket.
"""
return defer.execute(self._reactor.listenSSL, self._port,
protocolFactory,
contextFactory=self._sslContextFactory,
backlog=self._backlog,
interface=self._interface)
class SSL4ClientEndpoint(object):
"""
SSL secured TCP client endpoint with an IPv4 configuration
@ivar _reactor: An L{IReactorSSL} provider.
@type _host: str
@ivar _host: The hostname to connect to as a C{str}
@type _port: int
@ivar _port: The port to connect to as C{int}
@type _sslContextFactory: L{OpenSSLCertificateOptions}
@var _sslContextFactory: SSL Configuration information as an
L{OpenSSLCertificateOptions}
@type _timeout: int
@ivar _timeout: number of seconds to wait before assuming the
connection has failed.
@type _bindAddress: tuple
@ivar _bindAddress: a (host, port) tuple of local address to bind
to, or None.
"""
implements(interfaces.IStreamClientEndpoint)
def __init__(self, reactor, host, port, sslContextFactory,
timeout=30, bindAddress=None):
"""
@param reactor: An L{IReactorSSL} provider.
@param host: A hostname, used when connecting
@param port: The port number, used when connecting
@param sslContextFactory: SSL Configuration information as An instance
of L{OpenSSLCertificateOptions}.
@param timeout: number of seconds to wait before assuming the
connection has failed.
@param bindAddress: a (host, port tuple of local address to bind to,
or None.
"""
self._reactor = reactor
self._host = host
self._port = port
self._sslContextFactory = sslContextFactory
self._timeout = timeout
self._bindAddress = bindAddress
def connect(self, protocolFactory):
"""
Implement L{IStreamClientEndpoint.connect} to connect with SSL over
TCP.
"""
try:
wf = _WrappingFactory(protocolFactory)
self._reactor.connectSSL(
self._host, self._port, wf, self._sslContextFactory,
timeout=self._timeout, bindAddress=self._bindAddress)
return wf._onConnection
except:
return defer.fail()
class UNIXServerEndpoint(object):
"""
UnixSocket server endpoint.
@type path: str
@ivar path: a path to a unix socket on the filesystem.
@type _listenArgs: dict
@ivar _listenArgs: A C{dict} of keyword args that will be passed
to L{IReactorUNIX.listenUNIX}
@var _reactor: An L{IReactorTCP} provider.
"""
implements(interfaces.IStreamServerEndpoint)
def __init__(self, reactor, address, backlog=50, mode=0666, wantPID=0):
"""
@param reactor: An L{IReactorUNIX} provider.
@param address: The path to the Unix socket file, used when listening
@param listenArgs: An optional dict of keyword args that will be
passed to L{IReactorUNIX.listenUNIX}
@param backlog: number of connections to allow in backlog.
@param mode: mode to set on the unix socket. This parameter is
deprecated. Permissions should be set on the directory which
contains the UNIX socket.
@param wantPID: if True, create a pidfile for the socket.
"""
self._reactor = reactor
self._address = address
self._backlog = backlog
self._mode = mode
self._wantPID = wantPID
def listen(self, protocolFactory):
"""
Implement L{IStreamServerEndpoint.listen} to listen on a UNIX socket.
"""
return defer.execute(self._reactor.listenUNIX, self._address,
protocolFactory,
backlog=self._backlog,
mode=self._mode,
wantPID=self._wantPID)
class UNIXClientEndpoint(object):
"""
UnixSocket client endpoint.
@type _path: str
@ivar _path: a path to a unix socket on the filesystem.
@type _timeout: int
@ivar _timeout: number of seconds to wait before assuming the connection
has failed.
@type _checkPID: bool
@ivar _checkPID: if True, check for a pid file to verify that a server
is listening.
@var _reactor: An L{IReactorUNIX} provider.
"""
implements(interfaces.IStreamClientEndpoint)
def __init__(self, reactor, path, timeout=30, checkPID=0):
"""
@param reactor: An L{IReactorUNIX} provider.
@param path: The path to the Unix socket file, used when connecting
@param timeout: number of seconds to wait before assuming the
connection has failed.
@param checkPID: if True, check for a pid file to verify that a server
is listening.
"""
self._reactor = reactor
self._path = path
self._timeout = timeout
self._checkPID = checkPID
def connect(self, protocolFactory):
"""
Implement L{IStreamClientEndpoint.connect} to connect via a
UNIX Socket
"""
try:
wf = _WrappingFactory(protocolFactory)
self._reactor.connectUNIX(
self._path, wf,
timeout=self._timeout,
checkPID=self._checkPID)
return wf._onConnection
except:
return defer.fail()
class AdoptedStreamServerEndpoint(object):
"""
An endpoint for listening on a file descriptor initialized outside of
Twisted.
@ivar _used: A C{bool} indicating whether this endpoint has been used to
listen with a factory yet. C{True} if so.
"""
_close = os.close
_setNonBlocking = staticmethod(fdesc.setNonBlocking)
def __init__(self, reactor, fileno, addressFamily):
"""
@param reactor: An L{IReactorSocket} provider.
@param fileno: An integer file descriptor corresponding to a listening
I{SOCK_STREAM} socket.
@param addressFamily: The address family of the socket given by
C{fileno}.
"""
self.reactor = reactor
self.fileno = fileno
self.addressFamily = addressFamily
self._used = False
def listen(self, factory):
"""
Implement L{IStreamServerEndpoint.listen} to start listening on, and
then close, C{self._fileno}.
"""
if self._used:
return defer.fail(error.AlreadyListened())
self._used = True
try:
self._setNonBlocking(self.fileno)
port = self.reactor.adoptStreamPort(
self.fileno, self.addressFamily, factory)
self._close(self.fileno)
except:
return defer.fail()
return defer.succeed(port)
def _parseTCP(factory, port, interface="", backlog=50):
"""
Internal parser function for L{_parseServer} to convert the string
arguments for a TCP(IPv4) stream endpoint into the structured arguments.
@param factory: the protocol factory being parsed, or C{None}. (This was a
leftover argument from when this code was in C{strports}, and is now
mostly None and unused.)
@type factory: L{IProtocolFactory} or C{NoneType}
@param port: the integer port number to bind
@type port: C{str}
@param interface: the interface IP to listen on
@param backlog: the length of the listen queue
@type backlog: C{str}
@return: a 2-tuple of (args, kwargs), describing the parameters to
L{IReactorTCP.listenTCP} (or, modulo argument 2, the factory, arguments
to L{TCP4ServerEndpoint}.
"""
return (int(port), factory), {'interface': interface,
'backlog': int(backlog)}
def _parseUNIX(factory, address, mode='666', backlog=50, lockfile=True):
"""
Internal parser function for L{_parseServer} to convert the string
arguments for a UNIX (AF_UNIX/SOCK_STREAM) stream endpoint into the
structured arguments.
@param factory: the protocol factory being parsed, or C{None}. (This was a
leftover argument from when this code was in C{strports}, and is now
mostly None and unused.)
@type factory: L{IProtocolFactory} or C{NoneType}
@param address: the pathname of the unix socket
@type address: C{str}
@param backlog: the length of the listen queue
@type backlog: C{str}
@param lockfile: A string '0' or '1', mapping to True and False
respectively. See the C{wantPID} argument to C{listenUNIX}
@return: a 2-tuple of (args, kwargs), describing the parameters to
L{IReactorTCP.listenUNIX} (or, modulo argument 2, the factory,
arguments to L{UNIXServerEndpoint}.
"""
return (
(address, factory),
{'mode': int(mode, 8), 'backlog': int(backlog),
'wantPID': bool(int(lockfile))})
def _parseSSL(factory, port, privateKey="server.pem", certKey=None,
sslmethod=None, interface='', backlog=50):
"""
Internal parser function for L{_parseServer} to convert the string
arguments for an SSL (over TCP/IPv4) stream endpoint into the structured
arguments.
@param factory: the protocol factory being parsed, or C{None}. (This was a
leftover argument from when this code was in C{strports}, and is now
mostly None and unused.)
@type factory: L{IProtocolFactory} or C{NoneType}
@param port: the integer port number to bind
@type port: C{str}
@param interface: the interface IP to listen on
@param backlog: the length of the listen queue
@type backlog: C{str}
@param privateKey: The file name of a PEM format private key file.
@type privateKey: C{str}
@param certKey: The file name of a PEM format certificate file.
@type certKey: C{str}
@param sslmethod: The string name of an SSL method, based on the name of a
constant in C{OpenSSL.SSL}. Must be one of: "SSLv23_METHOD",
"SSLv2_METHOD", "SSLv3_METHOD", "TLSv1_METHOD".
@type sslmethod: C{str}
@return: a 2-tuple of (args, kwargs), describing the parameters to
L{IReactorSSL.listenSSL} (or, modulo argument 2, the factory, arguments
to L{SSL4ServerEndpoint}.
"""
from twisted.internet import ssl
if certKey is None:
certKey = privateKey
kw = {}
if sslmethod is not None:
kw['sslmethod'] = getattr(ssl.SSL, sslmethod)
cf = ssl.DefaultOpenSSLContextFactory(privateKey, certKey, **kw)
return ((int(port), factory, cf),
{'interface': interface, 'backlog': int(backlog)})
class _SystemdParser(object):
"""
Stream server endpoint string parser for the I{systemd} endpoint type.
@ivar prefix: See L{IStreamClientEndpointStringParser.prefix}.
@ivar _sddaemon: A L{ListenFDs} instance used to translate an index into an
actual file descriptor.
"""
implements(IPlugin, IStreamServerEndpointStringParser)
#_sddaemon = ListenFDs.fromEnvironment()
prefix = "systemd"
def _parseServer(self, reactor, domain, index):
"""
Internal parser function for L{_parseServer} to convert the string
arguments for a systemd server endpoint into structured arguments for
L{AdoptedStreamServerEndpoint}.
@param reactor: An L{IReactorSocket} provider.
@param domain: The domain (or address family) of the socket inherited
from systemd. This is a string like C{"INET"} or C{"UNIX"}, ie the
name of an address family from the L{socket} module, without the
C{"AF_"} prefix.
@type domain: C{str}
@param index: An offset into the list of file descriptors inherited from
systemd.
@type index: C{str}
@return: A two-tuple of parsed positional arguments and parsed keyword
arguments (a tuple and a dictionary). These can be used to
construct a L{AdoptedStreamServerEndpoint}.
"""
index = int(index)
fileno = self._sddaemon.inheritedDescriptors()[index]
addressFamily = getattr(socket, 'AF_' + domain)
return AdoptedStreamServerEndpoint(reactor, fileno, addressFamily)
def parseStreamServer(self, reactor, *args, **kwargs):
# Delegate to another function with a sane signature. This function has
# an insane signature to trick zope.interface into believing the
# interface is correctly implemented.
return self._parseServer(reactor, *args, **kwargs)
_serverParsers = {"tcp": _parseTCP,
"unix": _parseUNIX,
"ssl": _parseSSL,
}
_OP, _STRING = range(2)
def _tokenize(description):
"""
Tokenize a strports string and yield each token.
@param description: a string as described by L{serverFromString} or
L{clientFromString}.
@return: an iterable of 2-tuples of (L{_OP} or L{_STRING}, string). Tuples
starting with L{_OP} will contain a second element of either ':' (i.e.
'next parameter') or '=' (i.e. 'assign parameter value'). For example,
the string 'hello:greet\=ing=world' would result in a generator
yielding these values::
_STRING, 'hello'
_OP, ':'
_STRING, 'greet=ing'
_OP, '='
_STRING, 'world'
"""
current = ''
ops = ':='
nextOps = {':': ':=', '=': ':'}
description = iter(description)
for n in description:
if n in ops:
yield _STRING, current
yield _OP, n
current = ''
ops = nextOps[n]
elif n == '\\':
current += description.next()
else:
current += n
yield _STRING, current
def _parse(description):
"""
Convert a description string into a list of positional and keyword
parameters, using logic vaguely like what Python does.
@param description: a string as described by L{serverFromString} or
L{clientFromString}.
@return: a 2-tuple of C{(args, kwargs)}, where 'args' is a list of all
':'-separated C{str}s not containing an '=' and 'kwargs' is a map of
all C{str}s which do contain an '='. For example, the result of
C{_parse('a:b:d=1:c')} would be C{(['a', 'b', 'c'], {'d': '1'})}.
"""
args, kw = [], {}
def add(sofar):
if len(sofar) == 1:
args.append(sofar[0])
else:
kw[sofar[0]] = sofar[1]
sofar = ()
for (type, value) in _tokenize(description):
if type is _STRING:
sofar += (value,)
elif value == ':':
add(sofar)
sofar = ()
add(sofar)
return args, kw
# Mappings from description "names" to endpoint constructors.
_endpointServerFactories = {
'TCP': TCP4ServerEndpoint,
'SSL': SSL4ServerEndpoint,
'UNIX': UNIXServerEndpoint,
}
_endpointClientFactories = {
'TCP': TCP4ClientEndpoint,
'SSL': SSL4ClientEndpoint,
'UNIX': UNIXClientEndpoint,
}
_NO_DEFAULT = object()
def _parseServer(description, factory, default=None):
"""
Parse a stports description into a 2-tuple of arguments and keyword values.
@param description: A description in the format explained by
L{serverFromString}.
@type description: C{str}
@param factory: A 'factory' argument; this is left-over from
twisted.application.strports, it's not really used.
@type factory: L{IProtocolFactory} or L{None}
@param default: Deprecated argument, specifying the default parser mode to
use for unqualified description strings (those which do not have a ':'
and prefix).
@type default: C{str} or C{NoneType}
@return: a 3-tuple of (plugin or name, arguments, keyword arguments)
"""
args, kw = _parse(description)
if not args or (len(args) == 1 and not kw):
deprecationMessage = (
"Unqualified strport description passed to 'service'."
"Use qualified endpoint descriptions; for example, 'tcp:%s'."
% (description,))
if default is None:
default = 'tcp'
warnings.warn(
deprecationMessage, category=DeprecationWarning, stacklevel=4)
elif default is _NO_DEFAULT:
raise ValueError(deprecationMessage)
# If the default has been otherwise specified, the user has already
# been warned.
args[0:0] = [default]
endpointType = args[0]
parser = _serverParsers.get(endpointType)
if parser is None:
for plugin in getPlugins(IStreamServerEndpointStringParser):
if plugin.prefix == endpointType:
return (plugin, args[1:], kw)
raise ValueError("Unknown endpoint type: '%s'" % (endpointType,))
return (endpointType.upper(),) + parser(factory, *args[1:], **kw)
def _serverFromStringLegacy(reactor, description, default):
"""
Underlying implementation of L{serverFromString} which avoids exposing the
deprecated 'default' argument to anything but L{strports.service}.
"""
nameOrPlugin, args, kw = _parseServer(description, None, default)
if type(nameOrPlugin) is not str:
plugin = nameOrPlugin
return plugin.parseStreamServer(reactor, *args, **kw)
else:
name = nameOrPlugin
# Chop out the factory.
args = args[:1] + args[2:]
return _endpointServerFactories[name](reactor, *args, **kw)
def serverFromString(reactor, description):
"""
Construct a stream server endpoint from an endpoint description string.
The format for server endpoint descriptions is a simple string. It is a
prefix naming the type of endpoint, then a colon, then the arguments for
that endpoint.
For example, you can call it like this to create an endpoint that will
listen on TCP port 80::
serverFromString(reactor, "tcp:80")
Additional arguments may be specified as keywords, separated with colons.
For example, you can specify the interface for a TCP server endpoint to
bind to like this::
serverFromString(reactor, "tcp:80:interface=127.0.0.1")
SSL server endpoints may be specified with the 'ssl' prefix, and the
private key and certificate files may be specified by the C{privateKey} and
C{certKey} arguments::
serverFromString(reactor, "ssl:443:privateKey=key.pem:certKey=crt.pem")
If a private key file name (C{privateKey}) isn't provided, a "server.pem"
file is assumed to exist which contains the private key. If the certificate
file name (C{certKey}) isn't provided, the private key file is assumed to
contain the certificate as well.
You may escape colons in arguments with a backslash, which you will need to
use if you want to specify a full pathname argument on Windows::
serverFromString(reactor,
"ssl:443:privateKey=C\\:/key.pem:certKey=C\\:/cert.pem")
finally, the 'unix' prefix may be used to specify a filesystem UNIX socket,
optionally with a 'mode' argument to specify the mode of the socket file
created by C{listen}::
serverFromString(reactor, "unix:/var/run/finger")
serverFromString(reactor, "unix:/var/run/finger:mode=660")
This function is also extensible; new endpoint types may be registered as
L{IStreamServerEndpointStringParser} plugins. See that interface for more
information.
@param reactor: The server endpoint will be constructed with this reactor.
@param description: The strports description to parse.
@return: A new endpoint which can be used to listen with the parameters
given by by C{description}.
@rtype: L{IStreamServerEndpoint}
@raise ValueError: when the 'description' string cannot be parsed.
@since: 10.2
"""
return _serverFromStringLegacy(reactor, description, _NO_DEFAULT)
def quoteStringArgument(argument):
"""
Quote an argument to L{serverFromString} and L{clientFromString}. Since
arguments are separated with colons and colons are escaped with
backslashes, some care is necessary if, for example, you have a pathname,
you may be tempted to interpolate into a string like this::
serverFromString("ssl:443:privateKey=%s" % (myPathName,))
This may appear to work, but will have portability issues (Windows
pathnames, for example). Usually you should just construct the appropriate
endpoint type rather than interpolating strings, which in this case would
be L{SSL4ServerEndpoint}. There are some use-cases where you may need to
generate such a string, though; for example, a tool to manipulate a
configuration file which has strports descriptions in it. To be correct in
those cases, do this instead::
serverFromString("ssl:443:privateKey=%s" %
(quoteStringArgument(myPathName),))
@param argument: The part of the endpoint description string you want to
pass through.
@type argument: C{str}
@return: The quoted argument.
@rtype: C{str}
"""
return argument.replace('\\', '\\\\').replace(':', '\\:')
def _parseClientTCP(*args, **kwargs):
"""
Perform any argument value coercion necessary for TCP client parameters.
Valid positional arguments to this function are host and port.
Valid keyword arguments to this function are all L{IReactorTCP.connectTCP}
arguments.
@return: The coerced values as a C{dict}.
"""
if len(args) == 2:
kwargs['port'] = int(args[1])
kwargs['host'] = args[0]
elif len(args) == 1:
if 'host' in kwargs:
kwargs['port'] = int(args[0])
else:
kwargs['host'] = args[0]
try:
kwargs['port'] = int(kwargs['port'])
except KeyError:
pass
try:
kwargs['timeout'] = int(kwargs['timeout'])
except KeyError:
pass
return kwargs
def _loadCAsFromDir(directoryPath):
"""
Load certificate-authority certificate objects in a given directory.
@param directoryPath: a L{FilePath} pointing at a directory to load .pem
files from.
@return: a C{list} of L{OpenSSL.crypto.X509} objects.
"""
from twisted.internet import ssl
caCerts = {}
for child in directoryPath.children():
if not child.basename().split('.')[-1].lower() == 'pem':
continue
try:
data = child.getContent()
except IOError:
# Permission denied, corrupt disk, we don't care.
continue
try:
theCert = ssl.Certificate.loadPEM(data)
except ssl.SSL.Error:
# Duplicate certificate, invalid certificate, etc. We don't care.
pass
else:
caCerts[theCert.digest()] = theCert.original
return caCerts.values()
def _parseClientSSL(*args, **kwargs):
"""
Perform any argument value coercion necessary for SSL client parameters.
Valid keyword arguments to this function are all L{IReactorSSL.connectSSL}
arguments except for C{contextFactory}. Instead, C{certKey} (the path name
of the certificate file) C{privateKey} (the path name of the private key
associated with the certificate) are accepted and used to construct a
context factory.
Valid positional arguments to this function are host and port.
@param caCertsDir: The one parameter which is not part of
L{IReactorSSL.connectSSL}'s signature, this is a path name used to
construct a list of certificate authority certificates. The directory
will be scanned for files ending in C{.pem}, all of which will be
considered valid certificate authorities for this connection.
@type caCertsDir: C{str}
@return: The coerced values as a C{dict}.
"""
from twisted.internet import ssl
kwargs = _parseClientTCP(*args, **kwargs)
certKey = kwargs.pop('certKey', None)
privateKey = kwargs.pop('privateKey', None)
caCertsDir = kwargs.pop('caCertsDir', None)
if certKey is not None:
certx509 = ssl.Certificate.loadPEM(
FilePath(certKey).getContent()).original
else:
certx509 = None
if privateKey is not None:
privateKey = ssl.PrivateCertificate.loadPEM(
FilePath(privateKey).getContent()).privateKey.original
else:
privateKey = None
if caCertsDir is not None:
verify = True
caCerts = _loadCAsFromDir(FilePath(caCertsDir))
else:
verify = False
caCerts = None
kwargs['sslContextFactory'] = ssl.CertificateOptions(
method=ssl.SSL.SSLv23_METHOD,
certificate=certx509,
privateKey=privateKey,
verify=verify,
caCerts=caCerts
)
return kwargs
def _parseClientUNIX(**kwargs):
"""
Perform any argument value coercion necessary for UNIX client parameters.
Valid keyword arguments to this function are all L{IReactorUNIX.connectUNIX}
arguments except for C{checkPID}. Instead, C{lockfile} is accepted and has
the same meaning.
@return: The coerced values as a C{dict}.
"""
try:
kwargs['checkPID'] = bool(int(kwargs.pop('lockfile')))
except KeyError:
pass
try:
kwargs['timeout'] = int(kwargs['timeout'])
except KeyError:
pass
return kwargs
_clientParsers = {
'TCP': _parseClientTCP,
'SSL': _parseClientSSL,
'UNIX': _parseClientUNIX,
}
def clientFromString(reactor, description):
"""
Construct a client endpoint from a description string.
Client description strings are much like server description strings,
although they take all of their arguments as keywords, aside from host and
port.
You can create a TCP client endpoint with the 'host' and 'port' arguments,
like so::
clientFromString(reactor, "tcp:host=www.example.com:port=80")
or, without specifying host and port keywords::
clientFromString(reactor, "tcp:www.example.com:80")
Or you can specify only one or the other, as in the following 2 examples::
clientFromString(reactor, "tcp:host=www.example.com:80")
clientFromString(reactor, "tcp:www.example.com:port=80")
or an SSL client endpoint with those arguments, plus the arguments used by
the server SSL, for a client certificate::
clientFromString(reactor, "ssl:web.example.com:443:"
"privateKey=foo.pem:certKey=foo.pem")
to specify your certificate trust roots, you can identify a directory with
PEM files in it with the C{caCertsDir} argument::
clientFromString(reactor, "ssl:host=web.example.com:port=443:"
"caCertsDir=/etc/ssl/certs")
This function is also extensible; new endpoint types may be registered as
L{IStreamClientEndpointStringParser} plugins. See that interface for more
information.
@param reactor: The client endpoint will be constructed with this reactor.
@param description: The strports description to parse.
@return: A new endpoint which can be used to connect with the parameters
given by by C{description}.
@rtype: L{IStreamClientEndpoint}
@since: 10.2
"""
args, kwargs = _parse(description)
aname = args.pop(0)
name = aname.upper()
for plugin in getPlugins(IStreamClientEndpointStringParser):
if plugin.prefix.upper() == name:
return plugin.parseStreamClient(*args, **kwargs)
if name not in _clientParsers:
raise ValueError("Unknown endpoint type: %r" % (aname,))
kwargs = _clientParsers[name](*args, **kwargs)
return _endpointClientFactories[name](reactor, **kwargs)
calendarserver-5.2+dfsg/twext/who/ 0000755 0001750 0001750 00000000000 12322625326 016262 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/who/index.py 0000644 0001750 0001750 00000015727 12263343324 017756 0 ustar rahul rahul # -*- test-case-name: twext.who.test.test_xml -*-
##
# Copyright (c) 2006-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Indexed directory service implementation.
"""
__all__ = [
"DirectoryService",
"DirectoryRecord",
]
from itertools import chain
from twisted.python.constants import Names, NamedConstant
from twisted.internet.defer import succeed, inlineCallbacks, returnValue
from twext.who.util import ConstantsContainer
from twext.who.util import describe, uniqueResult, iterFlags
from twext.who.idirectory import FieldName as BaseFieldName
from twext.who.expression import MatchExpression, MatchType, MatchFlags
from twext.who.directory import DirectoryService as BaseDirectoryService
from twext.who.directory import DirectoryRecord as BaseDirectoryRecord
##
# Data type extentions
##
class FieldName(Names):
memberUIDs = NamedConstant()
memberUIDs.description = "member UIDs"
memberUIDs.multiValue = True
##
# Directory Service
##
class DirectoryService(BaseDirectoryService):
"""
XML directory service.
"""
fieldName = ConstantsContainer(chain(
BaseDirectoryService.fieldName.iterconstants(),
FieldName.iterconstants()
))
indexedFields = (
BaseFieldName.recordType,
BaseFieldName.uid,
BaseFieldName.guid,
BaseFieldName.shortNames,
BaseFieldName.emailAddresses,
FieldName.memberUIDs,
)
def __init__(self, realmName):
BaseDirectoryService.__init__(self, realmName)
self.flush()
@property
def index(self):
self.loadRecords()
return self._index
@index.setter
def index(self, value):
self._index = value
def loadRecords(self):
"""
Load records.
"""
raise NotImplementedError("Subclasses must implement loadRecords().")
def flush(self):
"""
Flush the index.
"""
self._index = None
@staticmethod
def _queryFlags(flags):
predicate = lambda x: x
normalize = lambda x: x
if flags is not None:
for flag in iterFlags(flags):
if flag == MatchFlags.NOT:
predicate = lambda x: not x
elif flag == MatchFlags.caseInsensitive:
normalize = lambda x: x.lower()
else:
raise NotImplementedError(
"Unknown query flag: {0}".format(describe(flag))
)
return predicate, normalize
def indexedRecordsFromMatchExpression(self, expression, records=None):
"""
Finds records in the internal indexes matching a single
expression.
@param expression: an expression
@type expression: L{object}
"""
predicate, normalize = self._queryFlags(expression.flags)
fieldIndex = self.index[expression.fieldName]
matchValue = normalize(expression.fieldValue)
matchType = expression.matchType
if matchType == MatchType.startsWith:
indexKeys = (
key for key in fieldIndex
if predicate(normalize(key).startswith(matchValue))
)
elif matchType == MatchType.contains:
indexKeys = (
key for key in fieldIndex
if predicate(matchValue in normalize(key))
)
elif matchType == MatchType.equals:
if predicate(True):
indexKeys = (matchValue,)
else:
indexKeys = (
key for key in fieldIndex
if normalize(key) != matchValue
)
else:
raise NotImplementedError(
"Unknown match type: {0}".format(describe(matchType))
)
matchingRecords = set()
for key in indexKeys:
matchingRecords |= fieldIndex.get(key, frozenset())
if records is not None:
matchingRecords &= records
return succeed(matchingRecords)
def unIndexedRecordsFromMatchExpression(self, expression, records=None):
"""
Finds records not in the internal indexes matching a single
expression.
@param expression: an expression
@type expression: L{object}
"""
predicate, normalize = self._queryFlags(expression.flags)
matchValue = normalize(expression.fieldValue)
matchType = expression.matchType
if matchType == MatchType.startsWith:
match = lambda fieldValue: predicate(
fieldValue.startswith(matchValue)
)
elif matchType == MatchType.contains:
match = lambda fieldValue: predicate(matchValue in fieldValue)
elif matchType == MatchType.equals:
match = lambda fieldValue: predicate(fieldValue == matchValue)
else:
raise NotImplementedError(
"Unknown match type: {0}".format(describe(matchType))
)
result = set()
if records is None:
records = (
uniqueResult(values) for values
in self.index[self.fieldName.uid].itervalues()
)
for record in records:
fieldValues = record.fields.get(expression.fieldName, None)
if fieldValues is None:
continue
for fieldValue in fieldValues:
if match(normalize(fieldValue)):
result.add(record)
return succeed(result)
def recordsFromExpression(self, expression, records=None):
if isinstance(expression, MatchExpression):
if expression.fieldName in self.indexedFields:
return self.indexedRecordsFromMatchExpression(
expression, records=records
)
else:
return self.unIndexedRecordsFromMatchExpression(
expression, records=records
)
else:
return BaseDirectoryService.recordsFromExpression(
self, expression, records=records
)
class DirectoryRecord(BaseDirectoryRecord):
"""
XML directory record
"""
@inlineCallbacks
def members(self):
members = set()
for uid in getattr(self, "memberUIDs", ()):
members.add((yield self.service.recordWithUID(uid)))
returnValue(members)
def groups(self):
return self.service.recordsWithFieldValue(
FieldName.memberUIDs, self.uid
)
calendarserver-5.2+dfsg/twext/who/expression.py 0000644 0001750 0001750 00000005136 12263343324 021037 0 ustar rahul rahul # -*- test-case-name: twext.who.test.test_expression -*-
##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Directory query expressions.
"""
__all__ = [
"MatchType",
"MatchFlags",
"MatchExpression",
]
from twisted.python.constants import Names, NamedConstant
from twisted.python.constants import Flags, FlagConstant
##
# Match expression
##
class MatchType(Names):
"""
Query match types.
"""
equals = NamedConstant()
startsWith = NamedConstant()
contains = NamedConstant()
equals.description = "equals"
startsWith.description = "starts with"
contains.description = "contains"
class MatchFlags(Flags):
"""
Match expression flags.
"""
NOT = FlagConstant()
NOT.description = "not"
caseInsensitive = FlagConstant()
caseInsensitive.description = "case insensitive"
class MatchExpression(object):
"""
Query for a matching value in a given field.
@ivar fieldName: a L{NamedConstant} specifying the field
@ivar fieldValue: a text value to match
@ivar matchType: a L{NamedConstant} specifying the match algorythm
@ivar flags: L{NamedConstant} specifying additional options
"""
def __init__(
self,
fieldName, fieldValue,
matchType=MatchType.equals, flags=None
):
self.fieldName = fieldName
self.fieldValue = fieldValue
self.matchType = matchType
self.flags = flags
def __repr__(self):
def describe(constant):
return getattr(constant, "description", str(constant))
if self.flags is None:
flags = ""
else:
flags = " ({0})".format(describe(self.flags))
return (
"<{self.__class__.__name__}: {fieldName!r} "
"{matchType} {fieldValue!r}{flags}>"
.format(
self=self,
fieldName=describe(self.fieldName),
matchType=describe(self.matchType),
fieldValue=describe(self.fieldValue),
flags=flags,
)
)
calendarserver-5.2+dfsg/twext/who/directory.py 0000644 0001750 0001750 00000026744 12263343324 020654 0 ustar rahul rahul # -*- test-case-name: twext.who.test.test_directory -*-
##
# Copyright (c) 2006-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Generic directory service base implementation
"""
__all__ = [
"DirectoryService",
"DirectoryRecord",
]
from uuid import UUID
from zope.interface import implementer
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.defer import succeed, fail
from twext.who.idirectory import QueryNotSupportedError, NotAllowedError
from twext.who.idirectory import FieldName, RecordType
from twext.who.idirectory import Operand
from twext.who.idirectory import IDirectoryService, IDirectoryRecord
from twext.who.expression import MatchExpression
from twext.who.util import uniqueResult, describe
@implementer(IDirectoryService)
class DirectoryService(object):
"""
Generic implementation of L{IDirectoryService}.
This is a complete implementation of L{IDirectoryService}, with support for
the query operands in L{Operand}.
The C{recordsWith*} methods are all implemented in terms of
L{recordsWithFieldValue}, which is in turn implemented in terms of
L{recordsFromExpression}.
L{recordsFromQuery} is also implemented in terms of
{recordsFromExpression}.
L{recordsFromExpression} (and therefore most uses of the other methods)
will always fail with a L{QueryNotSupportedError}.
A subclass should therefore override L{recordsFromExpression} with an
implementation that handles any queries that it can support and its
superclass' implementation with any query it cannot support.
A subclass may override L{recordsFromQuery} if it is to support additional
operands.
L{updateRecords} and L{removeRecords} will fail with L{NotAllowedError}
when asked to modify data.
A subclass should override these methods if is to allow editing of
directory information.
@cvar recordType: a L{Names} class or compatible object (eg.
L{ConstantsContainer}) which contains the L{NamedConstant}s denoting
the record types that are supported by this directory service.
@cvar fieldName: a L{Names} class or compatible object (eg.
L{ConstantsContainer}) which contains the L{NamedConstant}s denoting
the record field names that are supported by this directory service.
@cvar normalizedFields: a L{dict} mapping of (ie. L{NamedConstant}s
contained in the C{fieldName} class variable) to callables that take
a field value (a L{unicode}) and return a normalized field value (also
a L{unicode}).
"""
recordType = RecordType
fieldName = FieldName
normalizedFields = {
FieldName.guid: lambda g: UUID(g).hex,
FieldName.emailAddresses: lambda e: bytes(e).lower(),
}
def __init__(self, realmName):
"""
@param realmName: a realm name
@type realmName: unicode
"""
self.realmName = realmName
def __repr__(self):
return (
"<{self.__class__.__name__} {self.realmName!r}>"
.format(self=self)
)
def recordTypes(self):
return self.recordType.iterconstants()
def recordsFromExpression(self, expression, records=None):
"""
Finds records matching a single expression.
@note: The implementation in L{DirectoryService} always raises
L{QueryNotSupportedError}.
@note: This L{DirectoryService} adds a C{records} keyword argument to
the interface defined by L{IDirectoryService}.
This allows the implementation of
L{DirectoryService.recordsFromQuery} to narrow the scope of records
being searched as it applies expressions.
This is therefore relevant to subclasses, which need to support the
added parameter, but not to users of L{IDirectoryService}.
@param expression: an expression to apply
@type expression: L{object}
@param records: a set of records to limit the search to. C{None} if
the whole directory should be searched.
@type records: L{set} or L{frozenset}
@return: The matching records.
@rtype: deferred iterable of L{IDirectoryRecord}s
@raises: L{QueryNotSupportedError} if the expression is not
supported by this directory service.
"""
return fail(QueryNotSupportedError(
"Unknown expression: {0}".format(expression)
))
@inlineCallbacks
def recordsFromQuery(self, expressions, operand=Operand.AND):
expressionIterator = iter(expressions)
try:
expression = expressionIterator.next()
except StopIteration:
returnValue(())
results = set((yield self.recordsFromExpression(expression)))
for expression in expressions:
if operand == Operand.AND:
if not results:
# No need to bother continuing here
returnValue(())
records = results
else:
records = None
recordsMatchingExpression = frozenset((
yield self.recordsFromExpression(expression, records=records)
))
if operand == Operand.AND:
results &= recordsMatchingExpression
elif operand == Operand.OR:
results |= recordsMatchingExpression
else:
raise QueryNotSupportedError(
"Unknown operand: {0}".format(operand)
)
returnValue(results)
def recordsWithFieldValue(self, fieldName, value):
return self.recordsFromExpression(MatchExpression(fieldName, value))
@inlineCallbacks
def recordWithUID(self, uid):
returnValue(uniqueResult(
(yield self.recordsWithFieldValue(FieldName.uid, uid))
))
@inlineCallbacks
def recordWithGUID(self, guid):
returnValue(uniqueResult(
(yield self.recordsWithFieldValue(FieldName.guid, guid))
))
def recordsWithRecordType(self, recordType):
return self.recordsWithFieldValue(FieldName.recordType, recordType)
@inlineCallbacks
def recordWithShortName(self, recordType, shortName):
returnValue(uniqueResult((yield self.recordsFromQuery((
MatchExpression(FieldName.recordType, recordType),
MatchExpression(FieldName.shortNames, shortName),
)))))
def recordsWithEmailAddress(self, emailAddress):
return self.recordsWithFieldValue(
FieldName.emailAddresses,
emailAddress,
)
def updateRecords(self, records, create=False):
for record in records:
return fail(NotAllowedError("Record updates not allowed."))
return succeed(None)
def removeRecords(self, uids):
for uid in uids:
return fail(NotAllowedError("Record removal not allowed."))
return succeed(None)
@implementer(IDirectoryRecord)
class DirectoryRecord(object):
"""
Generic implementation of L{IDirectoryService}.
This is an incomplete implementation of L{IDirectoryRecord}.
L{groups} will always fail with L{NotImplementedError} and L{members} will
do so if this is a group record.
A subclass should override these methods to support group membership and
complete this implementation.
@cvar requiredFields: an iterable of field names that must be present in
all directory records.
"""
requiredFields = (
FieldName.uid,
FieldName.recordType,
FieldName.shortNames,
)
def __init__(self, service, fields):
for fieldName in self.requiredFields:
if fieldName not in fields or not fields[fieldName]:
raise ValueError("{0} field is required.".format(fieldName))
if FieldName.isMultiValue(fieldName):
values = fields[fieldName]
if len(values) == 0:
raise ValueError(
"{0} field must have at least one value."
.format(fieldName)
)
for value in values:
if not value:
raise ValueError(
"{0} field must not be empty.".format(fieldName)
)
if (
fields[FieldName.recordType] not in
service.recordType.iterconstants()
):
raise ValueError(
"Record type must be one of {0!r}, not {1!r}.".format(
tuple(service.recordType.iterconstants()),
fields[FieldName.recordType],
)
)
# Normalize fields
normalizedFields = {}
for name, value in fields.items():
normalize = service.normalizedFields.get(name, None)
if normalize is None:
normalizedFields[name] = value
continue
if FieldName.isMultiValue(name):
normalizedFields[name] = tuple((normalize(v) for v in value))
else:
normalizedFields[name] = normalize(value)
self.service = service
self.fields = normalizedFields
def __repr__(self):
return (
"<{self.__class__.__name__} ({recordType}){shortName}>".format(
self=self,
recordType=describe(self.recordType),
shortName=self.shortNames[0],
)
)
def __eq__(self, other):
if IDirectoryRecord.implementedBy(other.__class__):
return (
self.service == other.service and
self.fields == other.fields
)
return NotImplemented
def __ne__(self, other):
eq = self.__eq__(other)
if eq is NotImplemented:
return NotImplemented
return not eq
def __getattr__(self, name):
try:
fieldName = self.service.fieldName.lookupByName(name)
except ValueError:
raise AttributeError(name)
try:
return self.fields[fieldName]
except KeyError:
raise AttributeError(name)
def description(self):
description = [self.__class__.__name__, ":"]
for name, value in self.fields.items():
if hasattr(name, "description"):
name = name.description
else:
name = str(name)
if hasattr(value, "description"):
value = value.description
else:
value = str(value)
description.append("\n ")
description.append(name)
description.append(" = ")
description.append(value)
return "".join(description)
def members(self):
if self.recordType == RecordType.group:
return fail(
NotImplementedError("Subclasses must implement members()")
)
return succeed(())
def groups(self):
return fail(NotImplementedError("Subclasses must implement groups()"))
calendarserver-5.2+dfsg/twext/who/test/ 0000755 0001750 0001750 00000000000 12322625326 017241 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/who/test/test_xml.py 0000644 0001750 0001750 00000062301 12263343324 021453 0 ustar rahul rahul ##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
XML directory service tests.
"""
from time import sleep
from twisted.trial import unittest
from twisted.python.filepath import FilePath
from twisted.internet.defer import inlineCallbacks
from twext.who.idirectory import NoSuchRecordError
from twext.who.idirectory import Operand
from twext.who.expression import MatchExpression, MatchType, MatchFlags
from twext.who.xml import ParseError
from twext.who.xml import DirectoryService, DirectoryRecord
from twext.who.test import test_directory
class BaseTest(unittest.TestCase):
def service(self, xmlData=None):
return xmlService(self.mktemp(), xmlData)
def assertRecords(self, records, uids):
self.assertEquals(
frozenset((record.uid for record in records)),
frozenset((uids)),
)
class DirectoryServiceBaseTest(
BaseTest,
test_directory.BaseDirectoryServiceTest,
):
def test_repr(self):
service = self.service()
self.assertEquals(repr(service), "")
service.loadRecords()
self.assertEquals(repr(service), "")
@inlineCallbacks
def test_recordWithUID(self):
service = self.service()
record = (yield service.recordWithUID("__null__"))
self.assertEquals(record, None)
record = (yield service.recordWithUID("__wsanchez__"))
self.assertEquals(record.uid, "__wsanchez__")
@inlineCallbacks
def test_recordWithGUID(self):
service = self.service()
record = (
yield service.recordWithGUID(
"6C495FCD-7E78-4D5C-AA66-BC890AD04C9D"
)
)
self.assertEquals(record, None)
@inlineCallbacks
def test_recordsWithRecordType(self):
service = self.service()
records = (yield service.recordsWithRecordType(object()))
self.assertEquals(set(records), set())
records = (
yield service.recordsWithRecordType(service.recordType.user)
)
self.assertRecords(
records,
(
"__wsanchez__",
"__glyph__",
"__sagen__",
"__cdaboo__",
"__dre__",
"__exarkun__",
"__dreid__",
"__alyssa__",
"__joe__",
),
)
records = (
yield service.recordsWithRecordType(service.recordType.group)
)
self.assertRecords(
records,
(
"__calendar-dev__",
"__twisted__",
"__developers__",
),
)
@inlineCallbacks
def test_recordWithShortName(self):
service = self.service()
record = (
yield service.recordWithShortName(
service.recordType.user,
"null",
)
)
self.assertEquals(record, None)
record = (
yield service.recordWithShortName(
service.recordType.user,
"wsanchez",
)
)
self.assertEquals(record.uid, "__wsanchez__")
record = (
yield service.recordWithShortName(
service.recordType.user,
"wilfredo_sanchez",
)
)
self.assertEquals(record.uid, "__wsanchez__")
@inlineCallbacks
def test_recordsWithEmailAddress(self):
service = self.service()
records = (
yield service.recordsWithEmailAddress(
"wsanchez@bitbucket.calendarserver.org"
)
)
self.assertRecords(records, ("__wsanchez__",))
records = (
yield service.recordsWithEmailAddress(
"wsanchez@devnull.twistedmatrix.com"
)
)
self.assertRecords(records, ("__wsanchez__",))
records = (
yield service.recordsWithEmailAddress(
"shared@example.com"
)
)
self.assertRecords(records, ("__sagen__", "__dre__"))
class DirectoryServiceRealmTest(BaseTest):
def test_realmNameImmutable(self):
def setRealmName():
service = self.service()
service.realmName = "foo"
self.assertRaises(AssertionError, setRealmName)
class DirectoryServiceParsingTest(BaseTest):
def test_reloadInterval(self):
service = self.service()
service.loadRecords(stat=False)
lastRefresh = service._lastRefresh
self.assertTrue(service._lastRefresh)
sleep(1)
service.loadRecords(stat=False)
self.assertEquals(lastRefresh, service._lastRefresh)
def test_reloadStat(self):
service = self.service()
service.loadRecords(loadNow=True)
lastRefresh = service._lastRefresh
self.assertTrue(service._lastRefresh)
sleep(1)
service.loadRecords(loadNow=True)
self.assertEquals(lastRefresh, service._lastRefresh)
def test_badXML(self):
service = self.service(xmlData="Hello")
self.assertRaises(ParseError, service.loadRecords)
def test_badRootElement(self):
service = self.service(xmlData=(
"""
"""
))
self.assertRaises(ParseError, service.loadRecords)
try:
service.loadRecords()
except ParseError as e:
self.assertTrue(str(e).startswith("Incorrect root element"), e)
else:
raise AssertionError
def test_noRealmName(self):
service = self.service(xmlData=(
"""
"""
))
self.assertRaises(ParseError, service.loadRecords)
try:
service.loadRecords()
except ParseError as e:
self.assertTrue(str(e).startswith("No realm name"), e)
else:
raise AssertionError
def test_unknownFieldElementsClean(self):
service = self.service()
self.assertEquals(set(service.unknownFieldElements), set())
def test_unknownFieldElementsDirty(self):
service = self.service(xmlData=(
"""
__wsanchez__
wsanchez
Community and Freedom Party
"""
))
self.assertEquals(
set(service.unknownFieldElements),
set(("political-affiliation",))
)
def test_unknownRecordTypesClean(self):
service = self.service()
self.assertEquals(set(service.unknownRecordTypes), set())
def test_unknownRecordTypesDirty(self):
service = self.service(xmlData=(
"""
__d600__
d600
Nikon D600
"""
))
self.assertEquals(set(service.unknownRecordTypes), set(("camera",)))
class DirectoryServiceQueryTest(BaseTest):
@inlineCallbacks
def test_queryAnd(self):
service = self.service()
records = yield service.recordsFromQuery(
(
service.query("emailAddresses", "shared@example.com"),
service.query("shortNames", "sagen"),
),
operand=Operand.AND
)
self.assertRecords(records, ("__sagen__",))
@inlineCallbacks
def test_queryAndNoneFirst(self):
"""
Test optimized case, where first expression yields no results.
"""
service = self.service()
records = yield service.recordsFromQuery(
(
service.query("emailAddresses", "nobody@example.com"),
service.query("shortNames", "sagen"),
),
operand=Operand.AND
)
self.assertRecords(records, ())
@inlineCallbacks
def test_queryOr(self):
service = self.service()
records = yield service.recordsFromQuery(
(
service.query("emailAddresses", "shared@example.com"),
service.query("shortNames", "wsanchez"),
),
operand=Operand.OR
)
self.assertRecords(records, ("__sagen__", "__dre__", "__wsanchez__"))
@inlineCallbacks
def test_queryNot(self):
service = self.service()
records = yield service.recordsFromQuery(
(
service.query("emailAddresses", "shared@example.com"),
service.query("shortNames", "sagen", flags=MatchFlags.NOT),
),
operand=Operand.AND
)
self.assertRecords(records, ("__dre__",))
@inlineCallbacks
def test_queryNotNoIndex(self):
service = self.service()
records = yield service.recordsFromQuery(
(
service.query("emailAddresses", "shared@example.com"),
service.query(
"fullNames", "Andre LaBranche",
flags=MatchFlags.NOT
),
),
operand=Operand.AND
)
self.assertRecords(records, ("__sagen__",))
@inlineCallbacks
def test_queryCaseInsensitive(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"shortNames", "SagEn",
flags=MatchFlags.caseInsensitive
),
))
self.assertRecords(records, ("__sagen__",))
@inlineCallbacks
def test_queryCaseInsensitiveNoIndex(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"fullNames", "moRGen SAGen",
flags=MatchFlags.caseInsensitive
),
))
self.assertRecords(records, ("__sagen__",))
@inlineCallbacks
def test_queryStartsWith(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query("shortNames", "wil", matchType=MatchType.startsWith),
))
self.assertRecords(records, ("__wsanchez__",))
@inlineCallbacks
def test_queryStartsWithNoIndex(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"fullNames", "Wilfredo",
matchType=MatchType.startsWith
),
))
self.assertRecords(records, ("__wsanchez__",))
@inlineCallbacks
def test_queryStartsWithNot(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"shortNames", "w",
matchType=MatchType.startsWith,
flags=MatchFlags.NOT,
),
))
self.assertRecords(
records,
(
'__alyssa__',
'__calendar-dev__',
'__cdaboo__',
'__developers__',
'__dre__',
'__dreid__',
'__exarkun__',
'__glyph__',
'__joe__',
'__sagen__',
'__twisted__',
),
)
@inlineCallbacks
def test_queryStartsWithNotAny(self):
"""
FIXME?: In the this case, the record __wsanchez__ has two
shortNames, and one doesn't match the query. Should it be
included or not? It is, because one matches the query, but
should NOT require that all match?
"""
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"shortNames", "wil",
matchType=MatchType.startsWith,
flags=MatchFlags.NOT,
),
))
self.assertRecords(
records,
(
'__alyssa__',
'__calendar-dev__',
'__cdaboo__',
'__developers__',
'__dre__',
'__dreid__',
'__exarkun__',
'__glyph__',
'__joe__',
'__sagen__',
'__twisted__',
'__wsanchez__',
),
)
@inlineCallbacks
def test_queryStartsWithNotNoIndex(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"fullNames", "Wilfredo",
matchType=MatchType.startsWith,
flags=MatchFlags.NOT,
),
))
self.assertRecords(
records,
(
'__alyssa__',
'__calendar-dev__',
'__cdaboo__',
'__developers__',
'__dre__',
'__dreid__',
'__exarkun__',
'__glyph__',
'__joe__',
'__sagen__',
'__twisted__',
),
)
@inlineCallbacks
def test_queryStartsWithCaseInsensitive(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"shortNames", "WIL",
matchType=MatchType.startsWith,
flags=MatchFlags.caseInsensitive,
),
))
self.assertRecords(records, ("__wsanchez__",))
@inlineCallbacks
def test_queryStartsWithCaseInsensitiveNoIndex(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"fullNames", "wilfrEdo",
matchType=MatchType.startsWith,
flags=MatchFlags.caseInsensitive,
),
))
self.assertRecords(records, ("__wsanchez__",))
@inlineCallbacks
def test_queryContains(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"shortNames", "sanchez",
matchType=MatchType.contains
),
))
self.assertRecords(records, ("__wsanchez__",))
@inlineCallbacks
def test_queryContainsNoIndex(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query("fullNames", "fred", matchType=MatchType.contains),
))
self.assertRecords(records, ("__wsanchez__",))
@inlineCallbacks
def test_queryContainsNot(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"shortNames", "sanchez",
matchType=MatchType.contains,
flags=MatchFlags.NOT,
),
))
self.assertRecords(
records,
(
'__alyssa__',
'__calendar-dev__',
'__cdaboo__',
'__developers__',
'__dre__',
'__dreid__',
'__exarkun__',
'__glyph__',
'__joe__',
'__sagen__',
'__twisted__',
),
)
@inlineCallbacks
def test_queryContainsNotNoIndex(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"fullNames", "fred",
matchType=MatchType.contains,
flags=MatchFlags.NOT,
),
))
self.assertRecords(
records,
(
'__alyssa__',
'__calendar-dev__',
'__cdaboo__',
'__developers__',
'__dre__',
'__dreid__',
'__exarkun__',
'__glyph__',
'__joe__',
'__sagen__',
'__twisted__',
),
)
@inlineCallbacks
def test_queryContainsCaseInsensitive(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"shortNames", "Sanchez",
matchType=MatchType.contains,
flags=MatchFlags.caseInsensitive,
),
))
self.assertRecords(records, ("__wsanchez__",))
@inlineCallbacks
def test_queryContainsCaseInsensitiveNoIndex(self):
service = self.service()
records = yield service.recordsFromQuery((
service.query(
"fullNames", "frEdo",
matchType=MatchType.contains,
flags=MatchFlags.caseInsensitive,
),
))
self.assertRecords(records, ("__wsanchez__",))
class DirectoryServiceMutableTest(BaseTest):
@inlineCallbacks
def test_updateRecord(self):
service = self.service()
record = (yield service.recordWithUID("__wsanchez__"))
fields = record.fields.copy()
fields[service.fieldName.fullNames] = ["Wilfredo Sanchez Vega"]
updatedRecord = DirectoryRecord(service, fields)
yield service.updateRecords((updatedRecord,))
# Verify change is present immediately
record = (yield service.recordWithUID("__wsanchez__"))
self.assertEquals(
set(record.fullNames),
set(("Wilfredo Sanchez Vega",))
)
# Verify change is persisted
service.flush()
record = (yield service.recordWithUID("__wsanchez__"))
self.assertEquals(
set(record.fullNames),
set(("Wilfredo Sanchez Vega",))
)
@inlineCallbacks
def test_addRecord(self):
service = self.service()
newRecord = DirectoryRecord(
service,
fields={
service.fieldName.uid: "__plugh__",
service.fieldName.recordType: service.recordType.user,
service.fieldName.shortNames: ("plugh",),
}
)
yield service.updateRecords((newRecord,), create=True)
# Verify change is present immediately
record = (yield service.recordWithUID("__plugh__"))
self.assertEquals(set(record.shortNames), set(("plugh",)))
# Verify change is persisted
service.flush()
record = (yield service.recordWithUID("__plugh__"))
self.assertEquals(set(record.shortNames), set(("plugh",)))
def test_addRecordNoCreate(self):
service = self.service()
newRecord = DirectoryRecord(
service,
fields={
service.fieldName.uid: "__plugh__",
service.fieldName.recordType: service.recordType.user,
service.fieldName.shortNames: ("plugh",),
}
)
self.assertFailure(
service.updateRecords((newRecord,)),
NoSuchRecordError
)
@inlineCallbacks
def test_removeRecord(self):
service = self.service()
yield service.removeRecords(("__wsanchez__",))
# Verify change is present immediately
self.assertEquals((yield service.recordWithUID("__wsanchez__")), None)
# Verify change is persisted
service.flush()
self.assertEquals((yield service.recordWithUID("__wsanchez__")), None)
def test_removeRecordNoExist(self):
service = self.service()
return service.removeRecords(("__plugh__",))
class DirectoryRecordTest(BaseTest, test_directory.BaseDirectoryRecordTest):
@inlineCallbacks
def test_members(self):
service = self.service()
record = (yield service.recordWithUID("__wsanchez__"))
members = (yield record.members())
self.assertEquals(set(members), set())
record = (yield service.recordWithUID("__twisted__"))
members = (yield record.members())
self.assertEquals(
set((member.uid for member in members)),
set((
"__wsanchez__",
"__glyph__",
"__exarkun__",
"__dreid__",
"__dre__",
))
)
record = (yield service.recordWithUID("__developers__"))
members = (yield record.members())
self.assertEquals(
set((member.uid for member in members)),
set((
"__calendar-dev__",
"__twisted__",
"__alyssa__",
))
)
@inlineCallbacks
def test_groups(self):
service = self.service()
record = (yield service.recordWithUID("__wsanchez__"))
groups = (yield record.groups())
self.assertEquals(
set(group.uid for group in groups),
set((
"__calendar-dev__",
"__twisted__",
))
)
class QueryMixIn(object):
def query(self, field, value, matchType=MatchType.equals, flags=None):
name = getattr(self.fieldName, field)
assert name is not None
return MatchExpression(
name, value,
matchType=matchType,
flags=flags,
)
class TestService(DirectoryService, QueryMixIn):
pass
def xmlService(tmp, xmlData=None, serviceClass=None):
if xmlData is None:
xmlData = testXMLConfig
if serviceClass is None:
serviceClass = TestService
filePath = FilePath(tmp)
filePath.setContent(xmlData)
return serviceClass(filePath)
testXMLConfig = """
__wsanchez__
wsanchez
wilfredo_sanchez
Wilfredo Sanchez
zehcnasw
wsanchez@bitbucket.calendarserver.org
wsanchez@devnull.twistedmatrix.com
__glyph__
glyph
Glyph Lefkowitz
hpylg
glyph@bitbucket.calendarserver.org
glyph@devnull.twistedmatrix.com
__sagen__
sagen
Morgen Sagen
negas
sagen@bitbucket.calendarserver.org
shared@example.com
__cdaboo__
cdaboo
Cyrus Daboo
suryc
cdaboo@bitbucket.calendarserver.org
__dre__
dre
Andre LaBranche
erd
dre@bitbucket.calendarserver.org
shared@example.com
__exarkun__
exarkun
Jean-Paul Calderone
nucraxe
exarkun@devnull.twistedmatrix.com
__dreid__
dreid
David Reid
dierd
dreid@devnull.twistedmatrix.com
__joe__
joe
Joe Schmoe
eoj
joe@example.com
__alyssa__
alyssa
Alyssa P. Hacker
assyla
alyssa@example.com
__calendar-dev__
calendar-dev
Calendar Server developers
dev@bitbucket.calendarserver.org
__wsanchez__
__glyph__
__sagen__
__cdaboo__
__dre__
__twisted__
twisted
Twisted Matrix Laboratories
hack@devnull.twistedmatrix.com
__wsanchez__
__glyph__
__exarkun__
__dreid__
__dre__
__developers__
developers
All Developers
__calendar-dev__
__twisted__
__alyssa__
"""
calendarserver-5.2+dfsg/twext/who/test/test_aggregate.py 0000644 0001750 0001750 00000015623 12263343324 022606 0 ustar rahul rahul ##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Aggregate directory service tests.
"""
from twisted.python.components import proxyForInterface
from twisted.trial import unittest
from twext.who.idirectory import IDirectoryService, DirectoryConfigurationError
from twext.who.aggregate import DirectoryService
from twext.who.util import ConstantsContainer
from twext.who.test import test_directory, test_xml
from twext.who.test.test_xml import QueryMixIn, xmlService
from twext.who.test.test_xml import TestService as XMLTestService
class BaseTest(object):
def service(self, services=None):
if services is None:
services = (self.xmlService(),)
#
# Make sure aggregate DirectoryService isn't making
# implementation assumptions about the IDirectoryService
# objects it gets.
#
services = tuple((
proxyForInterface(IDirectoryService)(s)
for s in services
))
class TestService(DirectoryService, QueryMixIn):
pass
return TestService("xyzzy", services)
def xmlService(self, xmlData=None, serviceClass=None):
return xmlService(self.mktemp(), xmlData, serviceClass)
class DirectoryServiceBaseTest(BaseTest, test_xml.DirectoryServiceBaseTest):
def test_repr(self):
service = self.service()
self.assertEquals(repr(service), "")
class DirectoryServiceQueryTest(BaseTest, test_xml.DirectoryServiceQueryTest):
pass
class DirectoryServiceImmutableTest(
BaseTest,
test_directory.BaseDirectoryServiceImmutableTest,
):
pass
class AggregatedBaseTest(BaseTest):
def service(self):
class UsersDirectoryService(XMLTestService):
recordType = ConstantsContainer((XMLTestService.recordType.user,))
class GroupsDirectoryService(XMLTestService):
recordType = ConstantsContainer((XMLTestService.recordType.group,))
usersService = self.xmlService(
testXMLConfigUsers,
UsersDirectoryService
)
groupsService = self.xmlService(
testXMLConfigGroups,
GroupsDirectoryService
)
return BaseTest.service(self, (usersService, groupsService))
class DirectoryServiceAggregatedBaseTest(
AggregatedBaseTest,
DirectoryServiceBaseTest,
):
pass
class DirectoryServiceAggregatedQueryTest(
AggregatedBaseTest,
test_xml.DirectoryServiceQueryTest,
):
pass
class DirectoryServiceAggregatedImmutableTest(
AggregatedBaseTest,
test_directory.BaseDirectoryServiceImmutableTest,
):
pass
class DirectoryServiceTests(BaseTest, unittest.TestCase):
def test_conflictingRecordTypes(self):
self.assertRaises(
DirectoryConfigurationError,
BaseTest.service, self,
(self.xmlService(), self.xmlService(testXMLConfigUsers)),
)
testXMLConfigUsers = """
__wsanchez__
wsanchez
wilfredo_sanchez
Wilfredo Sanchez
zehcnasw
wsanchez@bitbucket.calendarserver.org
wsanchez@devnull.twistedmatrix.com
__glyph__
glyph
Glyph Lefkowitz
hpylg
glyph@bitbucket.calendarserver.org
glyph@devnull.twistedmatrix.com
__sagen__
sagen
Morgen Sagen
negas
sagen@bitbucket.calendarserver.org
shared@example.com
__cdaboo__
cdaboo
Cyrus Daboo
suryc
cdaboo@bitbucket.calendarserver.org
__dre__
dre
Andre LaBranche
erd
dre@bitbucket.calendarserver.org
shared@example.com
__exarkun__
exarkun
Jean-Paul Calderone
nucraxe
exarkun@devnull.twistedmatrix.com
__dreid__
dreid
David Reid
dierd
dreid@devnull.twistedmatrix.com
__joe__
joe
Joe Schmoe
eoj
joe@example.com
__alyssa__
alyssa
Alyssa P. Hacker
assyla
alyssa@example.com
"""
testXMLConfigGroups = """
__calendar-dev__
calendar-dev
Calendar Server developers
dev@bitbucket.calendarserver.org
__wsanchez__
__glyph__
__sagen__
__cdaboo__
__dre__
__twisted__
twisted
Twisted Matrix Laboratories
hack@devnull.twistedmatrix.com
__wsanchez__
__glyph__
__exarkun__
__dreid__
__dre__
__developers__
developers
All Developers
__calendar-dev__
__twisted__
__alyssa__
"""
calendarserver-5.2+dfsg/twext/who/test/test_util.py 0000644 0001750 0001750 00000007012 12263343324 021626 0 ustar rahul rahul ##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Directory service utility tests.
"""
from twisted.trial import unittest
from twisted.python.constants import Names, NamedConstant
from twisted.python.constants import Flags, FlagConstant
from twext.who.idirectory import DirectoryServiceError
from twext.who.util import ConstantsContainer
from twext.who.util import uniqueResult, describe
class Tools(Names):
hammer = NamedConstant()
screwdriver = NamedConstant()
hammer.description = "nail pounder"
screwdriver.description = "screw twister"
class Instruments(Names):
hammer = NamedConstant()
chisel = NamedConstant()
class Switches(Flags):
r = FlagConstant()
g = FlagConstant()
b = FlagConstant()
r.description = "red"
g.description = "green"
b.description = "blue"
black = FlagConstant()
class ConstantsContainerTest(unittest.TestCase):
def test_conflict(self):
constants = set((Tools.hammer, Instruments.hammer))
self.assertRaises(ValueError, ConstantsContainer, constants)
def test_attrs(self):
constants = set((Tools.hammer, Tools.screwdriver, Instruments.chisel))
container = ConstantsContainer(constants)
self.assertEquals(container.hammer, Tools.hammer)
self.assertEquals(container.screwdriver, Tools.screwdriver)
self.assertEquals(container.chisel, Instruments.chisel)
self.assertRaises(AttributeError, lambda: container.plugh)
def test_iterconstants(self):
constants = set((Tools.hammer, Tools.screwdriver, Instruments.chisel))
container = ConstantsContainer(constants)
self.assertEquals(
set(container.iterconstants()),
constants,
)
def test_lookupByName(self):
constants = set((
Instruments.hammer,
Tools.screwdriver,
Instruments.chisel,
))
container = ConstantsContainer(constants)
self.assertEquals(
container.lookupByName("hammer"),
Instruments.hammer,
)
self.assertEquals(
container.lookupByName("screwdriver"),
Tools.screwdriver,
)
self.assertEquals(
container.lookupByName("chisel"),
Instruments.chisel,
)
self.assertRaises(
ValueError,
container.lookupByName, "plugh",
)
class UtilTest(unittest.TestCase):
def test_uniqueResult(self):
self.assertEquals(1, uniqueResult((1,)))
self.assertRaises(DirectoryServiceError, uniqueResult, (1, 2, 3))
def test_describe(self):
self.assertEquals("nail pounder", describe(Tools.hammer))
self.assertEquals("hammer", describe(Instruments.hammer))
def test_describeFlags(self):
self.assertEquals("blue", describe(Switches.b))
self.assertEquals("red|green", describe(Switches.r | Switches.g))
self.assertEquals("blue|black", describe(Switches.b | Switches.black))
calendarserver-5.2+dfsg/twext/who/test/test_directory.py 0000644 0001750 0001750 00000025626 12263343324 022670 0 ustar rahul rahul ##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Generic directory service base implementation tests.
"""
from zope.interface.verify import verifyObject, BrokenMethodImplementation
from twisted.trial import unittest
from twisted.trial.unittest import SkipTest
from twisted.internet.defer import inlineCallbacks
from twext.who.idirectory import QueryNotSupportedError, NotAllowedError
from twext.who.idirectory import RecordType, FieldName
from twext.who.idirectory import IDirectoryService, IDirectoryRecord
from twext.who.directory import DirectoryService, DirectoryRecord
class ServiceMixIn(object):
realmName = "xyzzy"
def service(self):
if not hasattr(self, "_service"):
self._service = DirectoryService(self.realmName)
return self._service
class BaseDirectoryServiceTest(ServiceMixIn):
def test_interface(self):
service = self.service()
try:
verifyObject(IDirectoryService, service)
except BrokenMethodImplementation as e:
self.fail(e)
def test_init(self):
service = self.service()
self.assertEquals(service.realmName, self.realmName)
def test_repr(self):
service = self.service()
self.assertEquals(repr(service), "")
def test_recordTypes(self):
service = self.service()
self.assertEquals(
set(service.recordTypes()),
set(service.recordType.iterconstants())
)
@inlineCallbacks
def test_recordsFromQueryNone(self):
service = self.service()
records = (yield service.recordsFromQuery(()))
for record in records:
self.failTest("No records expected")
def test_recordsFromQueryBogus(self):
service = self.service()
self.assertFailure(
service.recordsFromQuery((object(),)),
QueryNotSupportedError
)
def test_recordWithUID(self):
raise SkipTest("Subclasses should implement this test.")
def test_recordWithGUID(self):
raise SkipTest("Subclasses should implement this test.")
def test_recordsWithRecordType(self):
raise SkipTest("Subclasses should implement this test.")
def test_recordWithShortName(self):
raise SkipTest("Subclasses should implement this test.")
def test_recordsWithEmailAddress(self):
raise SkipTest("Subclasses should implement this test.")
class DirectoryServiceTest(unittest.TestCase, BaseDirectoryServiceTest):
def test_recordsFromExpression(self):
service = self.service()
result = yield(service.recordsFromExpression(None))
self.assertFailure(result, QueryNotSupportedError)
def test_recordWithUID(self):
service = self.service()
self.assertFailure(
service.recordWithUID(None),
QueryNotSupportedError
)
def test_recordWithGUID(self):
service = self.service()
self.assertFailure(
service.recordWithGUID(None),
QueryNotSupportedError
)
def test_recordsWithRecordType(self):
service = self.service()
self.assertFailure(
service.recordsWithRecordType(None),
QueryNotSupportedError
)
def test_recordWithShortName(self):
service = self.service()
self.assertFailure(
service.recordWithShortName(None, None),
QueryNotSupportedError
)
def test_recordsWithEmailAddress(self):
service = self.service()
self.assertFailure(
service.recordsWithEmailAddress(None),
QueryNotSupportedError
)
class BaseDirectoryServiceImmutableTest(ServiceMixIn):
def test_updateRecordsNotAllowed(self):
service = self.service()
newRecord = DirectoryRecord(
service,
fields={
service.fieldName.uid: "__plugh__",
service.fieldName.recordType: service.recordType.user,
service.fieldName.shortNames: ("plugh",),
}
)
self.assertFailure(
service.updateRecords((newRecord,), create=True),
NotAllowedError,
)
self.assertFailure(
service.updateRecords((newRecord,), create=False),
NotAllowedError,
)
def test_removeRecordsNotAllowed(self):
service = self.service()
service.removeRecords(())
self.assertFailure(
service.removeRecords(("foo",)),
NotAllowedError,
)
class DirectoryServiceImmutableTest(
unittest.TestCase,
BaseDirectoryServiceImmutableTest,
):
pass
class BaseDirectoryRecordTest(ServiceMixIn):
fields_wsanchez = {
FieldName.uid: "UID:wsanchez",
FieldName.recordType: RecordType.user,
FieldName.shortNames: ("wsanchez", "wilfredo_sanchez"),
FieldName.fullNames: (
"Wilfredo Sanchez",
"Wilfredo Sanchez Vega",
),
FieldName.emailAddresses: (
"wsanchez@calendarserver.org",
"wsanchez@example.com",
)
}
fields_glyph = {
FieldName.uid: "UID:glyph",
FieldName.recordType: RecordType.user,
FieldName.shortNames: ("glyph",),
FieldName.fullNames: ("Glyph Lefkowitz",),
FieldName.emailAddresses: ("glyph@calendarserver.org",)
}
fields_sagen = {
FieldName.uid: "UID:sagen",
FieldName.recordType: RecordType.user,
FieldName.shortNames: ("sagen",),
FieldName.fullNames: ("Morgen Sagen",),
FieldName.emailAddresses: ("sagen@CalendarServer.org",)
}
fields_staff = {
FieldName.uid: "UID:staff",
FieldName.recordType: RecordType.group,
FieldName.shortNames: ("staff",),
FieldName.fullNames: ("Staff",),
FieldName.emailAddresses: ("staff@CalendarServer.org",)
}
def makeRecord(self, fields=None, service=None):
if fields is None:
fields = self.fields_wsanchez
if service is None:
service = self.service()
return DirectoryRecord(service, fields)
def test_interface(self):
record = self.makeRecord()
try:
verifyObject(IDirectoryRecord, record)
except BrokenMethodImplementation as e:
self.fail(e)
def test_init(self):
service = self.service()
wsanchez = self.makeRecord(self.fields_wsanchez, service=service)
self.assertEquals(wsanchez.service, service)
self.assertEquals(wsanchez.fields, self.fields_wsanchez)
def test_initWithNoUID(self):
fields = self.fields_wsanchez.copy()
del fields[FieldName.uid]
self.assertRaises(ValueError, self.makeRecord, fields)
fields = self.fields_wsanchez.copy()
fields[FieldName.uid] = ""
self.assertRaises(ValueError, self.makeRecord, fields)
def test_initWithNoRecordType(self):
fields = self.fields_wsanchez.copy()
del fields[FieldName.recordType]
self.assertRaises(ValueError, self.makeRecord, fields)
fields = self.fields_wsanchez.copy()
fields[FieldName.recordType] = ""
self.assertRaises(ValueError, self.makeRecord, fields)
def test_initWithNoShortNames(self):
fields = self.fields_wsanchez.copy()
del fields[FieldName.shortNames]
self.assertRaises(ValueError, self.makeRecord, fields)
fields = self.fields_wsanchez.copy()
fields[FieldName.shortNames] = ()
self.assertRaises(ValueError, self.makeRecord, fields)
fields = self.fields_wsanchez.copy()
fields[FieldName.shortNames] = ("",)
self.assertRaises(ValueError, self.makeRecord, fields)
fields = self.fields_wsanchez.copy()
fields[FieldName.shortNames] = ("wsanchez", "")
self.assertRaises(ValueError, self.makeRecord, fields)
def test_initWithBogusRecordType(self):
fields = self.fields_wsanchez.copy()
fields[FieldName.recordType] = object()
self.assertRaises(ValueError, self.makeRecord, fields)
def test_initNormalize(self):
sagen = self.makeRecord(self.fields_sagen)
self.assertEquals(
sagen.fields[FieldName.emailAddresses],
("sagen@calendarserver.org",)
)
def test_compare(self):
fields_glyphmod = self.fields_glyph.copy()
del fields_glyphmod[FieldName.emailAddresses]
plugh = DirectoryService("plugh")
wsanchez = self.makeRecord(self.fields_wsanchez)
wsanchezmod = self.makeRecord(self.fields_wsanchez, plugh)
glyph = self.makeRecord(self.fields_glyph)
glyphmod = self.makeRecord(fields_glyphmod)
self.assertEquals(wsanchez, wsanchez)
self.assertNotEqual(wsanchez, glyph)
self.assertNotEqual(glyph, glyphmod) # UID matches, other fields don't
self.assertNotEqual(glyphmod, wsanchez)
self.assertNotEqual(wsanchez, wsanchezmod) # Different service
def test_attributeAccess(self):
wsanchez = self.makeRecord(self.fields_wsanchez)
self.assertEquals(
wsanchez.recordType,
wsanchez.fields[FieldName.recordType]
)
self.assertEquals(
wsanchez.uid,
wsanchez.fields[FieldName.uid]
)
self.assertEquals(
wsanchez.shortNames,
wsanchez.fields[FieldName.shortNames]
)
self.assertEquals(
wsanchez.emailAddresses,
wsanchez.fields[FieldName.emailAddresses]
)
@inlineCallbacks
def test_members(self):
wsanchez = self.makeRecord(self.fields_wsanchez)
self.assertEquals(
set((yield wsanchez.members())),
set()
)
raise SkipTest("Subclasses should implement this test.")
def test_groups(self):
raise SkipTest("Subclasses should implement this test.")
class DirectoryRecordTest(unittest.TestCase, BaseDirectoryRecordTest):
def test_members(self):
wsanchez = self.makeRecord(self.fields_wsanchez)
self.assertEquals(
set((yield wsanchez.members())),
set()
)
staff = self.makeRecord(self.fields_staff)
self.assertFailure(staff.members(), NotImplementedError)
def test_groups(self):
wsanchez = self.makeRecord(self.fields_wsanchez)
self.assertFailure(wsanchez.groups(), NotImplementedError)
calendarserver-5.2+dfsg/twext/who/test/test_expression.py 0000644 0001750 0001750 00000003231 12263343324 023047 0 ustar rahul rahul ##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Directory service expression tests.
"""
from twisted.trial import unittest
from twext.who.idirectory import FieldName
from twext.who.expression import MatchExpression, MatchType, MatchFlags
class MatchExpressionTest(unittest.TestCase):
def test_repr(self):
self.assertEquals(
"",
repr(MatchExpression(
FieldName.fullNames,
"Wilfredo Sanchez",
)),
)
self.assertEquals(
"",
repr(MatchExpression(
FieldName.fullNames,
"Sanchez",
matchType=MatchType.contains,
)),
)
self.assertEquals(
"",
repr(MatchExpression(
FieldName.fullNames,
"Wilfredo",
matchType=MatchType.startsWith,
flags=MatchFlags.NOT,
)),
)
calendarserver-5.2+dfsg/twext/who/test/__init__.py 0000644 0001750 0001750 00000001213 12263343324 021346 0 ustar rahul rahul ##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Directory service integration tests
"""
calendarserver-5.2+dfsg/twext/who/aggregate.py 0000644 0001750 0001750 00000005436 12263343324 020571 0 ustar rahul rahul # -*- test-case-name: twext.who.test.test_aggregate -*-
##
# Copyright (c) 2006-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Directory service which aggregates multiple directory services.
"""
__all__ = [
"DirectoryService",
"DirectoryRecord",
]
from itertools import chain
from twisted.internet.defer import gatherResults, FirstError
from twext.who.idirectory import DirectoryConfigurationError
from twext.who.idirectory import IDirectoryService
from twext.who.index import DirectoryService as BaseDirectoryService
from twext.who.index import DirectoryRecord
from twext.who.util import ConstantsContainer
class DirectoryService(BaseDirectoryService):
"""
Aggregate directory service.
"""
def __init__(self, realmName, services):
recordTypes = set()
for service in services:
if not IDirectoryService.implementedBy(service.__class__):
raise ValueError(
"Not a directory service: {0}".format(service)
)
for recordType in service.recordTypes():
if recordType in recordTypes:
raise DirectoryConfigurationError(
"Aggregated services may not vend "
"the same record type: {0}"
.format(recordType)
)
recordTypes.add(recordType)
BaseDirectoryService.__init__(self, realmName)
self._services = tuple(services)
@property
def services(self):
return self._services
@property
def recordType(self):
if not hasattr(self, "_recordType"):
self._recordType = ConstantsContainer(chain(*tuple(
s.recordTypes()
for s in self.services
)))
return self._recordType
def recordsFromExpression(self, expression, records=None):
ds = []
for service in self.services:
d = service.recordsFromExpression(expression, records)
ds.append(d)
def unwrapFirstError(f):
f.trap(FirstError)
return f.value.subFailure
d = gatherResults(ds, consumeErrors=True)
d.addCallback(lambda results: chain(*results))
d.addErrback(unwrapFirstError)
return d
calendarserver-5.2+dfsg/twext/who/xml.py 0000644 0001750 0001750 00000030412 12263343324 017433 0 ustar rahul rahul # -*- test-case-name: twext.who.test.test_xml -*-
##
# Copyright (c) 2006-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
from __future__ import absolute_import
"""
XML directory service implementation.
"""
__all__ = [
"ParseError",
"DirectoryService",
"DirectoryRecord",
]
from time import time
from xml.etree.ElementTree import parse as parseXML
from xml.etree.ElementTree import ParseError as XMLParseError
from xml.etree.ElementTree import tostring as etreeToString
from xml.etree.ElementTree import Element as XMLElement
from twisted.python.constants import Values, ValueConstant
from twisted.internet.defer import fail
from twext.who.idirectory import DirectoryServiceError
from twext.who.idirectory import NoSuchRecordError, UnknownRecordTypeError
from twext.who.idirectory import RecordType, FieldName as BaseFieldName
from twext.who.index import DirectoryService as BaseDirectoryService
from twext.who.index import DirectoryRecord
from twext.who.index import FieldName as IndexFieldName
##
# Exceptions
##
class ParseError(DirectoryServiceError):
"""
Parse error.
"""
##
# XML constants
##
class Element(Values):
directory = ValueConstant("directory")
record = ValueConstant("record")
#
# Field names
#
uid = ValueConstant("uid")
uid.fieldName = BaseFieldName.uid
guid = ValueConstant("guid")
guid.fieldName = BaseFieldName.guid
shortName = ValueConstant("short-name")
shortName.fieldName = BaseFieldName.shortNames
fullName = ValueConstant("full-name")
fullName.fieldName = BaseFieldName.fullNames
emailAddress = ValueConstant("email")
emailAddress.fieldName = BaseFieldName.emailAddresses
password = ValueConstant("password")
password.fieldName = BaseFieldName.password
memberUID = ValueConstant("member-uid")
memberUID.fieldName = IndexFieldName.memberUIDs
class Attribute(Values):
realm = ValueConstant("realm")
recordType = ValueConstant("type")
class Value(Values):
#
# Booleans
#
true = ValueConstant("true")
false = ValueConstant("false")
#
# Record types
#
user = ValueConstant("user")
user.recordType = RecordType.user
group = ValueConstant("group")
group.recordType = RecordType.group
##
# Directory Service
##
class DirectoryService(BaseDirectoryService):
"""
XML directory service.
"""
element = Element
attribute = Attribute
value = Value
refreshInterval = 4
def __init__(self, filePath):
BaseDirectoryService.__init__(self, realmName=noRealmName)
self.filePath = filePath
def __repr__(self):
realmName = self._realmName
if realmName is None:
realmName = "(not loaded)"
else:
realmName = repr(realmName)
return (
"<{self.__class__.__name__} {realmName}>".format(
self=self,
realmName=realmName,
)
)
@property
def realmName(self):
self.loadRecords()
return self._realmName
@realmName.setter
def realmName(self, value):
if value is not noRealmName:
raise AssertionError("realmName may not be set directly")
@property
def unknownRecordTypes(self):
self.loadRecords()
return self._unknownRecordTypes
@property
def unknownFieldElements(self):
self.loadRecords()
return self._unknownFieldElements
def loadRecords(self, loadNow=False, stat=True):
"""
Load records from L{self.filePath}.
Does nothing if a successful refresh has happened within the
last L{self.refreshInterval} seconds.
@param loadNow: If true, load now (ignoring
L{self.refreshInterval})
@type loadNow: L{type}
@param stat: If true, check file metadata and don't reload if
unchanged.
@type loadNow: L{type}
"""
#
# Punt if we've read the file recently
#
now = time()
if not loadNow and now - self._lastRefresh <= self.refreshInterval:
return
#
# Punt if we've read the file and it's still the same.
#
if stat:
self.filePath.restat()
cacheTag = (
self.filePath.getModificationTime(),
self.filePath.getsize()
)
if cacheTag == self._cacheTag:
return
else:
cacheTag = None
#
# Open and parse the file
#
try:
fh = self.filePath.open()
try:
etree = parseXML(fh)
except XMLParseError as e:
raise ParseError(e)
finally:
fh.close()
#
# Pull data from DOM
#
directoryNode = etree.getroot()
if directoryNode.tag != self.element.directory.value:
raise ParseError(
"Incorrect root element: {0}".format(directoryNode.tag)
)
realmName = directoryNode.get(
self.attribute.realm.value, ""
).encode("utf-8")
if not realmName:
raise ParseError("No realm name.")
unknownRecordTypes = set()
unknownFieldElements = set()
records = set()
for recordNode in directoryNode:
try:
records.add(
self.parseRecordNode(recordNode, unknownFieldElements)
)
except UnknownRecordTypeError as e:
unknownRecordTypes.add(e.token)
#
# Store results
#
index = {}
for fieldName in self.indexedFields:
index[fieldName] = {}
for record in records:
for fieldName in self.indexedFields:
values = record.fields.get(fieldName, None)
if values is not None:
if not BaseFieldName.isMultiValue(fieldName):
values = (values,)
for value in values:
index[fieldName].setdefault(value, set()).add(record)
self._realmName = realmName
self._unknownRecordTypes = unknownRecordTypes
self._unknownFieldElements = unknownFieldElements
self._cacheTag = cacheTag
self._lastRefresh = now
self.index = index
return etree
def parseRecordNode(self, recordNode, unknownFieldElements=None):
recordTypeAttribute = recordNode.get(
self.attribute.recordType.value, ""
).encode("utf-8")
if recordTypeAttribute:
try:
recordType = (
self.value.lookupByValue(recordTypeAttribute).recordType
)
except (ValueError, AttributeError):
raise UnknownRecordTypeError(recordTypeAttribute)
else:
recordType = self.recordType.user
fields = {}
fields[self.fieldName.recordType] = recordType
for fieldNode in recordNode:
try:
fieldElement = self.element.lookupByValue(fieldNode.tag)
except ValueError:
if unknownFieldElements is not None:
unknownFieldElements.add(fieldNode.tag)
try:
fieldName = fieldElement.fieldName
except AttributeError:
if unknownFieldElements is not None:
unknownFieldElements.add(fieldNode.tag)
value = fieldNode.text.encode("utf-8")
if BaseFieldName.isMultiValue(fieldName):
values = fields.setdefault(fieldName, [])
values.append(value)
else:
fields[fieldName] = value
return DirectoryRecord(self, fields)
def _uidForRecordNode(self, recordNode):
uidNode = recordNode.find(self.element.uid.value)
if uidNode is None:
raise NotImplementedError("No UID node")
return uidNode.text
def flush(self):
BaseDirectoryService.flush(self)
self._realmName = None
self._unknownRecordTypes = None
self._unknownFieldElements = None
self._cacheTag = None
self._lastRefresh = 0
def updateRecords(self, records, create=False):
# Index the records to update by UID
recordsByUID = dict(((record.uid, record) for record in records))
# Index the record type -> attribute mappings.
recordTypes = {}
for valueName in self.value.iterconstants():
recordType = getattr(valueName, "recordType", None)
if recordType is not None:
recordTypes[recordType] = valueName.value
del valueName
# Index the field name -> element mappings.
fieldNames = {}
for elementName in self.element.iterconstants():
fieldName = getattr(elementName, "fieldName", None)
if fieldName is not None:
fieldNames[fieldName] = elementName.value
del elementName
directoryNode = self._directoryNodeForEditing()
def fillRecordNode(recordNode, record):
for (name, value) in record.fields.items():
if name == self.fieldName.recordType:
if value in recordTypes:
recordNode.set(
self.attribute.recordType.value,
recordTypes[value]
)
else:
raise AssertionError(
"Unknown record type: {0}".format(value)
)
else:
if name in fieldNames:
tag = fieldNames[name]
if BaseFieldName.isMultiValue(name):
values = value
else:
values = (value,)
for value in values:
subNode = XMLElement(tag)
subNode.text = value
recordNode.append(subNode)
else:
raise AssertionError(
"Unknown field name: {0!r}".format(name)
)
# Walk through the record nodes in the XML tree and apply
# updates.
for recordNode in directoryNode:
uid = self._uidForRecordNode(recordNode)
record = recordsByUID.get(uid, None)
if record:
recordNode.clear()
fillRecordNode(recordNode, record)
del recordsByUID[uid]
if recordsByUID:
if not create:
return fail(NoSuchRecordError(recordsByUID.keys()))
for uid, record in recordsByUID.items():
recordNode = XMLElement(self.element.record.value)
fillRecordNode(recordNode, record)
directoryNode.append(recordNode)
self._writeDirectoryNode(directoryNode)
def removeRecords(self, uids):
directoryNode = self._directoryNodeForEditing()
#
# Walk through the record nodes in the XML tree and start
# zapping.
#
for recordNode in directoryNode:
uid = self._uidForRecordNode(recordNode)
if uid in uids:
directoryNode.remove(recordNode)
self._writeDirectoryNode(directoryNode)
def _directoryNodeForEditing(self):
"""
Drop cached data and load the XML DOM.
"""
self.flush()
etree = self.loadRecords(loadNow=True)
return etree.getroot()
def _writeDirectoryNode(self, directoryNode):
self.filePath.setContent(etreeToString(directoryNode))
self.flush()
noRealmName = object()
calendarserver-5.2+dfsg/twext/who/idirectory.py 0000644 0001750 0001750 00000026231 12263343324 021014 0 ustar rahul rahul # -*- test-case-name: twext.who.test -*-
##
# Copyright (c) 2006-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Directory service interface.
"""
__all__ = [
"DirectoryServiceError",
"DirectoryConfigurationError",
"DirectoryAvailabilityError",
"UnknownRecordTypeError",
"QueryNotSupportedError",
"NoSuchRecordError",
"NotAllowedError",
"RecordType",
"FieldName",
"Operand",
"IDirectoryService",
"IDirectoryRecord",
]
from zope.interface import Attribute, Interface
from twisted.python.constants import Names, NamedConstant
#
# Exceptions
#
class DirectoryServiceError(Exception):
"""
Directory service generic error.
"""
class DirectoryConfigurationError(DirectoryServiceError):
"""
Directory configuration error.
"""
class DirectoryAvailabilityError(DirectoryServiceError):
"""
Directory not available.
"""
class UnknownRecordTypeError(DirectoryServiceError):
"""
Unknown record type.
"""
def __init__(self, token):
DirectoryServiceError.__init__(self, token)
self.token = token
class QueryNotSupportedError(DirectoryServiceError):
"""
Query not supported.
"""
class NoSuchRecordError(DirectoryServiceError):
"""
Record does not exist.
"""
class NotAllowedError(DirectoryServiceError):
"""
It seems you aren't permitted to do that.
"""
#
# Data Types
#
class RecordType(Names):
"""
Constants for common directory record types.
"""
user = NamedConstant()
group = NamedConstant()
user.description = "user"
group.description = "group"
class FieldName(Names):
"""
Constants for common directory record field names.
Fields as assciated with either a single value or an iterable of values.
@cvar uid: The primary unique identifier for a directory record.
The associated value must be a L{unicode}.
@cvar guid: The globally unique identifier for a directory record.
The associated value must be a L{UUID} or C{None}.
@cvar recordType: The type of a directory record.
The associated value must be a L{NamedConstant}.
@cvar shortNames: The short names for a directory record.
The associated values must L{unicode}s and there must be at least
one associated value.
@cvar fullNames: The full names for a directory record.
The associated values must be L{unicode}s.
@cvar emailAddresses: The email addresses for a directory record.
The associated values must be L{unicodes}.
@cvar password: The clear text password for a directory record.
The associated value must be a L{unicode} or C{None}.
"""
uid = NamedConstant()
guid = NamedConstant()
recordType = NamedConstant()
shortNames = NamedConstant()
fullNames = NamedConstant()
emailAddresses = NamedConstant()
password = NamedConstant()
uid.description = "UID"
guid.description = "GUID"
recordType.description = "record type"
shortNames.description = "short names"
fullNames.description = "full names"
emailAddresses.description = "email addresses"
password.description = "password"
shortNames.multiValue = True
fullNames.multiValue = True
emailAddresses.multiValue = True
@staticmethod
def isMultiValue(name):
"""
Check for whether a field is multi-value (as opposed to single-value).
@return: C{True} if the field is multi-value, C{False} otherwise.
@rtype: L{BOOL}
"""
return getattr(name, "multiValue", False)
class Operand(Names):
"""
Contants for common operands.
"""
OR = NamedConstant()
AND = NamedConstant()
OR.description = "or"
AND.description = "and"
#
# Interfaces
#
class IDirectoryService(Interface):
"""
Directory service.
A directory service is a service that vends information about
principals such as users, locations, printers, and other
resources. This information is provided in the form of directory
records.
A directory service can be queried for the types of records it
supports, and for specific records matching certain criteria.
A directory service may allow support the editing, removal and
addition of records.
Services are read-only should fail with L{NotAllowedError} in editing
methods.
The L{FieldName.uid} field, the L{FieldName.guid} field (if not C{None}),
and the combination of the L{FieldName.recordType} and
L{FieldName.shortName} fields must be unique to each directory record
vended by a directory service.
"""
realmName = Attribute(
"The name of the authentication realm this service represents."
)
def recordTypes():
"""
Get the record types supported by this directory service.
@return: The record types that are supported by this directory service.
@rtype: iterable of L{NamedConstant}s
"""
def recordsFromExpression(self, expression):
"""
Find records matching an expression.
@param expression: an expression to apply
@type expression: L{object}
@return: The matching records.
@rtype: deferred iterable of L{IDirectoryRecord}s
@raises: L{QueryNotSupportedError} if the expression is not
supported by this directory service.
"""
def recordsFromQuery(expressions, operand=Operand.AND):
"""
Find records by composing a query consisting of an iterable of
expressions and an operand.
@param expressions: expressions to query against
@type expressions: iterable of L{object}s
@param operand: an operand
@type operand: a L{NamedConstant}
@return: The matching records.
@rtype: deferred iterable of L{IDirectoryRecord}s
@raises: L{QueryNotSupportedError} if the query is not
supported by this directory service.
"""
def recordsWithFieldValue(fieldName, value):
"""
Find records that have the given field name with the given
value.
@param fieldName: a field name
@type fieldName: L{NamedConstant}
@param value: a value to match
@type value: L{bytes}
@return: The matching records.
@rtype: deferred iterable of L{IDirectoryRecord}s
"""
def recordWithUID(uid):
"""
Find the record that has the given UID.
@param uid: a UID
@type uid: L{bytes}
@return: The matching record or C{None} if there is no match.
@rtype: deferred L{IDirectoryRecord}s or C{None}
"""
def recordWithGUID(guid):
"""
Find the record that has the given GUID.
@param guid: a GUID
@type guid: L{bytes}
@return: The matching record or C{None} if there is no match.
@rtype: deferred L{IDirectoryRecord}s or C{None}
"""
def recordsWithRecordType(recordType):
"""
Find the records that have the given record type.
@param recordType: a record type
@type recordType: L{NamedConstant}
@return: The matching records.
@rtype: deferred iterable of L{IDirectoryRecord}s
"""
def recordWithShortName(recordType, shortName):
"""
Find the record that has the given record type and short name.
@param recordType: a record type
@type recordType: L{NamedConstant}
@param shortName: a short name
@type shortName: L{bytes}
@return: The matching record or C{None} if there is no match.
@rtype: deferred L{IDirectoryRecord}s or C{None}
"""
def recordsWithEmailAddress(emailAddress):
"""
Find the records that have the given email address.
@param emailAddress: an email address
@type emailAddress: L{bytes}
@return: The matching records.
@rtype: deferred iterable of L{IDirectoryRecord}s
"""
def updateRecords(records, create=False):
"""
Updates existing directory records.
@param records: the records to update
@type records: iterable of L{IDirectoryRecord}s
@param create: if true, create records if necessary
@type create: boolean
@return: unspecifiied
@rtype: deferred object
@raises L{NotAllowedError}: if the update is not allowed by the
directory service.
"""
def removeRecords(uids):
"""
Removes the records with the given UIDs.
@param uids: the UIDs of the records to remove
@type uids: iterable of L{bytes}
@return: unspecifiied
@rtype: deferred object
@raises L{NotAllowedError}: if the removal is not allowed by the
directory service.
"""
class IDirectoryRecord(Interface):
"""
Directory record.
A directory record corresponds to a principal, and contains
information about the principal such as idenfiers, names and
passwords.
This information is stored in a set of fields (a mapping of field
names and values).
Some fields allow for multiple values while others allow only one
value. This is discoverable by calling L{FieldName.isMultiValue}
on the field name.
The field L{FieldName.recordType} will be present in all directory
records, as all records must have a type. Which other fields are
required is implementation-specific.
Principals (called group principals) may have references to other
principals as members. Records representing group principals will
typically be records with the record type L{RecordType.group}, but
it is not prohibited for other record types to have members.
Fields may also be accessed as attributes. For example:
C{record.recordType} is equivalent to
C{record.fields[FieldName.recordType]}.
"""
service = Attribute("The L{IDirectoryService} this record exists in.")
fields = Attribute("A mapping with L{NamedConstant} keys.")
def members():
"""
Find the records that are members of this group. Only direct
members are included; members of members are not expanded.
@return: a deferred iterable of L{IDirectoryRecord}s which are
direct members of this group.
"""
def groups():
"""
Find the group records that this record is a member of. Only
groups for which this record is a direct member is are
included; membership is not expanded.
@return: a deferred iterable of L{IDirectoryRecord}s which are
groups that this record is a member of.
"""
calendarserver-5.2+dfsg/twext/who/__init__.py 0000644 0001750 0001750 00000001256 12263343324 020376 0 ustar rahul rahul # -*- test-case-name: twext.who.test -*-
##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Directory service integration
"""
calendarserver-5.2+dfsg/twext/who/util.py 0000644 0001750 0001750 00000004764 12263343324 017623 0 ustar rahul rahul # -*- test-case-name: twext.who.test.test_util -*-
##
# Copyright (c) 2013-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Directory service module utilities.
"""
__all__ = [
"ConstantsContainer",
"uniqueResult",
"describe",
"iterFlags",
]
from twisted.python.constants import FlagConstant
from twext.who.idirectory import DirectoryServiceError
class ConstantsContainer(object):
"""
A container for constants.
"""
def __init__(self, constants):
myConstants = {}
for constant in constants:
if constant.name in myConstants:
raise ValueError("Name conflict: {0}".format(constant.name))
myConstants[constant.name] = constant
self._constants = myConstants
def __getattr__(self, name):
try:
return self._constants[name]
except KeyError:
raise AttributeError(name)
def iterconstants(self):
return self._constants.itervalues()
def lookupByName(self, name):
try:
return self._constants[name]
except KeyError:
raise ValueError(name)
def uniqueResult(values):
result = None
for value in values:
if result is None:
result = value
else:
raise DirectoryServiceError(
"Multiple values found where one expected."
)
return result
def describe(constant):
if isinstance(constant, FlagConstant):
parts = []
for flag in iterFlags(constant):
parts.append(getattr(flag, "description", flag.name))
return "|".join(parts)
else:
return getattr(constant, "description", constant.name)
def iterFlags(flags):
if hasattr(flags, "__iter__"):
return flags
else:
# Work around http://twistedmatrix.com/trac/ticket/6302
# FIXME: This depends on a private attribute (flags._container)
return (flags._container.lookupByName(name) for name in flags.names)
calendarserver-5.2+dfsg/twext/web2/ 0000755 0001750 0001750 00000000000 12322625326 016324 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/client/ 0000755 0001750 0001750 00000000000 12322625326 017602 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/client/interfaces.py 0000644 0001750 0001750 00000004760 12263343324 022305 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_client -*-
##
# Copyright (c) 2007 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
from zope.interface import Interface
class IHTTPClientManager(Interface):
"""I coordinate between multiple L{HTTPClientProtocol} objects connected to a
single server to facilite request queuing and pipelining.
"""
def clientBusy(proto):
"""Called when the L{HTTPClientProtocol} doesn't want to accept anymore
requests.
@param proto: The L{HTTPClientProtocol} that is changing state.
@type proto: L{HTTPClientProtocol}
"""
pass
def clientIdle(proto):
"""Called when an L{HTTPClientProtocol} is able to accept more requests.
@param proto: The L{HTTPClientProtocol} that is changing state.
@type proto: L{HTTPClientProtocol}
"""
pass
def clientPipelining(proto):
"""Called when the L{HTTPClientProtocol} determines that it is able to
support request pipelining.
@param proto: The L{HTTPClientProtocol} that is changing state.
@type proto: L{HTTPClientProtocol}
"""
pass
def clientGone(proto):
"""Called when the L{HTTPClientProtocol} disconnects from the server.
@param proto: The L{HTTPClientProtocol} that is changing state.
@type proto: L{HTTPClientProtocol}
"""
pass
calendarserver-5.2+dfsg/twext/web2/client/http.py 0000644 0001750 0001750 00000030406 12263343324 021135 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_client -*-
##
# Copyright (c) 2001-2007 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
Client-side HTTP implementation.
"""
from zope.interface import implements
from twisted.internet.defer import Deferred
from twisted.protocols.basic import LineReceiver
from twisted.protocols.policies import TimeoutMixin
from twext.web2.responsecode import BAD_REQUEST, HTTP_VERSION_NOT_SUPPORTED
from twext.web2.http import parseVersion, Response
from twext.web2.http_headers import Headers
from twext.web2.stream import ProducerStream, StreamProducer, IByteStream
from twext.web2.channel.http import HTTPParser, PERSIST_NO_PIPELINE, PERSIST_PIPELINE
from twext.web2.client.interfaces import IHTTPClientManager
class ProtocolError(Exception):
"""
Exception raised when a HTTP error happened.
"""
class ClientRequest(object):
"""
A class for describing an HTTP request to be sent to the server.
"""
def __init__(self, method, uri, headers, stream):
"""
@param method: The HTTP method to for this request, ex: 'GET', 'HEAD',
'POST', etc.
@type method: C{str}
@param uri: The URI of the resource to request, this may be absolute or
relative, however the interpretation of this URI is left up to the
remote server.
@type uri: C{str}
@param headers: Headers to be sent to the server. It is important to
note that this object does not create any implicit headers. So it
is up to the HTTP Client to add required headers such as 'Host'.
@type headers: C{dict}, L{twext.web2.http_headers.Headers}, or
C{None}
@param stream: Content body to send to the remote HTTP server.
@type stream: L{twext.web2.stream.IByteStream}
"""
self.method = method
self.uri = uri
if isinstance(headers, Headers):
self.headers = headers
else:
self.headers = Headers(headers or {})
if stream is not None:
self.stream = IByteStream(stream)
else:
self.stream = None
class HTTPClientChannelRequest(HTTPParser):
parseCloseAsEnd = True
outgoing_version = "HTTP/1.1"
chunkedOut = False
finished = False
closeAfter = False
def __init__(self, channel, request, closeAfter):
HTTPParser.__init__(self, channel)
self.request = request
self.closeAfter = closeAfter
self.transport = self.channel.transport
self.responseDefer = Deferred()
def submit(self):
l = []
request = self.request
if request.method == "HEAD":
# No incoming data will arrive.
self.length = 0
l.append('%s %s %s\r\n' % (request.method, request.uri,
self.outgoing_version))
if request.headers is not None:
for name, valuelist in request.headers.getAllRawHeaders():
for value in valuelist:
l.append("%s: %s\r\n" % (name, value))
if request.stream is not None:
if request.stream.length is not None:
l.append("%s: %s\r\n" % ('Content-Length', request.stream.length))
else:
# Got a stream with no length. Send as chunked and hope, against
# the odds, that the server actually supports chunked uploads.
l.append("%s: %s\r\n" % ('Transfer-Encoding', 'chunked'))
self.chunkedOut = True
if self.closeAfter:
l.append("%s: %s\r\n" % ('Connection', 'close'))
else:
l.append("%s: %s\r\n" % ('Connection', 'Keep-Alive'))
l.append("\r\n")
self.transport.writeSequence(l)
d = StreamProducer(request.stream).beginProducing(self)
d.addCallback(self._finish).addErrback(self._error)
def registerProducer(self, producer, streaming):
"""
Register a producer.
"""
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
self.transport.unregisterProducer()
def write(self, data):
if not data:
return
elif self.chunkedOut:
self.transport.writeSequence(("%X\r\n" % len(data), data, "\r\n"))
else:
self.transport.write(data)
def _finish(self, x):
"""
We are finished writing data.
"""
if self.chunkedOut:
# write last chunk and closing CRLF
self.transport.write("0\r\n\r\n")
self.finished = True
self.channel.requestWriteFinished(self)
del self.transport
def _error(self, err):
"""
Abort parsing, and depending of the status of the request, either fire
the C{responseDefer} if no response has been sent yet, or close the
stream.
"""
self.abortParse()
if hasattr(self, 'stream') and self.stream is not None:
self.stream.finish(err)
else:
self.responseDefer.errback(err)
def _abortWithError(self, errcode, text):
"""
Abort parsing by forwarding a C{ProtocolError} to C{_error}.
"""
self._error(ProtocolError(text))
def connectionLost(self, reason):
self._error(reason)
def gotInitialLine(self, initialLine):
parts = initialLine.split(' ', 2)
# Parse the initial request line
if len(parts) != 3:
self._abortWithError(BAD_REQUEST,
"Bad response line: %s" % (initialLine,))
return
strversion, self.code, message = parts
try:
protovers = parseVersion(strversion)
if protovers[0] != 'http':
raise ValueError()
except ValueError:
self._abortWithError(BAD_REQUEST,
"Unknown protocol: %s" % (strversion,))
return
self.version = protovers[1:3]
# Ensure HTTP 0 or HTTP 1.
if self.version[0] != 1:
self._abortWithError(HTTP_VERSION_NOT_SUPPORTED,
'Only HTTP 1.x is supported.')
return
## FIXME: Actually creates Response, function is badly named!
def createRequest(self):
self.stream = ProducerStream(self.length)
self.response = Response(self.code, self.inHeaders, self.stream)
self.stream.registerProducer(self, True)
del self.inHeaders
## FIXME: Actually processes Response, function is badly named!
def processRequest(self):
self.responseDefer.callback(self.response)
def handleContentChunk(self, data):
self.stream.write(data)
def handleContentComplete(self):
self.stream.finish()
class EmptyHTTPClientManager(object):
"""
A dummy HTTPClientManager. It doesn't do any client management, and is
meant to be used only when creating an HTTPClientProtocol directly.
"""
implements(IHTTPClientManager)
def clientBusy(self, proto):
pass
def clientIdle(self, proto):
pass
def clientPipelining(self, proto):
pass
def clientGone(self, proto):
pass
class HTTPClientProtocol(LineReceiver, TimeoutMixin, object):
"""
A HTTP 1.1 Client with request pipelining support.
"""
chanRequest = None
maxHeaderLength = 10240
firstLine = 1
readPersistent = PERSIST_NO_PIPELINE
# inputTimeOut should be pending whenever a complete request has
# been written but the complete response has not yet been
# received, and be reset every time data is received.
inputTimeOut = 60 * 4
def __init__(self, manager=None):
"""
@param manager: The object this client reports it state to.
@type manager: L{IHTTPClientManager}
"""
self.outRequest = None
self.inRequests = []
if manager is None:
manager = EmptyHTTPClientManager()
self.manager = manager
def lineReceived(self, line):
if not self.inRequests:
# server sending random unrequested data.
self.transport.loseConnection()
return
# If not currently writing this request, set timeout
if self.inRequests[0] is not self.outRequest:
self.setTimeout(self.inputTimeOut)
if self.firstLine:
self.firstLine = 0
self.inRequests[0].gotInitialLine(line)
else:
self.inRequests[0].lineReceived(line)
def rawDataReceived(self, data):
if not self.inRequests:
# Server sending random unrequested data.
self.transport.loseConnection()
return
# If not currently writing this request, set timeout
if self.inRequests[0] is not self.outRequest:
self.setTimeout(self.inputTimeOut)
self.inRequests[0].rawDataReceived(data)
def submitRequest(self, request, closeAfter=True):
"""
@param request: The request to send to a remote server.
@type request: L{ClientRequest}
@param closeAfter: If True the 'Connection: close' header will be sent,
otherwise 'Connection: keep-alive'
@type closeAfter: C{bool}
@rtype: L{twisted.internet.defer.Deferred}
@return: A Deferred which will be called back with the
L{twext.web2.http.Response} from the server.
"""
# Assert we're in a valid state to submit more
assert self.outRequest is None
assert ((self.readPersistent is PERSIST_NO_PIPELINE
and not self.inRequests)
or self.readPersistent is PERSIST_PIPELINE)
self.manager.clientBusy(self)
if closeAfter:
self.readPersistent = False
self.outRequest = chanRequest = HTTPClientChannelRequest(self,
request, closeAfter)
self.inRequests.append(chanRequest)
chanRequest.submit()
return chanRequest.responseDefer
def requestWriteFinished(self, request):
assert request is self.outRequest
self.outRequest = None
# Tell the manager if more requests can be submitted.
self.setTimeout(self.inputTimeOut)
if self.readPersistent is PERSIST_PIPELINE:
self.manager.clientPipelining(self)
def requestReadFinished(self, request):
assert self.inRequests[0] is request
del self.inRequests[0]
self.firstLine = True
if not self.inRequests:
if self.readPersistent:
self.setTimeout(None)
self.manager.clientIdle(self)
else:
self.transport.loseConnection()
def setReadPersistent(self, persist):
self.readPersistent = persist
if not persist:
# Tell all requests but first to abort.
for request in self.inRequests[1:]:
request.connectionLost(None)
del self.inRequests[1:]
def connectionLost(self, reason):
self.readPersistent = False
self.setTimeout(None)
self.manager.clientGone(self)
# Tell all requests to abort.
for request in self.inRequests:
if request is not None:
request.connectionLost(reason)
calendarserver-5.2+dfsg/twext/web2/client/__init__.py 0000644 0001750 0001750 00000002410 12263343324 021707 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_client -*-
##
# Copyright (c) 2004 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
Twisted.web2.client: Client Implementation
"""
calendarserver-5.2+dfsg/twext/web2/channel/ 0000755 0001750 0001750 00000000000 12322625325 017733 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/channel/http.py 0000644 0001750 0001750 00000125051 12263343324 021270 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_http -*-
##
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
# Copyright (c) 2008-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
import time
import warnings
import socket
from random import randint
from cStringIO import StringIO
from zope.interface import implements
from twisted.internet import interfaces, protocol, reactor
from twisted.internet.defer import succeed, Deferred
from twisted.protocols import policies, basic
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2 import http_headers
from twext.web2 import http
from twext.web2.http import RedirectResponse
from twext.web2.server import Request
from twistedcaldav.config import config
from twistedcaldav import accounting
log = Logger()
class OverloadedLoggingServerProtocol (protocol.Protocol):
def __init__(self, retryAfter, outstandingRequests):
self.retryAfter = retryAfter
self.outstandingRequests = outstandingRequests
def connectionMade(self):
log.info(overloaded=self)
self.transport.write(
"HTTP/1.0 503 Service Unavailable\r\n"
"Content-Type: text/html\r\n"
)
if self.retryAfter:
self.transport.write(
"Retry-After: %s\r\n" % (self.retryAfter,)
)
self.transport.write(
"Connection: close\r\n\r\n"
"Service Unavailable"
"Service Unavailable
"
"The server is currently overloaded, "
"please try again later."
)
self.transport.loseConnection()
class SSLRedirectRequest(Request):
"""
An L{SSLRedirectRequest} prevents processing if the request is over plain
HTTP; instead, it redirects to HTTPS.
"""
def process(self):
ignored, secure = self.chanRequest.getHostInfo()
if not secure:
if config.SSLPort == 443:
location = (
"https://%s%s"
% (config.ServerHostName, self.uri)
)
else:
location = (
"https://%s:%d%s"
% (config.ServerHostName, config.SSLPort, self.uri)
)
return super(SSLRedirectRequest, self).writeResponse(
RedirectResponse(location)
)
else:
return super(SSLRedirectRequest, self).process()
# >%
PERSIST_NO_PIPELINE, PERSIST_PIPELINE = (1,2)
_cachedHostNames = {}
def _cachedGetHostByAddr(hostaddr):
hostname = _cachedHostNames.get(hostaddr)
if hostname is None:
try:
hostname = socket.gethostbyaddr(hostaddr)[0]
except socket.herror:
hostname = hostaddr
_cachedHostNames[hostaddr]=hostname
return hostname
class StringTransport(object):
"""
I am a StringIO wrapper that conforms for the transport API. I support
the 'writeSequence' method.
"""
def __init__(self):
self.s = StringIO()
def writeSequence(self, seq):
self.s.write(''.join(seq))
def __getattr__(self, attr):
return getattr(self.__dict__['s'], attr)
class AbortedException(Exception):
pass
class HTTPParser(object):
"""This class handles the parsing side of HTTP processing. With a suitable
subclass, it can parse either the client side or the server side of the
connection.
"""
# Class config:
parseCloseAsEnd = False
# Instance vars
chunkedIn = False
headerlen = 0
length = None
inHeaders = None
partialHeader = ''
connHeaders = None
finishedReading = False
channel = None
# For subclassing...
# Needs attributes:
# version
# Needs functions:
# createRequest()
# processRequest()
# _abortWithError()
# handleContentChunk(data)
# handleContentComplete()
# Needs functions to exist on .channel
# channel.maxHeaderLength
# channel.requestReadFinished(self)
# channel.setReadPersistent(self, persistent)
# (from LineReceiver):
# channel.setRawMode()
# channel.setLineMode(extraneous)
# channel.pauseProducing()
# channel.resumeProducing()
# channel.stopProducing()
def __init__(self, channel):
self.inHeaders = http_headers.Headers()
self.channel = channel
def lineReceived(self, line):
if self.chunkedIn:
# Parsing a chunked input
if self.chunkedIn == 1:
# First we get a line like "chunk-size [';' chunk-extension]"
# (where chunk extension is just random crap as far as we're concerned)
# RFC says to ignore any extensions you don't recognize -- that's all of them.
chunksize = line.split(';', 1)[0]
try:
self.length = int(chunksize, 16)
except:
self._abortWithError(responsecode.BAD_REQUEST, "Invalid chunk size, not a hex number: %s!" % chunksize)
if self.length < 0:
self._abortWithError(responsecode.BAD_REQUEST, "Invalid chunk size, negative.")
if self.length == 0:
# We're done, parse the trailers line
self.chunkedIn = 3
else:
# Read self.length bytes of raw data
self.channel.setRawMode()
elif self.chunkedIn == 2:
# After we got data bytes of the appropriate length, we end up here,
# waiting for the CRLF, then go back to get the next chunk size.
if line != '':
self._abortWithError(responsecode.BAD_REQUEST, "Excess %d bytes sent in chunk transfer mode" % len(line))
self.chunkedIn = 1
elif self.chunkedIn == 3:
# TODO: support Trailers (maybe! but maybe not!)
# After getting the final "0" chunk we're here, and we *EAT MERCILESSLY*
# any trailer headers sent, and wait for the blank line to terminate the
# request.
if line == '':
self.allContentReceived()
# END of chunk handling
elif line == '':
# Empty line => End of headers
if self.partialHeader:
self.headerReceived(self.partialHeader)
self.partialHeader = ''
self.allHeadersReceived() # can set chunkedIn
self.createRequest()
if self.chunkedIn:
# stay in linemode waiting for chunk header
pass
elif self.length == 0:
# no content expected
self.allContentReceived()
else:
# await raw data as content
self.channel.setRawMode()
# Should I do self.pauseProducing() here?
self.processRequest()
else:
self.headerlen += len(line)
if self.headerlen > self.channel.maxHeaderLength:
self._abortWithError(responsecode.BAD_REQUEST, 'Headers too long.')
if line[0] in ' \t':
# Append a header continuation
self.partialHeader += line
else:
if self.partialHeader:
self.headerReceived(self.partialHeader)
self.partialHeader = line
def rawDataReceived(self, data):
"""Handle incoming content."""
datalen = len(data)
if datalen < self.length:
self.handleContentChunk(data)
self.length = self.length - datalen
else:
self.handleContentChunk(data[:self.length])
extraneous = data[self.length:]
channel = self.channel # could go away from allContentReceived.
if not self.chunkedIn:
self.allContentReceived()
else:
# NOTE: in chunked mode, self.length is the size of the current chunk,
# so we still have more to read.
self.chunkedIn = 2 # Read next chunksize
channel.setLineMode(extraneous)
def headerReceived(self, line):
"""
Store this header away. Check for too much header data (>
channel.maxHeaderLength) and non-ASCII characters; abort the
connection with C{BAD_REQUEST} if so.
"""
nameval = line.split(':', 1)
if len(nameval) != 2:
self._abortWithError(responsecode.BAD_REQUEST, "No ':' in header.")
name, val = nameval
for field in name, val:
try:
field.decode('ascii')
except UnicodeDecodeError:
self._abortWithError(responsecode.BAD_REQUEST,
"Headers must be ASCII")
val = val.lstrip(' \t')
self.inHeaders.addRawHeader(name, val)
def allHeadersReceived(self):
# Split off connection-related headers
connHeaders = self.splitConnectionHeaders()
# Set connection parameters from headers
self.setConnectionParams(connHeaders)
self.connHeaders = connHeaders
def allContentReceived(self):
self.finishedReading = True
self.channel.requestReadFinished(self)
self.handleContentComplete()
def splitConnectionHeaders(self):
"""
Split off connection control headers from normal headers.
The normal headers are then passed on to user-level code, while the
connection headers are stashed in .connHeaders and used for things like
request/response framing.
This corresponds roughly with the HTTP RFC's description of 'hop-by-hop'
vs 'end-to-end' headers in RFC2616 S13.5.1, with the following
exceptions:
- proxy-authenticate and proxy-authorization are not treated as
connection headers.
- content-length is, as it is intimately related with low-level HTTP
parsing, and is made available to user-level code via the stream
length, rather than a header value. (except for HEAD responses, in
which case it is NOT used by low-level HTTP parsing, and IS kept in
the normal headers.
"""
def move(name):
h = inHeaders.getRawHeaders(name, None)
if h is not None:
inHeaders.removeHeader(name)
connHeaders.setRawHeaders(name, h)
# NOTE: According to HTTP spec, we're supposed to eat the
# 'Proxy-Authenticate' and 'Proxy-Authorization' headers also, but that
# doesn't sound like a good idea to me, because it makes it impossible
# to have a non-authenticating transparent proxy in front of an
# authenticating proxy. An authenticating proxy can eat them itself.
#
# 'Proxy-Connection' is an undocumented HTTP 1.0 abomination.
connHeaderNames = ['content-length', 'connection', 'keep-alive', 'te',
'trailers', 'transfer-encoding', 'upgrade',
'proxy-connection']
inHeaders = self.inHeaders
connHeaders = http_headers.Headers()
move('connection')
if self.version < (1,1):
# Remove all headers mentioned in Connection, because a HTTP 1.0
# proxy might have erroneously forwarded it from a 1.1 client.
for name in connHeaders.getHeader('connection', ()):
if inHeaders.hasHeader(name):
inHeaders.removeHeader(name)
else:
# Otherwise, just add the headers listed to the list of those to move
connHeaderNames.extend(connHeaders.getHeader('connection', ()))
# If the request was HEAD, self.length has been set to 0 by
# HTTPClientRequest.submit; in this case, Content-Length should
# be treated as a response header, not a connection header.
# Note: this assumes the invariant that .length will always be None
# coming into this function, unless this is a HEAD request.
if self.length is not None:
connHeaderNames.remove('content-length')
for headername in connHeaderNames:
move(headername)
return connHeaders
def setConnectionParams(self, connHeaders):
# Figure out persistent connection stuff
if self.version >= (1,1):
if 'close' in connHeaders.getHeader('connection', ()):
readPersistent = False
else:
readPersistent = PERSIST_PIPELINE
elif 'keep-alive' in connHeaders.getHeader('connection', ()):
readPersistent = PERSIST_NO_PIPELINE
else:
readPersistent = False
# Okay, now implement section 4.4 Message Length to determine
# how to find the end of the incoming HTTP message.
transferEncoding = connHeaders.getHeader('transfer-encoding')
if transferEncoding:
if transferEncoding[-1] == 'chunked':
# Chunked
self.chunkedIn = 1
# Cut off the chunked encoding (cause it's special)
transferEncoding = transferEncoding[:-1]
elif not self.parseCloseAsEnd:
# Would close on end of connection, except this can't happen for
# client->server data. (Well..it could actually, since TCP has half-close
# but the HTTP spec says it can't, so we'll pretend it's right.)
self._abortWithError(responsecode.BAD_REQUEST, "Transfer-Encoding received without chunked in last position.")
# TODO: support gzip/etc encodings.
# FOR NOW: report an error if the client uses any encodings.
# They shouldn't, because we didn't send a TE: header saying it's okay.
if transferEncoding:
self._abortWithError(responsecode.NOT_IMPLEMENTED, "Transfer-Encoding %s not supported." % transferEncoding)
else:
# No transfer-coding.
self.chunkedIn = 0
if self.parseCloseAsEnd:
# If no Content-Length, then it's indeterminate length data
# (unless the responsecode was one of the special no body ones)
# Also note that for HEAD requests, connHeaders won't have
# content-length even if the response did.
if self.code in http.NO_BODY_CODES:
self.length = 0
else:
self.length = connHeaders.getHeader('content-length', self.length)
# If it's an indeterminate stream without transfer encoding, it must be
# the last request.
if self.length is None:
readPersistent = False
else:
# If no Content-Length either, assume no content.
self.length = connHeaders.getHeader('content-length', 0)
# Set the calculated persistence
self.channel.setReadPersistent(readPersistent)
def abortParse(self):
# If we're erroring out while still reading the request
if not self.finishedReading:
self.finishedReading = True
self.channel.setReadPersistent(False)
self.channel.requestReadFinished(self)
# producer interface
def pauseProducing(self):
if not self.finishedReading:
self.channel.pauseProducing()
def resumeProducing(self):
if not self.finishedReading:
self.channel.resumeProducing()
def stopProducing(self):
if not self.finishedReading:
self.channel.stopProducing()
class HTTPChannelRequest(HTTPParser):
"""This class handles the state and parsing for one HTTP request.
It is responsible for all the low-level connection oriented behavior.
Thus, it takes care of keep-alive, de-chunking, etc., and passes
the non-connection headers on to the user-level Request object."""
command = path = version = None
queued = 0
request = None
out_version = "HTTP/1.1"
def __init__(self, channel, queued=0):
HTTPParser.__init__(self, channel)
self.queued=queued
# Buffer writes to a string until we're first in line
# to write a response
if queued:
self.transport = StringTransport()
else:
self.transport = self.channel.transport
# set the version to a fallback for error generation
self.version = (1,0)
def gotInitialLine(self, initialLine):
parts = initialLine.split()
# Parse the initial request line
if len(parts) != 3:
if len(parts) == 1:
parts.append('/')
if len(parts) == 2 and parts[1][0] == '/':
parts.append('HTTP/0.9')
else:
self._abortWithError(responsecode.BAD_REQUEST, 'Bad request line: %s' % initialLine)
self.command, self.path, strversion = parts
try:
protovers = http.parseVersion(strversion)
if protovers[0] != 'http':
raise ValueError()
except ValueError:
self._abortWithError(responsecode.BAD_REQUEST, "Unknown protocol: %s" % strversion)
self.version = protovers[1:3]
# Ensure HTTP 0 or HTTP 1.
if self.version[0] > 1:
self._abortWithError(responsecode.HTTP_VERSION_NOT_SUPPORTED, 'Only HTTP 0.9 and HTTP 1.x are supported.')
if self.version[0] == 0:
# simulate end of headers, as HTTP 0 doesn't have headers.
self.lineReceived('')
def lineLengthExceeded(self, line, wasFirst=False):
code = wasFirst and responsecode.REQUEST_URI_TOO_LONG or responsecode.BAD_REQUEST
self._abortWithError(code, 'Header line too long.')
def createRequest(self):
self.request = self.channel.requestFactory(self, self.command, self.path, self.version, self.length, self.inHeaders)
del self.inHeaders
def processRequest(self):
self.request.process()
def handleContentChunk(self, data):
self.request.handleContentChunk(data)
def handleContentComplete(self):
self.request.handleContentComplete()
############## HTTPChannelRequest *RESPONSE* methods #############
producer = None
chunkedOut = False
finished = False
##### Request Callbacks #####
def writeIntermediateResponse(self, code, headers=None):
if self.version >= (1,1):
self._writeHeaders(code, headers, False)
def writeHeaders(self, code, headers):
self._writeHeaders(code, headers, True)
def _writeHeaders(self, code, headers, addConnectionHeaders):
# HTTP 0.9 doesn't have headers.
if self.version[0] == 0:
return
l = []
code_message = responsecode.RESPONSES.get(code, "Unknown Status")
l.append('%s %s %s\r\n' % (self.out_version, code,
code_message))
if headers is not None:
for name, valuelist in headers.getAllRawHeaders():
for value in valuelist:
l.append("%s: %s\r\n" % (name, value))
if addConnectionHeaders:
# if we don't have a content length, we send data in
# chunked mode, so that we can support persistent connections.
if (headers.getHeader('content-length') is None and
self.command != "HEAD" and code not in http.NO_BODY_CODES):
if self.version >= (1,1):
l.append("%s: %s\r\n" % ('Transfer-Encoding', 'chunked'))
self.chunkedOut = True
else:
# Cannot use persistent connections if we can't do chunking
self.channel.dropQueuedRequests()
if self.channel.isLastRequest(self):
l.append("%s: %s\r\n" % ('Connection', 'close'))
elif self.version < (1,1):
l.append("%s: %s\r\n" % ('Connection', 'Keep-Alive'))
l.append("\r\n")
self.transport.writeSequence(l)
def write(self, data):
if not data:
return
elif self.chunkedOut:
self.transport.writeSequence(("%X\r\n" % len(data), data, "\r\n"))
else:
self.transport.write(data)
def finish(self):
"""We are finished writing data."""
if self.finished:
warnings.warn("Warning! request.finish called twice.", stacklevel=2)
return
if self.chunkedOut:
# write last chunk and closing CRLF
self.transport.write("0\r\n\r\n")
self.finished = True
if not self.queued:
self._cleanup()
def abortConnection(self, closeWrite=True):
"""Abort the HTTP connection because of some kind of unrecoverable
error. If closeWrite=False, then only abort reading, but leave
the writing side alone. This is mostly for internal use by
the HTTP request parsing logic, so that it can call an error
page generator.
Otherwise, completely shut down the connection.
"""
self.abortParse()
if closeWrite:
if self.producer:
self.producer.stopProducing()
self.unregisterProducer()
self.finished = True
if self.queued:
self.transport.reset()
self.transport.truncate()
else:
self._cleanup()
def getHostInfo(self):
return self.channel._host, self.channel._secure
def getRemoteHost(self):
return self.channel.transport.getPeer()
##### End Request Callbacks #####
def _abortWithError(self, errorcode, text=''):
"""Handle low level protocol errors."""
headers = http_headers.Headers()
headers.setHeader('content-length', len(text)+1)
self.abortConnection(closeWrite=False)
self.writeHeaders(errorcode, headers)
self.write(text)
self.write("\n")
self.finish()
log.warn("Aborted request (%d) %s" % (errorcode, text))
raise AbortedException
def _cleanup(self):
"""Called when have finished responding and are no longer queued."""
if self.producer:
log.error(RuntimeError("Producer was not unregistered for %s" % self))
self.unregisterProducer()
self.channel.requestWriteFinished(self)
del self.transport
# methods for channel - end users should not use these
def noLongerQueued(self):
"""Notify the object that it is no longer queued.
We start writing whatever data we have to the transport, etc.
This method is not intended for users.
"""
if not self.queued:
raise RuntimeError, "noLongerQueued() got called unnecessarily."
self.queued = 0
# set transport to real one and send any buffer data
data = self.transport.getvalue()
self.transport = self.channel.transport
if data:
self.transport.write(data)
# if we have producer, register it with transport
if (self.producer is not None) and not self.finished:
self.transport.registerProducer(self.producer, True)
# if we're finished, clean up
if self.finished:
self._cleanup()
# consumer interface
def registerProducer(self, producer, streaming):
"""Register a producer.
"""
if self.producer:
raise ValueError, "registering producer %s before previous one (%s) was unregistered" % (producer, self.producer)
self.producer = producer
if self.queued:
producer.pauseProducing()
else:
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
"""Unregister the producer."""
if not self.queued:
self.transport.unregisterProducer()
self.producer = None
def connectionLost(self, reason):
"""connection was lost"""
if self.queued and self.producer:
self.producer.stopProducing()
self.producer = None
if self.request:
self.request.connectionLost(reason)
class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin, object):
"""A receiver for HTTP requests. Handles splitting up the connection
for the multiple HTTPChannelRequests that may be in progress on this
channel.
@ivar timeOut: number of seconds to wait before terminating an
idle connection.
@ivar maxPipeline: number of outstanding in-progress requests
to allow before pausing the input.
@ivar maxHeaderLength: number of bytes of header to accept from
the client.
"""
implements(interfaces.IHalfCloseableProtocol)
## Configuration parameters. Set in instances or subclasses.
# How many simultaneous requests to handle.
maxPipeline = 4
# Timeout when between two requests
betweenRequestsTimeOut = 15
# Timeout between lines or bytes while reading a request
inputTimeOut = 60 * 4
# Timeout between end of request read and end of response write
idleTimeOut = 60 * 5
# Timeout when closing non-persistent connection
closeTimeOut = 20
# maximum length of headers (10KiB)
maxHeaderLength = 10240
# Allow persistent connections?
allowPersistentConnections = True
# ChannelRequest
chanRequestFactory = HTTPChannelRequest
requestFactory = http.Request
_first_line = 2
readPersistent = PERSIST_PIPELINE
_readLost = False
_writeLost = False
_abortTimer = None
chanRequest = None
def _callLater(self, secs, fun):
reactor.callLater(secs, fun)
def __init__(self):
# the request queue
self.requests = []
def connectionMade(self):
self._secure = interfaces.ISSLTransport(self.transport, None) is not None
address = self.transport.getHost()
self._host = _cachedGetHostByAddr(address.host)
self.setTimeout(self.inputTimeOut)
self.factory.addConnectedChannel(self)
def lineReceived(self, line):
if self._first_line:
self.setTimeout(self.inputTimeOut)
# if this connection is not persistent, drop any data which
# the client (illegally) sent after the last request.
if not self.readPersistent:
self.dataReceived = self.lineReceived = lambda *args: None
return
# IE sends an extraneous empty line (\r\n) after a POST request;
# eat up such a line, but only ONCE
if not line and self._first_line == 1:
self._first_line = 2
return
self._first_line = 0
if not self.allowPersistentConnections:
# Don't allow a second request
self.readPersistent = False
try:
self.chanRequest = self.chanRequestFactory(self, len(self.requests))
self.requests.append(self.chanRequest)
self.chanRequest.gotInitialLine(line)
except AbortedException:
pass
else:
try:
self.chanRequest.lineReceived(line)
except AbortedException:
pass
def lineLengthExceeded(self, line):
if self._first_line:
# Fabricate a request object to respond to the line length violation.
self.chanRequest = self.chanRequestFactory(self,
len(self.requests))
self.requests.append(self.chanRequest)
self.chanRequest.gotInitialLine("GET fake HTTP/1.0")
try:
self.chanRequest.lineLengthExceeded(line, self._first_line)
except AbortedException:
pass
def rawDataReceived(self, data):
self.setTimeout(self.inputTimeOut)
try:
self.chanRequest.rawDataReceived(data)
except AbortedException:
pass
def requestReadFinished(self, request):
if(self.readPersistent is PERSIST_NO_PIPELINE or
len(self.requests) >= self.maxPipeline):
self.pauseProducing()
# reset state variables
self._first_line = 1
self.chanRequest = None
self.setLineMode()
# Set an idle timeout, in case this request takes a long
# time to finish generating output.
if len(self.requests) > 0:
self.setTimeout(self.idleTimeOut)
def _startNextRequest(self):
# notify next request, if present, it can start writing
del self.requests[0]
if self._writeLost:
self.transport.loseConnection()
elif self.requests:
self.requests[0].noLongerQueued()
# resume reading if allowed to
if(not self._readLost and
self.readPersistent is not PERSIST_NO_PIPELINE and
len(self.requests) < self.maxPipeline):
self.resumeProducing()
elif self._readLost:
# No more incoming data, they already closed!
self.transport.loseConnection()
else:
# no requests in queue, resume reading
self.setTimeout(self.betweenRequestsTimeOut)
self.resumeProducing()
def setReadPersistent(self, persistent):
if self.readPersistent:
# only allow it to be set if it's not currently False
self.readPersistent = persistent
def dropQueuedRequests(self):
"""Called when a response is written that forces a connection close."""
self.readPersistent = False
# Tell all requests but first to abort.
for request in self.requests[1:]:
request.connectionLost(None)
del self.requests[1:]
def isLastRequest(self, request):
# Is this channel handling the last possible request
return not self.readPersistent and self.requests[-1] == request
def requestWriteFinished(self, request):
"""Called by first request in queue when it is done."""
if request != self.requests[0]: raise TypeError
# Don't del because we haven't finished cleanup, so,
# don't want queue len to be 0 yet.
self.requests[0] = None
if self.readPersistent or len(self.requests) > 1:
# Do this in the next reactor loop so as to
# not cause huge call stacks with fast
# incoming requests.
self._callLater(0, self._startNextRequest)
else:
# Set an abort timer in case an orderly close hangs
self.setTimeout(None)
self._abortTimer = reactor.callLater(self.closeTimeOut, self._abortTimeout)
#reactor.callLater(0.1, self.transport.loseConnection)
self.transport.loseConnection()
def timeoutConnection(self):
#log.info("Timing out client: %s" % str(self.transport.getPeer()))
# Set an abort timer in case an orderly close hangs
self._abortTimer = reactor.callLater(self.closeTimeOut, self._abortTimeout)
policies.TimeoutMixin.timeoutConnection(self)
def _abortTimeout(self):
log.error("Connection aborted - took too long to close: {c}", c=str(self.transport.getPeer()))
self._abortTimer = None
self.transport.abortConnection()
def readConnectionLost(self):
"""Read connection lost"""
# If in the lingering-close state, lose the socket.
if self._abortTimer:
self._abortTimer.cancel()
self._abortTimer = None
self.transport.loseConnection()
return
# If between requests, drop connection
# when all current requests have written their data.
self._readLost = True
if not self.requests:
# No requests in progress, lose now.
self.transport.loseConnection()
# If currently in the process of reading a request, this is
# probably a client abort, so lose the connection.
if self.chanRequest:
self.transport.loseConnection()
def connectionLost(self, reason):
self.factory.removeConnectedChannel(self)
self._writeLost = True
self.readConnectionLost()
self.setTimeout(None)
# Tell all requests to abort.
for request in self.requests:
if request is not None:
request.connectionLost(reason)
class OverloadedServerProtocol(protocol.Protocol):
def connectionMade(self):
self.transport.write("HTTP/1.0 503 Service Unavailable\r\n"
"Content-Type: text/html\r\n"
"Connection: close\r\n\r\n"
"503 Service Unavailable"
"Service Unavailable
"
"The server is currently overloaded, "
"please try again later.")
self.transport.loseConnection()
class HTTPFactory(protocol.ServerFactory):
"""
Factory for HTTP server.
@ivar outstandingRequests: the number of currently connected HTTP channels.
@type outstandingRequests: C{int}
@ivar connectedChannels: all the channels that have currently active
connections.
@type connectedChannels: C{set} of L{HTTPChannel}
"""
protocol = HTTPChannel
protocolArgs = None
def __init__(self, requestFactory, maxRequests=600, **kwargs):
self.maxRequests = maxRequests
self.protocolArgs = kwargs
self.protocolArgs['requestFactory'] = requestFactory
self.connectedChannels = set()
self.allConnectionsClosedDeferred = None
def buildProtocol(self, addr):
if self.outstandingRequests >= self.maxRequests:
return OverloadedServerProtocol()
p = protocol.ServerFactory.buildProtocol(self, addr)
for arg,value in self.protocolArgs.iteritems():
setattr(p, arg, value)
return p
def addConnectedChannel(self, channel):
"""
Add a connected channel to the set of currently connected channels and
increase the outstanding request count.
"""
self.connectedChannels.add(channel)
def removeConnectedChannel(self, channel):
"""
Remove a connected channel from the set of currently connected channels
and decrease the outstanding request count.
If someone is waiting for all the requests to be completed,
self.allConnectionsClosedDeferred will be non-None; fire that callback
when the number of outstanding requests hits zero.
"""
self.connectedChannels.remove(channel)
if self.allConnectionsClosedDeferred is not None:
if self.outstandingRequests == 0:
self.allConnectionsClosedDeferred.callback(None)
@property
def outstandingRequests(self):
return len(self.connectedChannels)
def allConnectionsClosed(self):
"""
Return a Deferred that will fire when all outstanding requests have completed.
@return: A Deferred with a result of None
"""
if self.outstandingRequests == 0:
return succeed(None)
self.allConnectionsClosedDeferred = Deferred()
return self.allConnectionsClosedDeferred
class HTTP503LoggingFactory (HTTPFactory):
"""
Factory for HTTP server which emits a 503 response when overloaded.
"""
def __init__(self, requestFactory, maxRequests=600, retryAfter=0, vary=False, **kwargs):
self.retryAfter = retryAfter
self.vary = vary
HTTPFactory.__init__(self, requestFactory, maxRequests, **kwargs)
def buildProtocol(self, addr):
if self.vary:
retryAfter = randint(int(self.retryAfter * 1/2), int(self.retryAfter * 3/2))
else:
retryAfter = self.retryAfter
if self.outstandingRequests >= self.maxRequests:
return OverloadedLoggingServerProtocol(retryAfter, self.outstandingRequests)
p = protocol.ServerFactory.buildProtocol(self, addr)
for arg,value in self.protocolArgs.iteritems():
setattr(p, arg, value)
return p
class HTTPLoggingChannelRequest(HTTPChannelRequest):
class TransportLoggingWrapper(object):
def __init__(self, transport, logData):
self.transport = transport
self.logData = logData
def write(self, data):
if self.logData is not None and data:
self.logData.append(data)
self.transport.write(data)
def writeSequence(self, seq):
if self.logData is not None and seq:
self.logData.append(''.join(seq))
self.transport.writeSequence(seq)
def __getattr__(self, attr):
return getattr(self.__dict__['transport'], attr)
class LogData(object):
def __init__(self):
self.request = []
self.response = []
def __init__(self, channel, queued=0):
super(HTTPLoggingChannelRequest, self).__init__(channel, queued)
if accounting.accountingEnabledForCategory("HTTP"):
self.logData = HTTPLoggingChannelRequest.LogData()
self.transport = HTTPLoggingChannelRequest.TransportLoggingWrapper(self.transport, self.logData.response)
else:
self.logData = None
def gotInitialLine(self, initialLine):
if self.logData is not None:
self.startTime = time.time()
self.logData.request.append(">>>> Request starting at: %.3f\r\n\r\n" % (self.startTime,))
self.logData.request.append("%s\r\n" % (initialLine,))
super(HTTPLoggingChannelRequest, self).gotInitialLine(initialLine)
def lineReceived(self, line):
if self.logData is not None:
# We don't want to log basic credentials
loggedLine = line
if line.lower().startswith("authorization:"):
bits = line[14:].strip().split(" ")
if bits[0].lower() == "basic" and len(bits) == 2:
loggedLine = "%s %s %s" % (line[:14], bits[0], "X" * len(bits[1]))
self.logData.request.append("%s\r\n" % (loggedLine,))
super(HTTPLoggingChannelRequest, self).lineReceived(line)
def handleContentChunk(self, data):
if self.logData is not None:
self.logData.request.append(data)
super(HTTPLoggingChannelRequest, self).handleContentChunk(data)
def handleContentComplete(self):
if self.logData is not None:
doneTime = time.time()
self.logData.request.append("\r\n\r\n>>>> Request complete at: %.3f (elapsed: %.1f ms)" % (doneTime, 1000 * (doneTime - self.startTime),))
super(HTTPLoggingChannelRequest, self).handleContentComplete()
def writeHeaders(self, code, headers):
if self.logData is not None:
doneTime = time.time()
self.logData.response.append("\r\n\r\n<<<< Response sending at: %.3f (elapsed: %.1f ms)\r\n\r\n" % (doneTime, 1000 * (doneTime - self.startTime),))
super(HTTPLoggingChannelRequest, self).writeHeaders(code, headers)
def finish(self):
super(HTTPLoggingChannelRequest, self).finish()
if self.logData is not None:
doneTime = time.time()
self.logData.response.append("\r\n\r\n<<<< Response complete at: %.3f (elapsed: %.1f ms)\r\n" % (doneTime, 1000 * (doneTime - self.startTime),))
accounting.emitAccounting("HTTP", "", "".join(self.logData.request) + "".join(self.logData.response), self.command)
HTTPChannel.chanRequestFactory = HTTPLoggingChannelRequest
class LimitingHTTPFactory(HTTPFactory):
"""
HTTPFactory which stores maxAccepts on behalf of the MaxAcceptPortMixin
@ivar myServer: a reference to a L{MaxAcceptTCPServer} that this
L{LimitingHTTPFactory} will limit. This must be set externally.
"""
def __init__(self, requestFactory, maxRequests=600, maxAccepts=100,
**kwargs):
HTTPFactory.__init__(self, requestFactory, maxRequests, **kwargs)
self.maxAccepts = maxAccepts
def buildProtocol(self, addr):
"""
Override L{HTTPFactory.buildProtocol} in order to avoid ever returning
an L{OverloadedServerProtocol}; this should be handled in other ways.
"""
p = protocol.ServerFactory.buildProtocol(self, addr)
for arg, value in self.protocolArgs.iteritems():
setattr(p, arg, value)
return p
def addConnectedChannel(self, channel):
"""
Override L{HTTPFactory.addConnectedChannel} to pause listening on the
socket when there are too many outstanding channels.
"""
HTTPFactory.addConnectedChannel(self, channel)
if self.outstandingRequests >= self.maxRequests:
self.myServer.myPort.stopReading()
def removeConnectedChannel(self, channel):
"""
Override L{HTTPFactory.removeConnectedChannel} to resume listening on the
socket when there are too many outstanding channels.
"""
HTTPFactory.removeConnectedChannel(self, channel)
if self.outstandingRequests < self.maxRequests:
self.myServer.myPort.startReading()
__all__ = [
"HTTPFactory",
"HTTP503LoggingFactory",
"LimitingHTTPFactory",
"SSLRedirectRequest",
]
calendarserver-5.2+dfsg/twext/web2/channel/__init__.py 0000644 0001750 0001750 00000002570 12263343324 022050 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_cgi,twext.web2.test.test_http -*-
##
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
Various backend channel implementations for web2.
"""
from twext.web2.channel.http import HTTPFactory
__all__ = ['HTTPFactory']
calendarserver-5.2+dfsg/twext/web2/dav/ 0000755 0001750 0001750 00000000000 12322625326 017076 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/dav/test/ 0000755 0001750 0001750 00000000000 12322625325 020054 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/dav/test/test_move.py 0000644 0001750 0001750 00000010672 12263343324 022441 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
import os
import twext.web2.dav.test.util
import twext.web2.dav.test.test_copy
from twext.web2 import responsecode
from twext.web2.dav.test.util import serialize
from twext.web2.dav.test.test_copy import sumFile
class MOVE(twext.web2.dav.test.util.TestCase):
"""
MOVE request
"""
# FIXME:
# Check that properties are being moved
def test_MOVE_create(self):
"""
MOVE to new resource.
"""
def test(response, path, isfile, sum, uri, depth, dst_path):
if response.code != responsecode.CREATED:
self.fail("Incorrect response code for MOVE %s: %s != %s"
% (uri, response.code, responsecode.CREATED))
if response.headers.getHeader("location") is None:
self.fail("Reponse to MOVE %s with CREATE status is missing location: header."
% (uri,))
if isfile:
if not os.path.isfile(dst_path):
self.fail("MOVE %s produced no output file" % (uri,))
if sum != sumFile(dst_path):
self.fail("MOVE %s produced different file" % (uri,))
else:
if not os.path.isdir(dst_path):
self.fail("MOVE %s produced no output directory" % (uri,))
if sum != sumFile(dst_path):
self.fail("isdir %s produced different directory" % (uri,))
return serialize(self.send, work(self, test))
def test_MOVE_exists(self):
"""
MOVE to existing resource.
"""
def test(response, path, isfile, sum, uri, depth, dst_path):
if response.code != responsecode.PRECONDITION_FAILED:
self.fail("Incorrect response code for MOVE without overwrite %s: %s != %s"
% (uri, response.code, responsecode.PRECONDITION_FAILED))
else:
# FIXME: Check XML error code (2518bis)
pass
return serialize(self.send, work(self, test, overwrite=False))
def test_MOVE_overwrite(self):
"""
MOVE to existing resource with overwrite header.
"""
def test(response, path, isfile, sum, uri, depth, dst_path):
if response.code != responsecode.NO_CONTENT:
self.fail("Incorrect response code for MOVE with overwrite %s: %s != %s"
% (uri, response.code, responsecode.NO_CONTENT))
else:
# FIXME: Check XML error code (2518bis)
pass
return serialize(self.send, work(self, test, overwrite=True))
def test_MOVE_no_parent(self):
"""
MOVE to resource with no parent.
"""
def test(response, path, isfile, sum, uri, depth, dst_path):
if response.code != responsecode.CONFLICT:
self.fail("Incorrect response code for MOVE with no parent %s: %s != %s"
% (uri, response.code, responsecode.CONFLICT))
else:
# FIXME: Check XML error code (2518bis)
pass
return serialize(self.send, work(self, test, dst=os.path.join(self.docroot, "elvislives!")))
def work(self, test, overwrite=None, dst=None):
return twext.web2.dav.test.test_copy.work(self, test, overwrite, dst, depths=(None,))
calendarserver-5.2+dfsg/twext/web2/dav/test/tworequest_client.py 0000644 0001750 0001750 00000001713 11337102650 024204 0 ustar rahul rahul import socket, sys
test_type = sys.argv[1]
port = int(sys.argv[2])
socket_type = sys.argv[3]
s = socket.socket(socket.AF_INET)
s.connect(("127.0.0.1", port))
s.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 40000)
if socket_type == 'ssl':
s2 = socket.ssl(s)
send=s2.write
recv=s2.read
else:
send=s.send
recv=s.recv
print >> sys.stderr, ">> Making %s request to port %d" % (socket_type, port)
send("PUT /forbidden HTTP/1.1\r\n")
send("Host: localhost\r\n")
print >> sys.stderr, ">> Sending lots of data"
send("Content-Length: 100\r\n\r\n")
send("X"*100)
send("PUT /forbidden HTTP/1.1\r\n")
send("Host: localhost\r\n")
print >> sys.stderr, ">> Sending lots of data"
send("Content-Length: 100\r\n\r\n")
send("X"*100)
#import time
#time.sleep(5)
print >> sys.stderr, ">> Getting data"
data=''
while len(data) < 299999:
try:
x=recv(10000)
except:
break
if x == '':
break
data+=x
sys.stdout.write(data)
calendarserver-5.2+dfsg/twext/web2/dav/test/test_put.py 0000644 0001750 0001750 00000012232 12263343324 022275 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
import os
import filecmp
from twext.web2 import responsecode
from twext.web2.iweb import IResponse
from twext.web2.stream import FileStream
from twext.web2.http import HTTPError
import twext.web2.dav.test.util
from twext.web2.test.test_server import SimpleRequest
from twext.web2.dav.test.util import serialize
class PUT(twext.web2.dav.test.util.TestCase):
"""
PUT request
"""
def test_PUT_simple(self):
"""
PUT request
"""
dst_path = os.path.join(self.docroot, "dst")
def checkResult(response, path):
response = IResponse(response)
if response.code not in (
responsecode.CREATED,
responsecode.NO_CONTENT
):
self.fail("PUT failed: %s" % (response.code,))
if not os.path.isfile(dst_path):
self.fail("PUT failed to create file %s." % (dst_path,))
if not filecmp.cmp(path, dst_path):
self.fail("PUT failed to preserve data for file %s in file %s." % (path, dst_path))
etag = response.headers.getHeader("etag")
if not etag:
self.fail("No etag header in PUT response %r." % (response,))
#
# We need to serialize these request & test iterations because they can
# interfere with each other.
#
def work():
dst_uri = "/dst"
for name in os.listdir(self.docroot):
if name == "dst":
continue
path = os.path.join(self.docroot, name)
# Can't really PUT something you can't read
if not os.path.isfile(path): continue
def do_test(response): checkResult(response, path)
request = SimpleRequest(self.site, "PUT", dst_uri)
request.stream = FileStream(file(path, "rb"))
yield (request, do_test)
return serialize(self.send, work())
def test_PUT_again(self):
"""
PUT on existing resource with If-None-Match header
"""
dst_path = os.path.join(self.docroot, "dst")
dst_uri = "/dst"
def work():
for code in (
responsecode.CREATED,
responsecode.PRECONDITION_FAILED,
responsecode.NO_CONTENT,
responsecode.PRECONDITION_FAILED,
responsecode.NO_CONTENT,
responsecode.CREATED,
):
def checkResult(response, code=code):
response = IResponse(response)
if response.code != code:
self.fail("Incorrect response code for PUT (%s != %s)"
% (response.code, code))
def onError(f):
f.trap(HTTPError)
return checkResult(f.value.response)
request = SimpleRequest(self.site, "PUT", dst_uri)
request.stream = FileStream(file(__file__, "rb"))
if code == responsecode.CREATED:
if os.path.isfile(dst_path):
os.remove(dst_path)
request.headers.setHeader("if-none-match", ("*",))
elif code == responsecode.PRECONDITION_FAILED:
request.headers.setHeader("if-none-match", ("*",))
yield (request, (checkResult, onError))
return serialize(self.send, work())
def test_PUT_no_parent(self):
"""
PUT with no parent
"""
dst_uri = "/put/no/parent"
def checkResult(response):
response = IResponse(response)
if response.code != responsecode.CONFLICT:
self.fail("Incorrect response code for PUT with no parent (%s != %s)"
% (response.code, responsecode.CONFLICT))
request = SimpleRequest(self.site, "PUT", dst_uri)
request.stream = FileStream(file(__file__, "rb"))
return self.send(request, checkResult)
calendarserver-5.2+dfsg/twext/web2/dav/test/test_acl.py 0000644 0001750 0001750 00000035665 12263343324 022243 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
import os
from twisted.cred.portal import Portal
from twext.web2 import responsecode
from twext.web2.auth import basic
from twext.web2.stream import MemoryStream
from twext.web2.dav.util import davXMLFromStream
from twext.web2.dav.auth import TwistedPasswordProperty, IPrincipal, DavRealm, TwistedPropertyChecker, AuthenticationWrapper
from twext.web2.dav.fileop import rmdir
from twext.web2.test.test_server import SimpleRequest
from twext.web2.dav.test.util import Site, serialize
from twext.web2.dav.test.test_resource import \
TestDAVPrincipalResource, TestPrincipalsCollection
from txdav.xml import element
import twext.web2.dav.test.util
class ACL(twext.web2.dav.test.util.TestCase):
"""
RFC 3744 (WebDAV ACL) tests.
"""
def createDocumentRoot(self):
docroot = self.mktemp()
os.mkdir(docroot)
userResource = TestDAVPrincipalResource("/principals/users/user01")
userResource.writeDeadProperty(TwistedPasswordProperty("user01"))
principalCollection = TestPrincipalsCollection(
"/principals/",
children={"users": TestPrincipalsCollection(
"/principals/users/",
children={"user01": userResource})})
rootResource = self.resource_class(
docroot, principalCollections=(principalCollection,))
portal = Portal(DavRealm())
portal.registerChecker(TwistedPropertyChecker())
credentialFactories = (basic.BasicCredentialFactory(""),)
loginInterfaces = (IPrincipal,)
self.site = Site(AuthenticationWrapper(
rootResource,
portal,
credentialFactories,
credentialFactories,
loginInterfaces
))
rootResource.setAccessControlList(self.grant(element.All()))
for name, acl in (
("none" , self.grant()),
("read" , self.grant(element.Read())),
("read-write" , self.grant(element.Read(), element.Write())),
("unlock" , self.grant(element.Unlock())),
("all" , self.grant(element.All())),
):
filename = os.path.join(docroot, name)
if not os.path.isfile(filename):
file(filename, "w").close()
resource = self.resource_class(filename)
resource.setAccessControlList(acl)
for name, acl in (
("nobind" , self.grant()),
("bind" , self.grant(element.Bind())),
("unbind" , self.grant(element.Bind(), element.Unbind())),
):
dirname = os.path.join(docroot, name)
if not os.path.isdir(dirname):
os.mkdir(dirname)
resource = self.resource_class(dirname)
resource.setAccessControlList(acl)
return docroot
def restore(self):
# Get rid of whatever messed up state the test has now so that we'll
# get a fresh docroot. This isn't very cool; tests should be doing
# less so that they don't need a fresh copy of this state.
if hasattr(self, "_docroot"):
rmdir(self._docroot)
del self._docroot
def test_COPY_MOVE_source(self):
"""
Verify source access controls during COPY and MOVE.
"""
def work():
dst_path = os.path.join(self.docroot, "copy_dst")
dst_uri = "/" + os.path.basename(dst_path)
for src, status in (
("nobind", responsecode.FORBIDDEN),
("bind", responsecode.FORBIDDEN),
("unbind", responsecode.CREATED),
):
src_path = os.path.join(self.docroot, "src_" + src)
src_uri = "/" + os.path.basename(src_path)
if not os.path.isdir(src_path):
os.mkdir(src_path)
src_resource = self.resource_class(src_path)
src_resource.setAccessControlList({
"nobind": self.grant(),
"bind" : self.grant(element.Bind()),
"unbind": self.grant(element.Bind(), element.Unbind())
}[src])
for name, acl in (
("none" , self.grant()),
("read" , self.grant(element.Read())),
("read-write" , self.grant(element.Read(), element.Write())),
("unlock" , self.grant(element.Unlock())),
("all" , self.grant(element.All())),
):
filename = os.path.join(src_path, name)
if not os.path.isfile(filename):
file(filename, "w").close()
self.resource_class(filename).setAccessControlList(acl)
for method in ("COPY", "MOVE"):
for name, code in (
("none" , {"COPY": responsecode.FORBIDDEN, "MOVE": status}[method]),
("read" , {"COPY": responsecode.CREATED, "MOVE": status}[method]),
("read-write" , {"COPY": responsecode.CREATED, "MOVE": status}[method]),
("unlock" , {"COPY": responsecode.FORBIDDEN, "MOVE": status}[method]),
("all" , {"COPY": responsecode.CREATED, "MOVE": status}[method]),
):
path = os.path.join(src_path, name)
uri = src_uri + "/" + name
request = SimpleRequest(self.site, method, uri)
request.headers.setHeader("destination", dst_uri)
_add_auth_header(request)
def test(response, code=code, path=path):
if os.path.isfile(dst_path):
os.remove(dst_path)
if response.code != code:
return self.oops(request, response, code, method, name)
yield (request, test)
return serialize(self.send, work())
def test_COPY_MOVE_dest(self):
"""
Verify destination access controls during COPY and MOVE.
"""
def work():
src_path = os.path.join(self.docroot, "read")
uri = "/" + os.path.basename(src_path)
for method in ("COPY", "MOVE"):
for name, code in (
("nobind" , responsecode.FORBIDDEN),
("bind" , responsecode.CREATED),
("unbind" , responsecode.CREATED),
):
dst_parent_path = os.path.join(self.docroot, name)
dst_path = os.path.join(dst_parent_path, "dst")
request = SimpleRequest(self.site, method, uri)
request.headers.setHeader("destination", "/" + name + "/dst")
_add_auth_header(request)
def test(response, code=code, dst_path=dst_path):
if os.path.isfile(dst_path):
os.remove(dst_path)
if response.code != code:
return self.oops(request, response, code, method, name)
yield (request, test)
self.restore()
return serialize(self.send, work())
def test_DELETE(self):
"""
Verify access controls during DELETE.
"""
def work():
for name, code in (
("nobind" , responsecode.FORBIDDEN),
("bind" , responsecode.FORBIDDEN),
("unbind" , responsecode.NO_CONTENT),
):
collection_path = os.path.join(self.docroot, name)
path = os.path.join(collection_path, "dst")
file(path, "w").close()
request = SimpleRequest(self.site, "DELETE", "/" + name + "/dst")
_add_auth_header(request)
def test(response, code=code, path=path):
if response.code != code:
return self.oops(request, response, code, "DELETE", name)
yield (request, test)
return serialize(self.send, work())
def test_UNLOCK(self):
"""
Verify access controls during UNLOCK of unowned lock.
"""
raise NotImplementedError()
test_UNLOCK.todo = "access controls on UNLOCK unimplemented"
def test_MKCOL_PUT(self):
"""
Verify access controls during MKCOL.
"""
for method in ("MKCOL", "PUT"):
def work():
for name, code in (
("nobind" , responsecode.FORBIDDEN),
("bind" , responsecode.CREATED),
("unbind" , responsecode.CREATED),
):
collection_path = os.path.join(self.docroot, name)
path = os.path.join(collection_path, "dst")
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
os.rmdir(path)
request = SimpleRequest(self.site, method, "/" + name + "/dst")
_add_auth_header(request)
def test(response, code=code, path=path):
if response.code != code:
return self.oops(request, response, code, method, name)
yield (request, test)
return serialize(self.send, work())
def test_PUT_exists(self):
"""
Verify access controls during PUT of existing file.
"""
def work():
for name, code in (
("none" , responsecode.FORBIDDEN),
("read" , responsecode.FORBIDDEN),
("read-write" , responsecode.NO_CONTENT),
("unlock" , responsecode.FORBIDDEN),
("all" , responsecode.NO_CONTENT),
):
path = os.path.join(self.docroot, name)
request = SimpleRequest(self.site, "PUT", "/" + name)
_add_auth_header(request)
def test(response, code=code, path=path):
if response.code != code:
return self.oops(request, response, code, "PUT", name)
yield (request, test)
return serialize(self.send, work())
def test_PROPFIND(self):
"""
Verify access controls during PROPFIND.
"""
raise NotImplementedError()
test_PROPFIND.todo = "access controls on PROPFIND unimplemented"
def test_PROPPATCH(self):
"""
Verify access controls during PROPPATCH.
"""
def work():
for name, code in (
("none" , responsecode.FORBIDDEN),
("read" , responsecode.FORBIDDEN),
("read-write" , responsecode.MULTI_STATUS),
("unlock" , responsecode.FORBIDDEN),
("all" , responsecode.MULTI_STATUS),
):
path = os.path.join(self.docroot, name)
request = SimpleRequest(self.site, "PROPPATCH", "/" + name)
request.stream = MemoryStream(
element.WebDAVDocument(element.PropertyUpdate()).toxml()
)
_add_auth_header(request)
def test(response, code=code, path=path):
if response.code != code:
return self.oops(request, response, code, "PROPPATCH", name)
yield (request, test)
return serialize(self.send, work())
def test_GET_REPORT(self):
"""
Verify access controls during GET and REPORT.
"""
def work():
for method in ("GET", "REPORT"):
if method == "GET":
ok = responsecode.OK
elif method == "REPORT":
ok = responsecode.MULTI_STATUS
else:
raise AssertionError("We shouldn't be here. (method = %r)" % (method,))
for name, code in (
("none" , responsecode.FORBIDDEN),
("read" , ok),
("read-write" , ok),
("unlock" , responsecode.FORBIDDEN),
("all" , ok),
):
path = os.path.join(self.docroot, name)
request = SimpleRequest(self.site, method, "/" + name)
if method == "REPORT":
request.stream = MemoryStream(element.PrincipalPropertySearch().toxml())
_add_auth_header(request)
def test(response, code=code, path=path):
if response.code != code:
return self.oops(request, response, code, method, name)
yield (request, test)
return serialize(self.send, work())
def oops(self, request, response, code, method, name):
def gotResponseData(doc):
if doc is None:
doc_xml = None
else:
doc_xml = doc.toxml()
def fail(acl):
self.fail("Incorrect status code %s (!= %s) for %s of resource %s with %s ACL: %s\nACL: %s"
% (response.code, code, method, request.uri, name, doc_xml, acl.toxml()))
def getACL(resource):
return resource.accessControlList(request)
d = request.locateResource(request.uri)
d.addCallback(getACL)
d.addCallback(fail)
return d
d = davXMLFromStream(response.stream)
d.addCallback(gotResponseData)
return d
def _add_auth_header(request):
request.headers.setHeader(
"authorization",
("basic", "user01:user01".encode("base64"))
)
calendarserver-5.2+dfsg/twext/web2/dav/test/test_mkcol.py 0000644 0001750 0001750 00000006001 12263343324 022567 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
import os
from twext.web2 import responsecode
from twext.web2.iweb import IResponse
from twext.web2.stream import MemoryStream
from twext.web2.dav.fileop import rmdir
from twext.web2.test.test_server import SimpleRequest
import twext.web2.dav.test.util
class MKCOL(twext.web2.dav.test.util.TestCase):
"""
MKCOL request
"""
# FIXME:
# Try in nonexistant parent collection.
# Try on existing resource.
# Try with request body?
def test_MKCOL(self):
"""
MKCOL request
"""
path, uri = self.mkdtemp("collection")
rmdir(path)
def check_result(response):
response = IResponse(response)
if response.code != responsecode.CREATED:
self.fail("MKCOL response %s != %s" % (response.code, responsecode.CREATED))
if not os.path.isdir(path):
self.fail("MKCOL did not create directory %s" % (path,))
request = SimpleRequest(self.site, "MKCOL", uri)
return self.send(request, check_result)
def test_MKCOL_invalid_body(self):
"""
MKCOL request with invalid request body
(Any body at all is invalid in our implementation; there is no
such thing as a valid body.)
"""
path, uri = self.mkdtemp("collection")
rmdir(path)
def check_result(response):
response = IResponse(response)
if response.code != responsecode.UNSUPPORTED_MEDIA_TYPE:
self.fail("MKCOL response %s != %s" % (response.code, responsecode.UNSUPPORTED_MEDIA_TYPE))
if os.path.isdir(path):
self.fail("MKCOL incorrectly created directory %s" % (path,))
request = SimpleRequest(self.site, "MKCOL", uri)
request.stream = MemoryStream("This is not a valid MKCOL request body.")
return self.send(request, check_result)
calendarserver-5.2+dfsg/twext/web2/dav/test/test_static.py 0000644 0001750 0001750 00000004524 12263343324 022761 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from twext.web2.dav.test import util
from txdav.xml import element as davxml
from twext.web2.stream import readStream
from twext.web2.test.test_server import SimpleRequest
class DAVFileTest(util.TestCase):
def test_renderPrivileges(self):
"""
Verify that a directory listing includes children which you
don't have access to.
"""
request = SimpleRequest(self.site, "GET", "/")
def setEmptyACL(resource):
resource.setAccessControlList(davxml.ACL()) # Empty ACL = no access
return resource
def renderRoot(_):
d = request.locateResource("/")
d.addCallback(lambda r: r.render(request))
return d
def assertListing(response):
data = []
d = readStream(response.stream, lambda s: data.append(str(s)))
d.addCallback(lambda _: self.failIf(
'dir2/' not in "".join(data),
"'dir2' expected in listing: %r" % (data,)
))
return d
d = request.locateResource("/dir2")
d.addCallback(setEmptyACL)
d.addCallback(renderRoot)
d.addCallback(assertListing)
return d
calendarserver-5.2+dfsg/twext/web2/dav/test/data/ 0000755 0001750 0001750 00000000000 12322625325 020765 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/ 0000755 0001750 0001750 00000000000 12322625325 021565 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPFIND_request.xml 0000644 0001750 0001750 00000000514 11337102650 025274 0 ustar rahul rahul
calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/REPORT_request.xml 0000644 0001750 0001750 00000000442 11337102650 025066 0 ustar rahul rahul
calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPFIND_response.xml 0000644 0001750 0001750 00000004246 11337102650 025450 0 ustar rahul rahul
/uploads/
2005-07-05T23:08:01Z
Tue, 05 Jul 2005 23:08:01 GMT
"77a99-66-27dd9640"
httpd/unix-directory
HTTP/1.1 200 OK
/uploads/foo.txt
2005-07-05T23:08:08Z
19
Tue, 05 Jul 2005 23:08:08 GMT
"77a9f-13-28486600"
F
text/plain
HTTP/1.1 200 OK
calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPFIND_bad.xml 0000644 0001750 0001750 00000000144 11337102650 024331 0 ustar rahul rahul
calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPPATCH_request.xml 0000644 0001750 0001750 00000001730 11337102650 025414 0 ustar rahul rahul
value0
value1
value2
value3
value4
value5
value6
value7
value8
value9
calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/PROPFIND_nonamespace.xml 0000644 0001750 0001750 00000000172 11337102650 026075 0 ustar rahul rahul
calendarserver-5.2+dfsg/twext/web2/dav/test/data/xml/REPORT_response.xml 0000644 0001750 0001750 00000003473 11337102650 025243 0 ustar rahul rahul
http://www.webdav.org/foo.html
http://repo.webdav.org/his/23
http://repo.webdav.org/his/23/ver/1
Fred
http://www.webdav.org/ws/dev/sally
HTTP/1.1 200 OK
http://repo.webdav.org/his/23/ver/2
Sally
http://repo.webdav.org/act/add-refresh-cmd
HTTP/1.1 200 OK
HTTP/1.1 200 OK
HTTP/1.1 200 OK
calendarserver-5.2+dfsg/twext/web2/dav/test/data/quota_100.txt 0000644 0001750 0001750 00000000144 11337102650 023232 0 ustar rahul rahul 123456789
123456789
123456789
123456789
123456789
123456789
123456789
123456789
123456789
123456789
calendarserver-5.2+dfsg/twext/web2/dav/test/test_prop.py 0000644 0001750 0001750 00000033302 12263343324 022446 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from twext.web2 import responsecode
from twext.web2.iweb import IResponse
from twext.web2.stream import MemoryStream
from twext.web2 import http_headers
from twext.web2.dav.util import davXMLFromStream
from twext.web2.test.test_server import SimpleRequest
from twext.web2.dav.test.util import serialize
from txdav.xml import element as davxml
from txdav.xml.element import dav_namespace, lookupElement
import twext.web2.dav.test.util
# Remove dynamic live properties that exist
dynamicLiveProperties = (
(dav_namespace, "quota-available-bytes" ),
(dav_namespace, "quota-used-bytes" ),
)
class PROP(twext.web2.dav.test.util.TestCase):
"""
PROPFIND, PROPPATCH requests
"""
def liveProperties(self):
return [lookupElement(qname)() for qname in self.site.resource.liveProperties() if (qname[0] == dav_namespace) and qname not in dynamicLiveProperties]
def test_PROPFIND_basic(self):
"""
PROPFIND request
"""
def check_result(response):
response = IResponse(response)
if response.code != responsecode.MULTI_STATUS:
self.fail("Incorrect response code for PROPFIND (%s != %s)"
% (response.code, responsecode.MULTI_STATUS))
content_type = response.headers.getHeader("content-type")
if content_type not in (http_headers.MimeType("text", "xml"),
http_headers.MimeType("application", "xml")):
self.fail("Incorrect content-type for PROPFIND response (%r not in %r)"
% (content_type, (http_headers.MimeType("text", "xml"),
http_headers.MimeType("application", "xml"))))
return davXMLFromStream(response.stream).addCallback(check_xml)
def check_xml(doc):
multistatus = doc.root_element
if not isinstance(multistatus, davxml.MultiStatus):
self.fail("PROPFIND response XML root element is not multistatus: %r" % (multistatus,))
for response in multistatus.childrenOfType(davxml.PropertyStatusResponse):
if response.childOfType(davxml.HRef) == "/":
for propstat in response.childrenOfType(davxml.PropertyStatus):
status = propstat.childOfType(davxml.Status)
properties = propstat.childOfType(davxml.PropertyContainer).children
if status.code != responsecode.OK:
self.fail("PROPFIND failed (status %s) to locate live properties: %s"
% (status.code, properties))
properties_to_find = [p.qname() for p in self.liveProperties()]
for property in properties:
qname = property.qname()
if qname in properties_to_find:
properties_to_find.remove(qname)
else:
self.fail("PROPFIND found property we didn't ask for: %r" % (property,))
if properties_to_find:
self.fail("PROPFIND failed to find properties: %r" % (properties_to_find,))
break
else:
self.fail("No response for URI /")
query = davxml.PropertyFind(davxml.PropertyContainer(*self.liveProperties()))
request = SimpleRequest(self.site, "PROPFIND", "/")
depth = "1"
if depth is not None:
request.headers.setHeader("depth", depth)
request.stream = MemoryStream(query.toxml())
return self.send(request, check_result)
def test_PROPFIND_list(self):
"""
PROPFIND with allprop, propname
"""
def check_result(which):
def _check_result(response):
response = IResponse(response)
if response.code != responsecode.MULTI_STATUS:
self.fail("Incorrect response code for PROPFIND (%s != %s)"
% (response.code, responsecode.MULTI_STATUS))
return davXMLFromStream(response.stream).addCallback(check_xml, which)
return _check_result
def check_xml(doc, which):
response = doc.root_element.childOfType(davxml.PropertyStatusResponse)
self.failUnless(
response.childOfType(davxml.HRef) == "/",
"Incorrect response URI: %s != /" % (response.childOfType(davxml.HRef),)
)
for propstat in response.childrenOfType(davxml.PropertyStatus):
status = propstat.childOfType(davxml.Status)
properties = propstat.childOfType(davxml.PropertyContainer).children
if status.code != responsecode.OK:
self.fail("PROPFIND failed (status %s) to locate live properties: %s"
% (status.code, properties))
if which.name == "allprop":
properties_to_find = [p.qname() for p in self.liveProperties() if not p.hidden]
else:
properties_to_find = [p.qname() for p in self.liveProperties()]
for property in properties:
qname = property.qname()
if qname in properties_to_find:
properties_to_find.remove(qname)
elif qname[0] != dav_namespace:
pass
else:
self.fail("PROPFIND with %s found property we didn't expect: %r" % (which.name, property))
if which.name == "propname":
# Element should be empty
self.failUnless(len(property.children) == 0)
else:
# Element should have a value, unless the property exists and is empty...
# Verify that there is a value for live properties for which we know
# that this should be the case.
if property.namespace == dav_namespace and property.name in (
"getetag",
"getcontenttype",
"getlastmodified",
"creationdate",
"displayname",
):
self.failIf(
len(property.children) == 0,
"Property has no children: %r" % (property.toxml(),)
)
if properties_to_find:
self.fail("PROPFIND with %s failed to find properties: %r" % (which.name, properties_to_find))
properties = propstat.childOfType(davxml.PropertyContainer).children
def work():
for which in (davxml.AllProperties(), davxml.PropertyName()):
query = davxml.PropertyFind(which)
request = SimpleRequest(self.site, "PROPFIND", "/")
request.headers.setHeader("depth", "0")
request.stream = MemoryStream(query.toxml())
yield (request, check_result(which))
return serialize(self.send, work())
def test_PROPPATCH_basic(self):
"""
PROPPATCH
"""
# FIXME:
# Do PROPFIND to make sure it's still there
# Test nonexistant resource
# Test None namespace in property
def check_patch_response(response):
response = IResponse(response)
if response.code != responsecode.MULTI_STATUS:
self.fail("Incorrect response code for PROPFIND (%s != %s)"
% (response.code, responsecode.MULTI_STATUS))
content_type = response.headers.getHeader("content-type")
if content_type not in (http_headers.MimeType("text", "xml"),
http_headers.MimeType("application", "xml")):
self.fail("Incorrect content-type for PROPPATCH response (%r not in %r)"
% (content_type, (http_headers.MimeType("text", "xml"),
http_headers.MimeType("application", "xml"))))
return davXMLFromStream(response.stream).addCallback(check_patch_xml)
def check_patch_xml(doc):
multistatus = doc.root_element
if not isinstance(multistatus, davxml.MultiStatus):
self.fail("PROPFIND response XML root element is not multistatus: %r" % (multistatus,))
# Requested a property change one resource, so there should be exactly one response
response = multistatus.childOfType(davxml.Response)
# Should have a response description (its contents are arbitrary)
response.childOfType(davxml.ResponseDescription)
# Requested property change was on /
self.failUnless(
response.childOfType(davxml.HRef) == "/",
"Incorrect response URI: %s != /" % (response.childOfType(davxml.HRef),)
)
# Requested one property change, so there should be exactly one property status
propstat = response.childOfType(davxml.PropertyStatus)
# And the contained property should be a SpiffyProperty
self.failIf(
propstat.childOfType(davxml.PropertyContainer).childOfType(SpiffyProperty) is None,
"Not a SpiffyProperty in PROPPATCH property status: %s" % (propstat.toxml())
)
# And the status should be 200
self.failUnless(
propstat.childOfType(davxml.Status).code == responsecode.OK,
"Incorrect status code for PROPPATCH of property %s: %s != %s"
% (propstat.childOfType(davxml.PropertyContainer).toxml(),
propstat.childOfType(davxml.Status).code, responsecode.OK)
)
patch = davxml.PropertyUpdate(
davxml.Set(
davxml.PropertyContainer(
SpiffyProperty.fromString("This is a spiffy resource.")
)
)
)
request = SimpleRequest(self.site, "PROPPATCH", "/")
request.stream = MemoryStream(patch.toxml())
return self.send(request, check_patch_response)
def test_PROPPATCH_liveprop(self):
"""
PROPPATCH on a live property
"""
prop = davxml.GETETag.fromString("some-etag-string")
patch = davxml.PropertyUpdate(davxml.Set(davxml.PropertyContainer(prop)))
return self._simple_PROPPATCH(patch, prop, responsecode.FORBIDDEN, "edit of live property")
def test_PROPPATCH_exists_not(self):
"""
PROPPATCH remove a non-existant property
"""
prop = davxml.Timeout() # Timeout isn't a valid property, so it won't exist.
patch = davxml.PropertyUpdate(davxml.Remove(davxml.PropertyContainer(prop)))
return self._simple_PROPPATCH(patch, prop, responsecode.OK, "remove of non-existant property")
def _simple_PROPPATCH(self, patch, prop, expected_code, what):
def check_result(response):
response = IResponse(response)
if response.code != responsecode.MULTI_STATUS:
self.fail("Incorrect response code for PROPPATCH (%s != %s)"
% (response.code, responsecode.MULTI_STATUS))
return davXMLFromStream(response.stream).addCallback(check_xml)
def check_xml(doc):
response = doc.root_element.childOfType(davxml.Response)
propstat = response.childOfType(davxml.PropertyStatus)
self.failUnless(
response.childOfType(davxml.HRef) == "/",
"Incorrect response URI: %s != /" % (response.childOfType(davxml.HRef),)
)
self.failIf(
propstat.childOfType(davxml.PropertyContainer).childOfType(prop) is None,
"Not a %s in PROPPATCH property status: %s" % (prop.sname(), propstat.toxml())
)
self.failUnless(
propstat.childOfType(davxml.Status).code == expected_code,
"Incorrect status code for PROPPATCH %s: %s != %s"
% (what, propstat.childOfType(davxml.Status).code, expected_code)
)
request = SimpleRequest(self.site, "PROPPATCH", "/")
request.stream = MemoryStream(patch.toxml())
return self.send(request, check_result)
class SpiffyProperty (davxml.WebDAVTextElement):
namespace = "http://twistedmatrix.com/ns/private/tests"
name = "spiffyproperty"
calendarserver-5.2+dfsg/twext/web2/dav/test/test_copy.py 0000644 0001750 0001750 00000016015 12263343324 022442 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from hashlib import md5
import os
import urllib
import twext.web2.dav.test.util
from twext.web2 import responsecode
from twext.web2.test.test_server import SimpleRequest
from twext.web2.dav.test.util import dircmp, serialize
from twext.web2.dav.fileop import rmdir
class COPY(twext.web2.dav.test.util.TestCase):
"""
COPY request
"""
# FIXME:
# Check that properties are being copied
def test_COPY_create(self):
"""
COPY to new resource.
"""
def test(response, path, isfile, sum, uri, depth, dst_path):
if response.code != responsecode.CREATED:
self.fail("Incorrect response code for COPY %s (depth=%r): %s != %s"
% (uri, depth, response.code, responsecode.CREATED))
if response.headers.getHeader("location") is None:
self.fail("Reponse to COPY %s (depth=%r) with CREATE status is missing location: header."
% (uri, depth))
if os.path.isfile(path):
if not os.path.isfile(dst_path):
self.fail("COPY %s (depth=%r) produced no output file" % (uri, depth))
if not cmp(path, dst_path):
self.fail("COPY %s (depth=%r) produced different file" % (uri, depth))
os.remove(dst_path)
elif os.path.isdir(path):
if not os.path.isdir(dst_path):
self.fail("COPY %s (depth=%r) produced no output directory" % (uri, depth))
if depth in ("infinity", None):
if dircmp(path, dst_path):
self.fail("COPY %s (depth=%r) produced different directory" % (uri, depth))
elif depth == "0":
for filename in os.listdir(dst_path):
self.fail("COPY %s (depth=%r) shouldn't copy directory contents (eg. %s)" % (uri, depth, filename))
else: raise AssertionError("Unknown depth: %r" % (depth,))
rmdir(dst_path)
else:
self.fail("Source %s is neither a file nor a directory"
% (path,))
return serialize(self.send, work(self, test))
def test_COPY_exists(self):
"""
COPY to existing resource.
"""
def test(response, path, isfile, sum, uri, depth, dst_path):
if response.code != responsecode.PRECONDITION_FAILED:
self.fail("Incorrect response code for COPY without overwrite %s: %s != %s"
% (uri, response.code, responsecode.PRECONDITION_FAILED))
else:
# FIXME: Check XML error code (2518bis)
pass
return serialize(self.send, work(self, test, overwrite=False))
def test_COPY_overwrite(self):
"""
COPY to existing resource with overwrite header.
"""
def test(response, path, isfile, sum, uri, depth, dst_path):
if response.code != responsecode.NO_CONTENT:
self.fail("Incorrect response code for COPY with overwrite %s: %s != %s"
% (uri, response.code, responsecode.NO_CONTENT))
else:
# FIXME: Check XML error code (2518bis)
pass
self.failUnless(os.path.exists(dst_path), "COPY didn't produce file: %s" % (dst_path,))
return serialize(self.send, work(self, test, overwrite=True))
def test_COPY_no_parent(self):
"""
COPY to resource with no parent.
"""
def test(response, path, isfile, sum, uri, depth, dst_path):
if response.code != responsecode.CONFLICT:
self.fail("Incorrect response code for COPY with no parent %s: %s != %s"
% (uri, response.code, responsecode.CONFLICT))
else:
# FIXME: Check XML error code (2518bis)
pass
return serialize(self.send, work(self, test, dst=os.path.join(self.docroot, "elvislives!")))
def work(self, test, overwrite=None, dst=None, depths=("0", "infinity", None)):
if dst is None:
dst = os.path.join(self.docroot, "dst")
os.mkdir(dst)
for basename in os.listdir(self.docroot):
if basename == "dst": continue
uri = urllib.quote("/" + basename)
path = os.path.join(self.docroot, basename)
isfile = os.path.isfile(path)
sum = sumFile(path)
dst_path = os.path.join(dst, basename)
dst_uri = urllib.quote("/dst/" + basename)
if not isfile:
uri += "/"
dst_uri += "/"
if overwrite is not None:
# Create a file at dst_path to create a conflict
file(dst_path, "w").close()
for depth in depths:
def do_test(response, path=path, isfile=isfile, sum=sum, uri=uri, depth=depth, dst_path=dst_path):
test(response, path, isfile, sum, uri, depth, dst_path)
request = SimpleRequest(self.site, self.__class__.__name__, uri)
request.headers.setHeader("destination", dst_uri)
if depth is not None:
request.headers.setHeader("depth", depth)
if overwrite is not None:
request.headers.setHeader("overwrite", overwrite)
yield (request, do_test)
def sumFile(path):
m = md5()
if os.path.isfile(path):
f = file(path)
try:
m.update(f.read())
finally:
f.close()
elif os.path.isdir(path):
for dir, subdirs, files in os.walk(path):
for filename in files:
m.update(filename)
f = file(os.path.join(dir, filename))
try:
m.update(f.read())
finally:
f.close()
for dirname in subdirs:
m.update(dirname + "/")
else:
raise AssertionError()
return m.digest()
calendarserver-5.2+dfsg/twext/web2/dav/test/test_http.py 0000644 0001750 0001750 00000006702 12263343324 022451 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
import errno
from twisted.python.failure import Failure
from twext.web2 import responsecode
from twext.web2.http import HTTPError
from twext.web2.dav.http import ErrorResponse, statusForFailure
import twext.web2.dav.test.util
class HTTP(twext.web2.dav.test.util.TestCase):
"""
HTTP Utilities
"""
def test_statusForFailure_errno(self):
"""
statusForFailure() for exceptions with known errno values
"""
for ex_class in (IOError, OSError):
for exception, result in (
(ex_class(errno.EACCES, "Permission denied" ), responsecode.FORBIDDEN),
(ex_class(errno.EPERM , "Permission denied" ), responsecode.FORBIDDEN),
(ex_class(errno.ENOSPC, "No space available"), responsecode.INSUFFICIENT_STORAGE_SPACE),
(ex_class(errno.ENOENT, "No such file" ), responsecode.NOT_FOUND),
):
self._check_exception(exception, result)
def test_statusForFailure_HTTPError(self):
"""
statusForFailure() for HTTPErrors
"""
for code in responsecode.RESPONSES:
self._check_exception(HTTPError(code), code)
self._check_exception(HTTPError(ErrorResponse(code, ("http://twistedmatrix.com/", "bar"))), code)
def test_statusForFailure_exception(self):
"""
statusForFailure() for known/unknown exceptions
"""
for exception, result in (
(NotImplementedError("Duh..."), responsecode.NOT_IMPLEMENTED),
):
self._check_exception(exception, result)
class UnknownException (Exception):
pass
try:
self._check_exception(UnknownException(), None)
except UnknownException:
pass
else:
self.fail("Unknown exception should have re-raised.")
def _check_exception(self, exception, result):
try:
raise exception
except Exception:
failure = Failure()
status = statusForFailure(failure)
self.failUnless(
status == result,
"Failure %r (%s) generated incorrect status code: %s != %s"
% (failure, failure.value, status, result)
)
else:
raise AssertionError("We shouldn't be here.")
calendarserver-5.2+dfsg/twext/web2/dav/test/test_resource.py 0000644 0001750 0001750 00000041146 12263343324 023322 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from twisted.internet.defer import DeferredList, waitForDeferred, deferredGenerator, succeed
from twisted.cred.portal import Portal
from twext.web2 import responsecode
from twext.web2.http import HTTPError
from twext.web2.auth import basic
from twext.web2.server import Site
from txdav.xml import element as davxml
from twext.web2.dav.resource import DAVResource, AccessDeniedError, \
DAVPrincipalResource, DAVPrincipalCollectionResource, davPrivilegeSet
from twext.web2.dav.auth import TwistedPasswordProperty, DavRealm, TwistedPropertyChecker, IPrincipal, AuthenticationWrapper
from twext.web2.test.test_server import SimpleRequest
from twext.web2.dav.test.util import InMemoryPropertyStore
import twext.web2.dav.test.util
class TestCase(twext.web2.dav.test.util.TestCase):
def setUp(self):
twext.web2.dav.test.util.TestCase.setUp(self)
TestResource._cachedPropertyStores = {}
class GenericDAVResource(TestCase):
def setUp(self):
TestCase.setUp(self)
rootresource = TestResource(None, {
"file1": TestResource("/file1"),
"file2": AuthAllResource("/file2"),
"dir1": TestResource("/dir1/", {
"subdir1": TestResource("/dir1/subdir1/",{})
}),
"dir2": AuthAllResource("/dir2/", {
"file1": TestResource("/dir2/file1"),
"file2": TestResource("/dir2/file2"),
"subdir1": TestResource("/dir2/subdir1/", {
"file1": TestResource("/dir2/subdir1/file1"),
"file2": TestResource("/dir2/subdir1/file2")
})
})
})
self.site = Site(rootresource)
def test_findChildren(self):
"""
This test asserts that we have:
1) not found any unexpected children
2) found all expected children
It does this for all depths C{"0"}, C{"1"}, and C{"infintiy"}
"""
expected_children = {
"0": [],
"1": [
"/file1",
"/file2",
"/dir1/",
"/dir2/",
],
"infinity": [
"/file1",
"/file2",
"/dir1/",
"/dir1/subdir1/",
"/dir2/",
"/dir2/file1",
"/dir2/file2",
"/dir2/subdir1/",
"/dir2/subdir1/file1",
"/dir2/subdir1/file2",
],
}
request = SimpleRequest(self.site, "GET", "/")
resource = waitForDeferred(request.locateResource("/"))
yield resource
resource = resource.getResult()
def checkChildren(resource, uri):
self.assertEquals(uri, resource.uri)
if uri not in expected_children[depth]:
unexpected_children.append(uri)
else:
found_children.append(uri)
for depth in ["0", "1", "infinity"]:
found_children = []
unexpected_children = []
fc = resource.findChildren(depth, request, checkChildren)
completed = waitForDeferred(fc)
yield completed
completed.getResult()
self.assertEquals(
unexpected_children, [],
"Found unexpected children: %r" % (unexpected_children,)
)
expected_children[depth].sort()
found_children.sort()
self.assertEquals(expected_children[depth], found_children)
test_findChildren = deferredGenerator(test_findChildren)
def test_findChildrenWithPrivileges(self):
"""
This test revokes read privileges for the C{"/file2"} and C{"/dir2/"}
resource to verify that we can not find them giving our unauthenticated
privileges.
"""
expected_children = [
"/file1",
"/dir1/",
]
request = SimpleRequest(self.site, "GET", "/")
resource = waitForDeferred(request.locateResource("/"))
yield resource
resource = resource.getResult()
def checkChildren(resource, uri):
self.assertEquals(uri, resource.uri)
if uri not in expected_children:
unexpected_children.append(uri)
else:
found_children.append(uri)
found_children = []
unexpected_children = []
privileges = waitForDeferred(resource.currentPrivileges(request))
yield privileges
privileges = privileges.getResult()
fc = resource.findChildren("1", request, checkChildren, privileges)
completed = waitForDeferred(fc)
yield completed
completed.getResult()
self.assertEquals(
unexpected_children, [],
"Found unexpected children: %r" % (unexpected_children,)
)
expected_children.sort()
found_children.sort()
self.assertEquals(expected_children, found_children)
test_findChildrenWithPrivileges = deferredGenerator(test_findChildrenWithPrivileges)
def test_findChildrenCallbackRaises(self):
"""
Verify that when the user callback raises an exception
the completion deferred returned by findChildren errbacks
TODO: Verify that the user callback doesn't get called subsequently
"""
def raiseOnChild(resource, uri):
raise Exception("Oh no!")
def findChildren(resource):
return self.assertFailure(
resource.findChildren("infinity", request, raiseOnChild),
Exception
)
request = SimpleRequest(self.site, "GET", "/")
d = request.locateResource("/").addCallback(findChildren)
return d
class AccessTests(TestCase):
def setUp(self):
TestCase.setUp(self)
gooduser = TestDAVPrincipalResource("/users/gooduser")
gooduser.writeDeadProperty(TwistedPasswordProperty("goodpass"))
baduser = TestDAVPrincipalResource("/users/baduser")
baduser.writeDeadProperty(TwistedPasswordProperty("badpass"))
rootresource = TestPrincipalsCollection("/", {
"users": TestResource("/users/",
{"gooduser": gooduser,
"baduser": baduser})
})
protected = TestResource(
"/protected", principalCollections=[rootresource])
protected.setAccessControlList(davxml.ACL(
davxml.ACE(
davxml.Principal(davxml.HRef("/users/gooduser")),
davxml.Grant(davxml.Privilege(davxml.All())),
davxml.Protected()
)
))
rootresource.children["protected"] = protected
portal = Portal(DavRealm())
portal.registerChecker(TwistedPropertyChecker())
credentialFactories = (basic.BasicCredentialFactory(""),)
loginInterfaces = (IPrincipal,)
self.rootresource = rootresource
self.site = Site(AuthenticationWrapper(
self.rootresource,
portal,
credentialFactories,
credentialFactories,
loginInterfaces,
))
def checkSecurity(self, request):
"""
Locate the resource named by the given request's URI, then authorize it
for the 'Read' permission.
"""
d = request.locateResource(request.uri)
d.addCallback(lambda r: r.authorize(request, (davxml.Read(),)))
return d
def assertErrorResponse(self, error, expectedcode, otherExpectations=lambda err: None):
self.assertEquals(error.response.code, expectedcode)
otherExpectations(error)
def test_checkPrivileges(self):
"""
DAVResource.checkPrivileges()
"""
ds = []
authAllResource = AuthAllResource()
requested_access = (davxml.All(),)
site = Site(authAllResource)
def expectError(failure):
failure.trap(AccessDeniedError)
errors = failure.value.errors
self.failUnless(len(errors) == 1)
subpath, denials = errors[0]
self.failUnless(subpath is None)
self.failUnless(
tuple(denials) == requested_access,
"%r != %r" % (tuple(denials), requested_access)
)
def expectOK(result):
self.failUnlessEquals(result, None)
def _checkPrivileges(resource):
d = resource.checkPrivileges(request, requested_access)
return d
# No auth; should deny
request = SimpleRequest(site, "GET", "/")
d = request.locateResource("/").addCallback(_checkPrivileges).addErrback(expectError)
ds.append(d)
# Has auth; should allow
request = SimpleRequest(site, "GET", "/")
request.authnUser = davxml.Principal(davxml.HRef("/users/d00d"))
request.authzUser = davxml.Principal(davxml.HRef("/users/d00d"))
d = request.locateResource("/")
d.addCallback(_checkPrivileges)
d.addCallback(expectOK)
ds.append(d)
return DeferredList(ds)
def test_authorize(self):
"""
Authorizing a known user with the correct password will not raise an
exception, indicating that the user is properly authorized given their
credentials.
"""
request = SimpleRequest(self.site, "GET", "/protected")
request.headers.setHeader(
"authorization",
("basic", "gooduser:goodpass".encode("base64")))
return self.checkSecurity(request)
def test_badUsernameOrPassword(self):
request = SimpleRequest(self.site, "GET", "/protected")
request.headers.setHeader(
"authorization",
("basic", "gooduser:badpass".encode("base64"))
)
d = self.assertFailure(self.checkSecurity(request), HTTPError)
def expectWwwAuth(err):
self.failUnless(err.response.headers.hasHeader("WWW-Authenticate"),
"No WWW-Authenticate header present.")
d.addCallback(self.assertErrorResponse, responsecode.UNAUTHORIZED, expectWwwAuth)
return d
def test_lacksPrivileges(self):
request = SimpleRequest(self.site, "GET", "/protected")
request.headers.setHeader(
"authorization",
("basic", "baduser:badpass".encode("base64"))
)
d = self.assertFailure(self.checkSecurity(request), HTTPError)
d.addCallback(self.assertErrorResponse, responsecode.FORBIDDEN)
return d
##
# Utilities
##
class TestResource (DAVResource):
"""A simple test resource used for creating trees of
DAV Resources
"""
_cachedPropertyStores = {}
acl = davxml.ACL(
davxml.ACE(
davxml.Principal(davxml.All()),
davxml.Grant(davxml.Privilege(davxml.All())),
davxml.Protected(),
)
)
def __init__(self, uri=None, children=None, principalCollections=()):
"""
@param uri: A string respresenting the URI of the given resource
@param children: a dictionary of names to Resources
"""
DAVResource.__init__(self, principalCollections=principalCollections)
self.children = children
self.uri = uri
def deadProperties(self):
"""
Retrieve deadProperties from a special place in memory
"""
if not hasattr(self, "_dead_properties"):
dp = TestResource._cachedPropertyStores.get(self.uri)
if dp is None:
TestResource._cachedPropertyStores[self.uri] = InMemoryPropertyStore(self)
dp = TestResource._cachedPropertyStores[self.uri]
self._dead_properties = dp
return self._dead_properties
def isCollection(self):
return self.children is not None
def listChildren(self):
return self.children.keys()
def supportedPrivileges(self, request):
return succeed(davPrivilegeSet)
def currentPrincipal(self, request):
if hasattr(request, "authzUser"):
return request.authzUser
else:
return davxml.Principal(davxml.Unauthenticated())
def locateChild(self, request, segments):
child = segments[0]
if child == "":
return self, segments[1:]
elif child in self.children:
return self.children[child], segments[1:]
else:
raise HTTPError(404)
def setAccessControlList(self, acl):
self.acl = acl
def accessControlList(self, request, **kwargs):
return succeed(self.acl)
class TestPrincipalsCollection(DAVPrincipalCollectionResource, TestResource):
"""
A full implementation of L{IDAVPrincipalCollectionResource}, implemented as
a L{TestResource} which assumes a single L{TestResource} child named
'users'.
"""
def __init__(self, url, children):
DAVPrincipalCollectionResource.__init__(self, url)
TestResource.__init__(self, url, children, principalCollections=(self,))
def principalForUser(self, user):
"""
@see L{IDAVPrincipalCollectionResource.principalForUser}.
"""
return self.principalForShortName('users', user)
def principalForAuthID(self, creds):
"""
Retrieve the principal for the authentication identifier from a set of
credentials.
Note that although this method is not actually invoked anywhere in
web2.dav, this test class is currently imported by CalendarServer,
which requires this method.
@param creds: credentials which identify a user
@type creds: L{twisted.cred.credentials.IUsernameHashedPassword} or
L{twisted.cred.credentials.IUsernamePassword}
@return: a DAV principal resource representing a user.
@rtype: L{IDAVPrincipalResource} or C{NoneType}
"""
# XXX either move this to CalendarServer entirely or document it on
# IDAVPrincipalCollectionResource
return self.principalForShortName('users', creds.username)
def principalForShortName(self, type, shortName):
"""
Retrieve the principal of a given type from this resource.
Note that although this method is not actually invoked anywhere (aside
from test methods) in web2.dav, this test class is currently imported by
CalendarServer, which requires this method.
@param: a short string (such as 'users' or 'groups') identifying both
the principal type, and the name of a resource in the 'children'
dictionary, which itself is a L{TestResource} with
L{IDAVPrincipalCollectionResource} children.
@return: a DAV principal resource of the given type with the given
name.
@rtype: L{IDAVPrincipalResource} or C{NoneType}
"""
# XXX either move this to CalendarServer entirely or document it on
# IDAVPrincipalCollectionResource
typeResource = self.children.get(type, None)
user = None
if typeResource:
user = typeResource.children.get(shortName, None)
return user
class AuthAllResource (TestResource):
"""
Give Authenticated principals all privileges and deny everyone else.
"""
acl = davxml.ACL(
davxml.ACE(
davxml.Principal(davxml.Authenticated()),
davxml.Grant(davxml.Privilege(davxml.All())),
davxml.Protected(),
)
)
class TestDAVPrincipalResource(DAVPrincipalResource, TestResource):
"""
Get deadProperties from TestResource
"""
def principalURL(self):
return self.uri
calendarserver-5.2+dfsg/twext/web2/dav/test/test_lock.py 0000644 0001750 0001750 00000003426 12263343324 022422 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from twisted.trial.unittest import SkipTest
import twext.web2.dav.test.util
class LOCK_UNLOCK(twext.web2.dav.test.util.TestCase):
"""
LOCK, UNLOCK requests
"""
# FIXME:
# Check PUT
# Check POST
# Check PROPPATCH
# Check LOCK
# Check UNLOCK
# Check MOVE, COPY
# Check DELETE
# Check MKCOL
# Check null resource
# Check collections
# Check depth
# Check If header
# Refresh lock
def test_LOCK_UNLOCK(self):
"""
LOCK, UNLOCK request
"""
raise SkipTest("test unimplemented")
test_LOCK_UNLOCK.todo = "LOCK/UNLOCK unimplemented"
calendarserver-5.2+dfsg/twext/web2/dav/test/test_report_expand.py 0000644 0001750 0001750 00000002764 12263343324 024350 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from twisted.trial.unittest import SkipTest
import twext.web2.dav.test.util
class REPORT_expand(twext.web2.dav.test.util.TestCase):
"""
DAV:expand-property REPORT request
"""
def test_REPORT_expand_property(self):
"""
DAV:expand-property REPORT request.
"""
raise SkipTest("test unimplemeted")
calendarserver-5.2+dfsg/twext/web2/dav/test/test_util.py 0000644 0001750 0001750 00000013202 12263343324 022440 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from twisted.trial import unittest
from twext.web2.dav import util
class Utilities(unittest.TestCase):
"""
Utilities.
"""
def test_normalizeURL(self):
"""
normalizeURL()
"""
self.assertEquals(util.normalizeURL("http://server//foo"), "http://server/foo")
self.assertEquals(util.normalizeURL("http://server/foo/.."), "http://server/")
self.assertEquals(util.normalizeURL("/foo/bar/..//"), "/foo")
self.assertEquals(util.normalizeURL("/foo/bar/.//"), "/foo/bar")
self.assertEquals(util.normalizeURL("//foo///bar/../baz"), "/foo/baz")
self.assertEquals(util.normalizeURL("//foo///bar/./baz"), "/foo/bar/baz")
self.assertEquals(util.normalizeURL("///../"), "/")
self.assertEquals(util.normalizeURL("/.."), "/")
def test_joinURL(self):
"""
joinURL()
"""
self.assertEquals(util.joinURL("http://server/foo/"), "http://server/foo/")
self.assertEquals(util.joinURL("http://server/foo", "/bar"), "http://server/foo/bar")
self.assertEquals(util.joinURL("http://server/foo", "bar"), "http://server/foo/bar")
self.assertEquals(util.joinURL("http://server/foo/", "/bar"), "http://server/foo/bar")
self.assertEquals(util.joinURL("http://server/foo/", "/bar/.."), "http://server/foo")
self.assertEquals(util.joinURL("http://server/foo/", "/bar/."), "http://server/foo/bar")
self.assertEquals(util.joinURL("http://server/foo/", "/bar/../"), "http://server/foo/")
self.assertEquals(util.joinURL("http://server/foo/", "/bar/./"), "http://server/foo/bar/")
self.assertEquals(util.joinURL("http://server/foo/../", "/bar"), "http://server/bar")
self.assertEquals(util.joinURL("/foo/"), "/foo/")
self.assertEquals(util.joinURL("/foo", "/bar"), "/foo/bar")
self.assertEquals(util.joinURL("/foo", "bar"), "/foo/bar")
self.assertEquals(util.joinURL("/foo/", "/bar"), "/foo/bar")
self.assertEquals(util.joinURL("/foo/", "/bar/.."), "/foo")
self.assertEquals(util.joinURL("/foo/", "/bar/."), "/foo/bar")
self.assertEquals(util.joinURL("/foo/", "/bar/../"), "/foo/")
self.assertEquals(util.joinURL("/foo/", "/bar/./"), "/foo/bar/")
self.assertEquals(util.joinURL("/foo/../", "/bar"), "/bar")
self.assertEquals(util.joinURL("/foo", "/../"), "/")
self.assertEquals(util.joinURL("/foo", "/./"), "/foo/")
def test_parentForURL(self):
"""
parentForURL()
"""
self.assertEquals(util.parentForURL("http://server/"), None)
self.assertEquals(util.parentForURL("http://server//"), None)
self.assertEquals(util.parentForURL("http://server/foo/.."), None)
self.assertEquals(util.parentForURL("http://server/foo/../"), None)
self.assertEquals(util.parentForURL("http://server/foo/."), "http://server/")
self.assertEquals(util.parentForURL("http://server/foo/./"), "http://server/")
self.assertEquals(util.parentForURL("http://server/foo"), "http://server/")
self.assertEquals(util.parentForURL("http://server//foo"), "http://server/")
self.assertEquals(util.parentForURL("http://server/foo/bar/.."), "http://server/")
self.assertEquals(util.parentForURL("http://server/foo/bar/."), "http://server/foo/")
self.assertEquals(util.parentForURL("http://server/foo/bar"), "http://server/foo/")
self.assertEquals(util.parentForURL("http://server/foo/bar/"), "http://server/foo/")
self.assertEquals(util.parentForURL("http://server/foo/bar?x=1&y=2"), "http://server/foo/")
self.assertEquals(util.parentForURL("http://server/foo/bar/?x=1&y=2"), "http://server/foo/")
self.assertEquals(util.parentForURL("/"), None)
self.assertEquals(util.parentForURL("/foo/.."), None)
self.assertEquals(util.parentForURL("/foo/../"), None)
self.assertEquals(util.parentForURL("/foo/."), "/")
self.assertEquals(util.parentForURL("/foo/./"), "/")
self.assertEquals(util.parentForURL("/foo"), "/")
self.assertEquals(util.parentForURL("/foo"), "/")
self.assertEquals(util.parentForURL("/foo/bar/.."), "/")
self.assertEquals(util.parentForURL("/foo/bar/."), "/foo/")
self.assertEquals(util.parentForURL("/foo/bar"), "/foo/")
self.assertEquals(util.parentForURL("/foo/bar/"), "/foo/")
self.assertEquals(util.parentForURL("/foo/bar?x=1&y=2"), "/foo/")
self.assertEquals(util.parentForURL("/foo/bar/?x=1&y=2"), "/foo/")
calendarserver-5.2+dfsg/twext/web2/dav/test/test_quota.py 0000644 0001750 0001750 00000015467 12263343324 022633 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from twext.web2 import responsecode
from twext.web2.iweb import IResponse
from twext.web2.stream import FileStream
import twext.web2.dav.test.util
from twext.web2.test.test_server import SimpleRequest
from twext.web2.dav.test.util import Site
from txdav.xml import element as davxml
import os
class QuotaBase(twext.web2.dav.test.util.TestCase):
def createDocumentRoot(self):
docroot = self.mktemp()
os.mkdir(docroot)
rootresource = self.resource_class(docroot)
rootresource.setAccessControlList(self.grantInherit(davxml.All()))
self.site = Site(rootresource)
self.site.resource.setQuotaRoot(None, 100000)
return docroot
def checkQuota(self, value):
def _defer(quota):
self.assertEqual(quota, value)
d = self.site.resource.currentQuotaUse(None)
d.addCallback(_defer)
return d
class QuotaEmpty(QuotaBase):
def test_Empty_Quota(self):
return self.checkQuota(0)
class QuotaPUT(QuotaBase):
def test_Quota_PUT(self):
"""
Quota change on PUT
"""
dst_uri = "/dst"
def checkResult(response):
response = IResponse(response)
if response.code != responsecode.CREATED:
self.fail("Incorrect response code for PUT (%s != %s)"
% (response.code, responsecode.CREATED))
return self.checkQuota(100)
request = SimpleRequest(self.site, "PUT", dst_uri)
request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb"))
return self.send(request, checkResult)
class QuotaDELETE(QuotaBase):
def test_Quota_DELETE(self):
"""
Quota change on DELETE
"""
dst_uri = "/dst"
def checkPUTResult(response):
response = IResponse(response)
if response.code != responsecode.CREATED:
self.fail("Incorrect response code for PUT (%s != %s)"
% (response.code, responsecode.CREATED))
def doDelete(_ignore):
def checkDELETEResult(response):
response = IResponse(response)
if response.code != responsecode.NO_CONTENT:
self.fail("Incorrect response code for PUT (%s != %s)"
% (response.code, responsecode.NO_CONTENT))
return self.checkQuota(0)
request = SimpleRequest(self.site, "DELETE", dst_uri)
return self.send(request, checkDELETEResult)
d = self.checkQuota(100)
d.addCallback(doDelete)
return d
request = SimpleRequest(self.site, "PUT", dst_uri)
request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb"))
return self.send(request, checkPUTResult)
class OverQuotaPUT(QuotaBase):
def test_Quota_PUT(self):
"""
Quota change on PUT
"""
dst_uri = "/dst"
self.site.resource.setQuotaRoot(None, 90)
def checkResult(response):
response = IResponse(response)
if response.code != responsecode.INSUFFICIENT_STORAGE_SPACE:
self.fail("Incorrect response code for PUT (%s != %s)"
% (response.code, responsecode.INSUFFICIENT_STORAGE_SPACE))
return self.checkQuota(0)
request = SimpleRequest(self.site, "PUT", dst_uri)
request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb"))
return self.send(request, checkResult)
class QuotaOKAdjustment(QuotaBase):
def test_Quota_OK_Adjustment(self):
"""
Quota adjustment OK
"""
dst_uri = "/dst"
def checkPUTResult(response):
response = IResponse(response)
if response.code != responsecode.CREATED:
self.fail("Incorrect response code for PUT (%s != %s)"
% (response.code, responsecode.CREATED))
def doOKAdjustment(_ignore):
def checkAdjustmentResult(_ignore):
return self.checkQuota(10)
d = self.site.resource.quotaSizeAdjust(None, -90)
d.addCallback(checkAdjustmentResult)
return d
d = self.checkQuota(100)
d.addCallback(doOKAdjustment)
return d
request = SimpleRequest(self.site, "PUT", dst_uri)
request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb"))
return self.send(request, checkPUTResult)
class QuotaBadAdjustment(QuotaBase):
def test_Quota_Bad_Adjustment(self):
"""
Quota adjustment too much
"""
dst_uri = "/dst"
def checkPUTResult(response):
response = IResponse(response)
if response.code != responsecode.CREATED:
self.fail("Incorrect response code for PUT (%s != %s)"
% (response.code, responsecode.CREATED))
def doBadAdjustment(_ignore):
def checkAdjustmentResult(_ignore):
return self.checkQuota(100)
d = self.site.resource.quotaSizeAdjust(None, -200)
d.addCallback(checkAdjustmentResult)
return d
d = self.checkQuota(100)
d.addCallback(doBadAdjustment)
return d
request = SimpleRequest(self.site, "PUT", dst_uri)
request.stream = FileStream(file(os.path.join(os.path.dirname(__file__), "data", "quota_100.txt"), "rb"))
return self.send(request, checkPUTResult)
calendarserver-5.2+dfsg/twext/web2/dav/test/test_xattrprops.py 0000644 0001750 0001750 00000041451 12022736174 023722 0 ustar rahul rahul # Copyright (c) 2009 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twext.web2.dav.xattrprops}.
"""
from zlib import compress, decompress
from pickle import dumps
from cPickle import UnpicklingError
from twext.python.filepath import CachingFilePath as FilePath
from twisted.trial.unittest import TestCase
from twext.web2.responsecode import NOT_FOUND, INTERNAL_SERVER_ERROR
from twext.web2.responsecode import FORBIDDEN
from twext.web2.http import HTTPError
from twext.web2.dav.static import DAVFile
from txdav.xml.element import Depth, WebDAVDocument
try:
from twext.web2.dav.xattrprops import xattrPropertyStore
except ImportError:
xattrPropertyStore = None
else:
from xattr import xattr
class ExtendedAttributesPropertyStoreTests(TestCase):
"""
Tests for L{xattrPropertyStore}.
"""
if xattrPropertyStore is None:
skip = "xattr package missing, cannot test xattr property store"
def setUp(self):
"""
Create a resource and a xattr property store for it.
"""
self.resourcePath = FilePath(self.mktemp())
self.resourcePath.setContent("")
self.attrs = xattr(self.resourcePath.path)
self.resource = DAVFile(self.resourcePath.path)
self.propertyStore = xattrPropertyStore(self.resource)
def test_getAbsent(self):
"""
L{xattrPropertyStore.get} raises L{HTTPError} with a I{NOT FOUND}
response code if passed the name of an attribute for which there is no
corresponding value.
"""
error = self.assertRaises(HTTPError, self.propertyStore.get, ("foo", "bar"))
self.assertEquals(error.response.code, NOT_FOUND)
def _forbiddenTest(self, method):
# Remove access to the directory containing the file so that getting
# extended attributes from it fails with EPERM.
self.resourcePath.parent().chmod(0)
# Make sure to restore access to it later so that it can be deleted
# after the test run is finished.
self.addCleanup(self.resourcePath.parent().chmod, 0700)
# Try to get a property from it - and fail.
document = self._makeValue()
error = self.assertRaises(
HTTPError,
getattr(self.propertyStore, method),
document.root_element.qname())
# Make sure that the status is FORBIDDEN, a roughly reasonable mapping
# of the EPERM failure.
self.assertEquals(error.response.code, FORBIDDEN)
def _missingTest(self, method):
# Remove access to the directory containing the file so that getting
# extended attributes from it fails with EPERM.
self.resourcePath.parent().chmod(0)
# Make sure to restore access to it later so that it can be deleted
# after the test run is finished.
self.addCleanup(self.resourcePath.parent().chmod, 0700)
# Try to get a property from it - and fail.
document = self._makeValue()
error = self.assertRaises(
HTTPError,
getattr(self.propertyStore, method),
document.root_element.qname())
# Make sure that the status is FORBIDDEN, a roughly reasonable mapping
# of the EPERM failure.
self.assertEquals(error.response.code, FORBIDDEN)
def test_getErrors(self):
"""
If there is a problem getting the specified property (aside from the
property not existing), L{xattrPropertyStore.get} raises L{HTTPError}
with a status code which is determined by the nature of the problem.
"""
self._forbiddenTest('get')
def test_getMissing(self):
"""
Test missing file.
"""
resourcePath = FilePath(self.mktemp())
resource = DAVFile(resourcePath.path)
propertyStore = xattrPropertyStore(resource)
# Try to get a property from it - and fail.
document = self._makeValue()
error = self.assertRaises(
HTTPError,
propertyStore.get,
document.root_element.qname())
# Make sure that the status is NOT FOUND.
self.assertEquals(error.response.code, NOT_FOUND)
def _makeValue(self, uid=None):
"""
Create and return any old WebDAVDocument for use by the get tests.
"""
element = Depth(uid if uid is not None else "0")
document = WebDAVDocument(element)
return document
def _setValue(self, originalDocument, value, uid=None):
element = originalDocument.root_element
attribute = (
self.propertyStore.deadPropertyXattrPrefix +
(uid if uid is not None else "") +
element.sname())
self.attrs[attribute] = value
def _getValue(self, originalDocument, uid=None):
element = originalDocument.root_element
attribute = (
self.propertyStore.deadPropertyXattrPrefix +
(uid if uid is not None else "") +
element.sname())
return self.attrs[attribute]
def _checkValue(self, originalDocument, uid=None):
property = originalDocument.root_element.qname()
# Try to load it via xattrPropertyStore.get
loadedDocument = self.propertyStore.get(property, uid)
# XXX Why isn't this a WebDAVDocument?
self.assertIsInstance(loadedDocument, Depth)
self.assertEquals(str(loadedDocument), uid if uid else "0")
def test_getXML(self):
"""
If there is an XML document associated with the property name passed to
L{xattrPropertyStore.get}, that value is parsed into a
L{WebDAVDocument}, the root element of which C{get} then returns.
"""
document = self._makeValue()
self._setValue(document, document.toxml())
self._checkValue(document)
def test_getCompressed(self):
"""
If there is a compressed value associated with the property name passed
to L{xattrPropertyStore.get}, that value is decompressed and parsed
into a L{WebDAVDocument}, the root element of which C{get} then
returns.
"""
document = self._makeValue()
self._setValue(document, compress(document.toxml()))
self._checkValue(document)
def test_getPickled(self):
"""
If there is a pickled document associated with the property name passed
to L{xattrPropertyStore.get}, that value is unpickled into a
L{WebDAVDocument}, the root element of which is returned.
"""
document = self._makeValue()
self._setValue(document, dumps(document))
self._checkValue(document)
def test_getUpgradeXML(self):
"""
If the value associated with the property name passed to
L{xattrPropertyStore.get} is an uncompressed XML document, it is
upgraded on access by compressing it.
"""
document = self._makeValue()
originalValue = document.toxml()
self._setValue(document, originalValue)
self._checkValue(document)
self.assertEquals(
decompress(self._getValue(document)), document.root_element.toxml(pretty=False))
def test_getUpgradeCompressedPickle(self):
"""
If the value associated with the property name passed to
L{xattrPropertyStore.get} is a compressed pickled document, it is
upgraded on access to the compressed XML format.
"""
document = self._makeValue()
self._setValue(document, compress(dumps(document)))
self._checkValue(document)
self.assertEquals(
decompress(self._getValue(document)), document.root_element.toxml(pretty=False))
def test_getInvalid(self):
"""
If the value associated with the property name passed to
L{xattrPropertyStore.get} cannot be interpreted, an error is logged and
L{HTTPError} is raised with the I{INTERNAL SERVER ERROR} response code.
"""
document = self._makeValue()
self._setValue(
document,
"random garbage goes here! \0 that nul is definitely garbage")
property = document.root_element.qname()
error = self.assertRaises(HTTPError, self.propertyStore.get, property)
self.assertEquals(error.response.code, INTERNAL_SERVER_ERROR)
self.assertEquals(
len(self.flushLoggedErrors(UnpicklingError)), 1)
def test_set(self):
"""
L{xattrPropertyStore.set} accepts a L{WebDAVElement} and stores a
compressed XML document representing it in an extended attribute.
"""
document = self._makeValue()
self.propertyStore.set(document.root_element)
self.assertEquals(
decompress(self._getValue(document)), document.root_element.toxml(pretty=False))
def test_delete(self):
"""
L{xattrPropertyStore.delete} deletes the named property.
"""
document = self._makeValue()
self.propertyStore.set(document.root_element)
self.propertyStore.delete(document.root_element.qname())
self.assertRaises(KeyError, self._getValue, document)
def test_deleteNonExistent(self):
"""
L{xattrPropertyStore.delete} does nothing if passed a property which
has no value.
"""
document = self._makeValue()
self.propertyStore.delete(document.root_element.qname())
self.assertRaises(KeyError, self._getValue, document)
def test_deleteErrors(self):
"""
If there is a problem deleting the specified property (aside from the
property not existing), L{xattrPropertyStore.delete} raises
L{HTTPError} with a status code which is determined by the nature of
the problem.
"""
# Remove the file so that deleting extended attributes of it fails with
# EEXIST.
self.resourcePath.remove()
# Try to delete a property from it - and fail.
document = self._makeValue()
error = self.assertRaises(
HTTPError,
self.propertyStore.delete, document.root_element.qname())
# Make sure that the status is NOT FOUND, a roughly reasonable mapping
# of the EEXIST failure.
self.assertEquals(error.response.code, NOT_FOUND)
def test_contains(self):
"""
L{xattrPropertyStore.contains} returns C{True} if the given property
has a value, C{False} otherwise.
"""
document = self._makeValue()
self.assertFalse(
self.propertyStore.contains(document.root_element.qname()))
self._setValue(document, document.toxml())
self.assertTrue(
self.propertyStore.contains(document.root_element.qname()))
def test_containsError(self):
"""
If there is a problem checking if the specified property exists (aside
from the property not existing), L{xattrPropertyStore.contains} raises
L{HTTPError} with a status code which is determined by the nature of
the problem.
"""
self._forbiddenTest('contains')
def test_containsMissing(self):
"""
Test missing file.
"""
resourcePath = FilePath(self.mktemp())
resource = DAVFile(resourcePath.path)
propertyStore = xattrPropertyStore(resource)
# Try to get a property from it - and fail.
document = self._makeValue()
self.assertFalse(propertyStore.contains(document.root_element.qname()))
def test_list(self):
"""
L{xattrPropertyStore.list} returns a C{list} of property names
associated with the wrapped file.
"""
prefix = self.propertyStore.deadPropertyXattrPrefix
self.attrs[prefix + '{foo}bar'] = 'baz'
self.attrs[prefix + '{bar}baz'] = 'quux'
self.assertEquals(
set(self.propertyStore.list()),
set([(u'foo', u'bar'), (u'bar', u'baz')]))
def test_listError(self):
"""
If there is a problem checking if the specified property exists (aside
from the property not existing), L{xattrPropertyStore.contains} raises
L{HTTPError} with a status code which is determined by the nature of
the problem.
"""
# Remove access to the directory containing the file so that getting
# extended attributes from it fails with EPERM.
self.resourcePath.parent().chmod(0)
# Make sure to restore access to it later so that it can be deleted
# after the test run is finished.
self.addCleanup(self.resourcePath.parent().chmod, 0700)
# Try to get a property from it - and fail.
self._makeValue()
error = self.assertRaises(HTTPError, self.propertyStore.list)
# Make sure that the status is FORBIDDEN, a roughly reasonable mapping
# of the EPERM failure.
self.assertEquals(error.response.code, FORBIDDEN)
def test_listMissing(self):
"""
Test missing file.
"""
resourcePath = FilePath(self.mktemp())
resource = DAVFile(resourcePath.path)
propertyStore = xattrPropertyStore(resource)
# Try to get a property from it - and fail.
self.assertEqual(propertyStore.list(), [])
def test_get_uids(self):
"""
L{xattrPropertyStore.get} accepts a L{WebDAVElement} and stores a
compressed XML document representing it in an extended attribute.
"""
for uid in (None, "123", "456",):
document = self._makeValue(uid)
self._setValue(document, document.toxml(), uid=uid)
for uid in (None, "123", "456",):
document = self._makeValue(uid)
self._checkValue(document, uid=uid)
def test_set_uids(self):
"""
L{xattrPropertyStore.set} accepts a L{WebDAVElement} and stores a
compressed XML document representing it in an extended attribute.
"""
for uid in (None, "123", "456",):
document = self._makeValue(uid)
self.propertyStore.set(document.root_element, uid=uid)
self.assertEquals(
decompress(self._getValue(document, uid)), document.root_element.toxml(pretty=False))
def test_delete_uids(self):
"""
L{xattrPropertyStore.set} accepts a L{WebDAVElement} and stores a
compressed XML document representing it in an extended attribute.
"""
for delete_uid in (None, "123", "456",):
for uid in (None, "123", "456",):
document = self._makeValue(uid)
self.propertyStore.set(document.root_element, uid=uid)
self.propertyStore.delete(document.root_element.qname(), uid=delete_uid)
self.assertRaises(KeyError, self._getValue, document, uid=delete_uid)
for uid in (None, "123", "456",):
if uid == delete_uid:
continue
document = self._makeValue(uid)
self.assertEquals(
decompress(self._getValue(document, uid)), document.root_element.toxml(pretty=False))
def test_contains_uids(self):
"""
L{xattrPropertyStore.contains} returns C{True} if the given property
has a value, C{False} otherwise.
"""
for uid in (None, "123", "456",):
document = self._makeValue(uid)
self.assertFalse(
self.propertyStore.contains(document.root_element.qname(), uid=uid))
self._setValue(document, document.toxml(), uid=uid)
self.assertTrue(
self.propertyStore.contains(document.root_element.qname(), uid=uid))
def test_list_uids(self):
"""
L{xattrPropertyStore.list} returns a C{list} of property names
associated with the wrapped file.
"""
prefix = self.propertyStore.deadPropertyXattrPrefix
for uid in (None, "123", "456",):
user = uid if uid is not None else ""
self.attrs[prefix + '%s{foo}bar' % (user,)] = 'baz%s' % (user,)
self.attrs[prefix + '%s{bar}baz' % (user,)] = 'quux%s' % (user,)
self.attrs[prefix + '%s{moo}mar%s' % (user, user,)] = 'quux%s' % (user,)
for uid in (None, "123", "456",):
user = uid if uid is not None else ""
self.assertEquals(
set(self.propertyStore.list(uid)),
set([
(u'foo', u'bar'),
(u'bar', u'baz'),
(u'moo', u'mar%s' % (user,)),
]))
self.assertEquals(
set(self.propertyStore.list(filterByUID=False)),
set([
(u'foo', u'bar', None),
(u'bar', u'baz', None),
(u'moo', u'mar', None),
(u'foo', u'bar', "123"),
(u'bar', u'baz', "123"),
(u'moo', u'mar123', "123"),
(u'foo', u'bar', "456"),
(u'bar', u'baz', "456"),
(u'moo', u'mar456', "456"),
]))
calendarserver-5.2+dfsg/twext/web2/dav/test/test_pipeline.py 0000644 0001750 0001750 00000005755 12263343324 023306 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
import sys, os
from twisted.internet import utils
from twext.web2.test import test_server
from twext.web2 import resource
from twext.web2 import http
from twext.web2.test import test_http
from twisted.internet.defer import waitForDeferred, deferredGenerator
from twisted.python import util
class Pipeline(test_server.BaseCase):
"""
Pipelined request
"""
class TestResource(resource.LeafResource):
def render(self, req):
return http.Response(stream="Host:%s, Path:%s"%(req.host, req.path))
def setUp(self):
self.root = self.TestResource()
def chanrequest(self, root, uri, length, headers, method, version, prepath, content):
self.cr = super(Pipeline, self).chanrequest(root, uri, length, headers, method, version, prepath, content)
return self.cr
def test_root(self):
def _testStreamRead(x):
self.assertTrue(self.cr.request.stream.length == 0)
return self.assertResponse(
(self.root, 'http://host/path', {"content-type":"text/plain",}, "PUT", None, '', "This is some text."),
(405, {}, None)).addCallback(_testStreamRead)
class SSLPipeline(test_http.SSLServerTest):
@deferredGenerator
def testAdvancedWorkingness(self):
args = ('-u', util.sibpath(__file__, "tworequest_client.py"), "basic",
str(self.port), self.type)
d = waitForDeferred(utils.getProcessOutputAndValue(sys.executable,
args=args,
env=os.environ))
yield d; out,err,code = d.getResult()
self.assertEquals(code, 0, "Error output:\n%s" % (err,))
self.assertEquals(out, "HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\nHTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n")
calendarserver-5.2+dfsg/twext/web2/dav/test/test_options.py 0000644 0001750 0001750 00000004214 12263343324 023161 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from twext.web2.iweb import IResponse
import twext.web2.dav.test.util
from twext.web2.test.test_server import SimpleRequest
class OPTIONS(twext.web2.dav.test.util.TestCase):
"""
OPTIONS request
"""
def test_DAV1(self):
"""
DAV level 1
"""
return self._test_level("1")
def test_DAV2(self):
"""
DAV level 2
"""
return self._test_level("2")
test_DAV2.todo = "DAV level 2 unimplemented"
def test_ACL(self):
"""
DAV ACL
"""
return self._test_level("access-control")
def _test_level(self, level):
def doTest(response):
response = IResponse(response)
dav = response.headers.getHeader("dav")
if not dav: self.fail("no DAV header: %s" % (response.headers,))
self.assertIn(level, dav, "no DAV level %s header" % (level,))
return response
return self.send(SimpleRequest(self.site, "OPTIONS", "/"), doTest)
calendarserver-5.2+dfsg/twext/web2/dav/test/__init__.py 0000644 0001750 0001750 00000002276 12263343324 022174 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
Tests for twext.web2.dav.
"""
calendarserver-5.2+dfsg/twext/web2/dav/test/util.py 0000644 0001750 0001750 00000026230 12263343324 021406 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
import os
from urllib import quote as url_quote
from filecmp import dircmp as DirCompare
from tempfile import mkdtemp
from shutil import copy
from twisted.trial import unittest
from twisted.internet import address
from twisted.internet.defer import Deferred
from twext.python.log import Logger
from twext.web2.http import HTTPError, StatusResponse
from twext.web2 import responsecode, server
from twext.web2 import http_headers
from twext.web2 import stream
from twext.web2.dav.resource import TwistedACLInheritable
from twext.web2.dav.static import DAVFile
from twext.web2.dav.util import joinURL
from txdav.xml import element
from txdav.xml.base import encodeXMLName
from twext.web2.http_headers import MimeType
from twext.web2.dav.util import allDataFromStream
log = Logger()
class SimpleRequest(server.Request):
"""
A L{SimpleRequest} can be used in cases where a L{server.Request} object is
necessary but it is beneficial to bypass the concrete transport (and
associated logic with the C{chanRequest} attribute).
"""
clientproto = (1, 1)
def __init__(self, site, method, uri, headers=None, content=None):
if not headers:
headers = http_headers.Headers(headers)
super(SimpleRequest, self).__init__(
site=site,
chanRequest=None,
command=method,
path=uri,
version=self.clientproto,
contentLength=len(content or ''),
headers=headers)
self.stream = stream.MemoryStream(content or '')
self.remoteAddr = address.IPv4Address('TCP', '127.0.0.1', 0)
self._parseURL()
self.host = 'localhost'
self.port = 8080
def writeResponse(self, response):
if self.chanRequest:
self.chanRequest.writeHeaders(response.code, response.headers)
return response
class InMemoryPropertyStore (object):
"""
A dead property store for keeping properties in memory
DO NOT USE OUTSIDE OF UNIT TESTS!
"""
def __init__(self, resource):
self._dict = {}
def get(self, qname):
try:
property = self._dict[qname]
except KeyError:
raise HTTPError(StatusResponse(
responsecode.NOT_FOUND,
"No such property: %s" % (encodeXMLName(*qname),)
))
doc = element.WebDAVDocument.fromString(property)
return doc.root_element
def set(self, property):
self._dict[property.qname()] = property.toxml()
def delete(self, qname):
try:
del(self._dict[qname])
except KeyError:
pass
def contains(self, qname):
return qname in self._dict
def list(self):
return self._dict.keys()
class TestFile (DAVFile):
_cachedPropertyStores = {}
def deadProperties(self):
if not hasattr(self, "_dead_properties"):
dp = TestFile._cachedPropertyStores.get(self.fp.path)
if dp is None:
TestFile._cachedPropertyStores[self.fp.path] = InMemoryPropertyStore(self)
dp = TestFile._cachedPropertyStores[self.fp.path]
self._dead_properties = dp
return self._dead_properties
def parent(self):
return TestFile(self.fp.parent())
class TestCase (unittest.TestCase):
resource_class = TestFile
def grant(*privileges):
return element.ACL(*[
element.ACE(
element.Grant(element.Privilege(privilege)),
element.Principal(element.All())
)
for privilege in privileges
])
grant = staticmethod(grant)
def grantInherit(*privileges):
return element.ACL(*[
element.ACE(
element.Grant(element.Privilege(privilege)),
element.Principal(element.All()),
TwistedACLInheritable()
)
for privilege in privileges
])
grantInherit = staticmethod(grantInherit)
def createDocumentRoot(self):
docroot = self.mktemp()
os.mkdir(docroot)
rootresource = self.resource_class(docroot)
rootresource.setAccessControlList(self.grantInherit(element.All()))
dirnames = (
os.path.join(docroot, "dir1"), # 0
os.path.join(docroot, "dir2"), # 1
os.path.join(docroot, "dir2", "subdir1"), # 2
os.path.join(docroot, "dir3"), # 3
os.path.join(docroot, "dir4"), # 4
os.path.join(docroot, "dir4", "subdir1"), # 5
os.path.join(docroot, "dir4", "subdir1", "subsubdir1"), # 6
os.path.join(docroot, "dir4", "subdir2"), # 7
os.path.join(docroot, "dir4", "subdir2", "dir1"), # 8
os.path.join(docroot, "dir4", "subdir2", "dir2"), # 9
)
for dir in dirnames:
os.mkdir(dir)
src = os.path.dirname(__file__)
filenames = [
os.path.join(src, f)
for f in os.listdir(src)
if os.path.isfile(os.path.join(src, f))
]
for dirname in (docroot,) + dirnames[3:8 + 1]:
for filename in filenames[:5]:
copy(filename, dirname)
return docroot
def _getDocumentRoot(self):
if not hasattr(self, "_docroot"):
log.info("Setting up docroot for %s" % (self.__class__,))
self._docroot = self.createDocumentRoot()
return self._docroot
def _setDocumentRoot(self, value):
self._docroot = value
docroot = property(_getDocumentRoot, _setDocumentRoot)
def _getSite(self):
if not hasattr(self, "_site"):
rootresource = self.resource_class(self.docroot)
rootresource.setAccessControlList(self.grantInherit(element.All()))
self._site = Site(rootresource)
return self._site
def _setSite(self, site):
self._site = site
site = property(_getSite, _setSite)
def setUp(self):
unittest.TestCase.setUp(self)
TestFile._cachedPropertyStores = {}
def tearDown(self):
unittest.TestCase.tearDown(self)
def mkdtemp(self, prefix):
"""
Creates a new directory in the document root and returns its path and
URI.
"""
path = mkdtemp(prefix=prefix + "_", dir=self.docroot)
uri = joinURL("/", url_quote(os.path.basename(path))) + "/"
return (os.path.abspath(path), uri)
def send(self, request, callback=None):
"""
Invoke the logic involved in traversing a given L{server.Request} as if
a client had sent it; call C{locateResource} to look up the resource to
be rendered, and render it by calling its C{renderHTTP} method.
@param request: A L{server.Request} (generally, to avoid real I/O, a
L{SimpleRequest}) already associated with a site.
@return: asynchronously return a response object or L{None}
@rtype: L{Deferred} firing L{Response} or L{None}
"""
log.info("Sending %s request for URI %s" % (request.method, request.uri))
d = request.locateResource(request.uri)
d.addCallback(lambda resource: resource.renderHTTP(request))
d.addCallback(request._cbFinishRender)
if callback:
if type(callback) is tuple:
d.addCallbacks(*callback)
else:
d.addCallback(callback)
return d
def simpleSend(self, method, path="/", body="", mimetype="text",
subtype="xml", resultcode=responsecode.OK, headers=()):
"""
Assemble and send a simple request using L{SimpleRequest}. This
L{SimpleRequest} is associated with this L{TestCase}'s C{site}
attribute.
@param method: the HTTP method
@type method: L{bytes}
@param path: the absolute path portion of the HTTP URI
@type path: L{bytes}
@param body: the content body of the request
@type body: L{bytes}
@param mimetype: the main type of the mime type of the body of the
request
@type mimetype: L{bytes}
@param subtype: the subtype of the mimetype of the body of the request
@type subtype: L{bytes}
@param resultcode: The expected result code for the response to the
request.
@type resultcode: L{int}
@param headers: An iterable of 2-tuples of C{(header, value)}; headers
to set on the outgoing request.
@return: a L{Deferred} which fires with a L{bytes} if the request was
successfully processed and fails with an L{HTTPError} if not; or,
if the resultcode does not match the response's code, fails with
L{FailTest}.
"""
request = SimpleRequest(self.site, method, path, content=body)
if headers is not None:
for k, v in headers:
request.headers.setHeader(k, v)
request.headers.setHeader("content-type", MimeType(mimetype, subtype))
def checkResult(response):
self.assertEqual(response.code, resultcode)
if response.stream is None:
return None
return allDataFromStream(response.stream)
return self.send(request, None).addCallback(checkResult)
class Site:
# FIXME: There is no ISite interface; there should be.
# implements(ISite)
def __init__(self, resource):
self.resource = resource
def dircmp(dir1, dir2):
dc = DirCompare(dir1, dir2)
return bool(
dc.left_only or dc.right_only or
dc.diff_files or
dc.common_funny or dc.funny_files
)
def serialize(f, work):
d = Deferred()
def oops(error):
d.errback(error)
def do_serialize(_):
try:
args = work.next()
except StopIteration:
d.callback(None)
else:
r = f(*args)
r.addCallback(do_serialize)
r.addErrback(oops)
do_serialize(None)
return d
calendarserver-5.2+dfsg/twext/web2/dav/test/test_delete.py 0000644 0001750 0001750 00000005253 12263343324 022734 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
import os
import urllib
import random
from twext.web2 import responsecode
from twext.web2.iweb import IResponse
from twext.web2.test.test_server import SimpleRequest
from twext.web2.dav.test.util import serialize
import twext.web2.dav.test.util
class DELETE(twext.web2.dav.test.util.TestCase):
"""
DELETE request
"""
# FIXME:
# Try setting unwriteable perms on file, then delete
# Try check response XML for error in some but not all files
def test_DELETE(self):
"""
DELETE request
"""
def check_result(response, path):
response = IResponse(response)
if response.code != responsecode.NO_CONTENT:
self.fail("DELETE response %s != %s" % (response.code, responsecode.NO_CONTENT))
if os.path.exists(path):
self.fail("DELETE did not remove path %s" % (path,))
def work():
for filename in os.listdir(self.docroot):
path = os.path.join(self.docroot, filename)
uri = urllib.quote("/" + filename)
if os.path.isdir(path): uri = uri + "/"
def do_test(response, path=path):
return check_result(response, path)
request = SimpleRequest(self.site, "DELETE", uri)
depth = random.choice(("infinity", None))
if depth is not None:
request.headers.setHeader("depth", depth)
yield (request, do_test)
return serialize(self.send, work())
calendarserver-5.2+dfsg/twext/web2/dav/test/test_report.py 0000644 0001750 0001750 00000005235 12263343324 023005 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from twext.web2.iweb import IResponse
from twext.web2.stream import MemoryStream
from twext.web2 import responsecode
import twext.web2.dav.test.util
from twext.web2.test.test_server import SimpleRequest
from txdav.xml import element as davxml
class REPORT(twext.web2.dav.test.util.TestCase):
"""
REPORT request
"""
def test_REPORT_no_body(self):
"""
REPORT request with no body
"""
def do_test(response):
response = IResponse(response)
if response.code != responsecode.BAD_REQUEST:
self.fail("Unexpected response code for REPORT with no body: %s"
% (response.code,))
request = SimpleRequest(self.site, "REPORT", "/")
request.stream = MemoryStream("")
return self.send(request, do_test)
def test_REPORT_unknown(self):
"""
Unknown/bogus report type
"""
def do_test(response):
response = IResponse(response)
if response.code != responsecode.FORBIDDEN:
self.fail("Unexpected response code for unknown REPORT: %s"
% (response.code,))
class GoofyReport (davxml.WebDAVUnknownElement):
namespace = "GOOFY:"
name = "goofy-report"
def __init__(self): super(GoofyReport, self).__init__()
request = SimpleRequest(self.site, "REPORT", "/")
request.stream = MemoryStream(GoofyReport().toxml())
return self.send(request, do_test)
calendarserver-5.2+dfsg/twext/web2/dav/test/test_auth.py 0000644 0001750 0001750 00000005436 12263343324 022436 0 ustar rahul rahul ##
# Copyright (c) 2012-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
import collections
from twext.web2.dav.auth import AuthenticationWrapper
import twext.web2.dav.test.util
class AutoWrapperTestCase(twext.web2.dav.test.util.TestCase):
def test_basicAuthPrevention(self):
"""
Ensure authentication factories which are not safe to use over an
"unencrypted wire" are not advertised when an insecure (i.e. non-SSL
connection is made.
"""
FakeFactory = collections.namedtuple("FakeFactory", ("scheme,"))
wireEncryptedfactories = [FakeFactory("basic"), FakeFactory("digest"), FakeFactory("xyzzy")]
wireUnencryptedfactories = [FakeFactory("digest"), FakeFactory("xyzzy")]
class FakeChannel(object):
def __init__(self, secure):
self.secure = secure
def getHostInfo(self):
return "ignored", self.secure
class FakeRequest(object):
def __init__(self, secure):
self.portal = None
self.loginInterfaces = None
self.credentialFactories = None
self.chanRequest = FakeChannel(secure)
wrapper = AuthenticationWrapper(None, None,
wireEncryptedfactories, wireUnencryptedfactories, None)
req = FakeRequest(True) # Connection is over SSL
wrapper.hook(req)
self.assertEquals(
set(req.credentialFactories.keys()),
set(["basic", "digest", "xyzzy"])
)
req = FakeRequest(False) # Connection is not over SSL
wrapper.hook(req)
self.assertEquals(
set(req.credentialFactories.keys()),
set(["digest", "xyzzy"])
)
calendarserver-5.2+dfsg/twext/web2/dav/idav.py 0000644 0001750 0001750 00000030224 12263343324 020373 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
web2.dav interfaces.
"""
__all__ = [ "IDAVResource", "IDAVPrincipalResource", "IDAVPrincipalCollectionResource", ]
from twext.web2.iweb import IResource
class IDAVResource(IResource):
"""
WebDAV resource.
"""
def isCollection():
"""
Checks whether this resource is a collection resource.
@return: C{True} if this resource is a collection resource, C{False}
otherwise.
"""
def findChildren(depth, request, callback, privileges, inherited_aces):
"""
Returns an iterable of child resources for the given depth.
Because resources do not know their request URIs, chidren are returned
as tuples C{(resource, uri)}, where C{resource} is the child resource
and C{uri} is a URL path relative to this resource.
@param depth: the search depth (one of C{"0"}, C{"1"}, or C{"infinity"})
@param request: The current L{IRequest} responsible for this call.
@param callback: C{callable} that will be called for each child found
@param privileges: the list of L{Privilege}s to test for. This should
default to None.
@param inherited_aces: a list of L{Privilege}s for aces being inherited from
the parent collection used to bypass inheritance lookup.
@return: An L{Deferred} that fires when all the children have been found
"""
def hasProperty(property, request):
"""
Checks whether the given property is defined on this resource.
@param property: an empty L{davxml.WebDAVElement} instance or a qname
tuple.
@param request: the request being processed.
@return: a deferred value of C{True} if the given property is set on
this resource, or C{False} otherwise.
"""
def readProperty(property, request):
"""
Reads the given property on this resource.
@param property: an empty L{davxml.WebDAVElement} class or instance, or
a qname tuple.
@param request: the request being processed.
@return: a deferred L{davxml.WebDAVElement} instance
containing the value of the given property.
@raise HTTPError: (containing a response with a status code of
L{responsecode.CONFLICT}) if C{property} is not set on this
resource.
"""
def writeProperty(property, request):
"""
Writes the given property on this resource.
@param property: a L{davxml.WebDAVElement} instance.
@param request: the request being processed.
@return: an empty deferred which fires when the operation is completed.
@raise HTTPError: (containing a response with a status code of
L{responsecode.CONFLICT}) if C{property} is a read-only property.
"""
def removeProperty(property, request):
"""
Removes the given property from this resource.
@param property: a L{davxml.WebDAVElement} instance or a qname tuple.
@param request: the request being processed.
@return: an empty deferred which fires when the operation is completed.
@raise HTTPError: (containing a response with a status code of
L{responsecode.CONFLICT}) if C{property} is a read-only property or
if the property does not exist.
"""
def listProperties(request):
"""
@param request: the request being processed.
@return: a deferred iterable of qnames for all properties
defined for this resource.
"""
def supportedReports():
"""
@return: an iterable of L{davxml.Report} elements for each report
supported by this resource.
"""
def authorize(request, privileges, recurse=False):
"""
Verify that the given request is authorized to perform actions that
require the given privileges.
@param request: the request being processed.
@param privileges: an iterable of L{davxml.WebDAVElement} elements
denoting access control privileges.
@param recurse: C{True} if a recursive check on all child
resources of this resource should be performed as well,
C{False} otherwise.
@return: a Deferred which fires with C{None} when authorization is
complete, or errbacks with L{HTTPError} (containing a response with
a status code of L{responsecode.UNAUTHORIZED}) if not authorized.
"""
def principalCollections():
"""
@return: an interable of L{IDAVPrincipalCollectionResource}s which
contain principals used in ACLs for this resource.
"""
def setAccessControlList(acl):
"""
Sets the access control list containing the access control list for
this resource.
@param acl: an L{davxml.ACL} element.
"""
def supportedPrivileges(request):
"""
@param request: the request being processed.
@return: a L{Deferred} with an L{davxml.SupportedPrivilegeSet} result describing
the access control privileges which are supported by this resource.
"""
def currentPrivileges(request):
"""
@param request: the request being processed.
@return: a sequence of the access control privileges which are
set for the currently authenticated user.
"""
def accessControlList(request, inheritance=True, expanding=False):
"""
Obtains the access control list for this resource.
@param request: the request being processed.
@param inheritance: if True, replace inherited privileges with those
from the import resource being inherited from, if False just return
whatever is set in this ACL.
@param expanding: if C{True}, method is called during parent inheritance
expansion, if C{False} then not doing parent expansion.
@return: a deferred L{davxml.ACL} element containing the
access control list for this resource.
"""
def privilegesForPrincipal(principal, request):
"""
Evaluate the set of privileges that apply to the specified principal.
This involves examing all ace's and granting/denying as appropriate for
the specified principal's membership of the ace's prinicpal.
@param request: the request being processed.
@return: a list of L{Privilege}s that are allowed on this resource for
the specified principal.
"""
##
# Quota
##
def quota(request):
"""
Get current available & used quota values for this resource's quota root
collection.
@return: a C{tuple} containing two C{int}'s the first is
quota-available-bytes, the second is quota-used-bytes, or
C{None} if quota is not defined on the resource.
"""
def hasQuota(request):
"""
Check whether this resource is undre quota control by checking each parent to see if
it has a quota root.
@return: C{True} if under quota control, C{False} if not.
"""
def hasQuotaRoot(request):
"""
Determine whether the resource has a quota root.
@return: a C{True} if this resource has quota root, C{False} otherwise.
"""
def quotaRoot(request):
"""
Get the quota root (max. allowed bytes) value for this collection.
@return: a C{int} containing the maximum allowed bytes if this collection
is quota-controlled, or C{None} if not quota controlled.
"""
def setQuotaRoot(request, maxsize):
"""
Set the quota root (max. allowed bytes) value for this collection.
@param maxsize: a C{int} containing the maximum allowed bytes for the contents
of this collection.
"""
def quotaSize(request):
"""
Get the size of this resource (if its a collection get total for all children as well).
TODO: Take into account size of dead-properties.
@return: a L{Deferred} with a C{int} result containing the size of the resource.
"""
def currentQuotaUse(request):
"""
Get the cached quota use value, or if not present (or invalid) determine
quota use by brute force.
@return: an L{Deferred} with a C{int} result containing the current used byte count if
this collection is quota-controlled, or C{None} if not quota controlled.
"""
def updateQuotaUse(request, adjust):
"""
Adjust current quota use on this all all parent collections that also
have quota roots.
@param adjust: a C{int} containing the number of bytes added (positive) or
removed (negative) that should be used to adjust the cached total.
@return: an L{Deferred} with a C{int} result containing the current used byte if this collection
is quota-controlled, or C{None} if not quota controlled.
"""
class IDAVPrincipalResource (IDAVResource):
"""
WebDAV principal resource. (RFC 3744, section 2)
"""
def alternateURIs():
"""
Provides the URIs of network resources with additional descriptive
information about the principal, for example, a URI to an LDAP record.
(RFC 3744, section 4.1)
@return: a iterable of URIs.
"""
def principalURL():
"""
Provides the URL which must be used to identify this principal in ACL
requests. (RFC 3744, section 4.2)
@return: a URL.
"""
def groupMembers():
"""
Provides the principal URLs of principals that are direct members of
this (group) principal. (RFC 3744, section 4.3)
@return: a deferred returning an iterable of principal URLs.
"""
def expandedGroupMembers():
"""
Provides the principal URLs of principals that are members of this
(group) principal, as well as members of any group principal which are
members of this one.
@return: a L{Deferred} that fires with an iterable of principal URLs.
"""
def groupMemberships():
"""
Provides the URLs of the group principals in which the principal is
directly a member. (RFC 3744, section 4.4)
@return: a deferred containing an iterable of group principal URLs.
"""
class IDAVPrincipalCollectionResource(IDAVResource):
"""
WebDAV principal collection resource. (RFC 3744, section 5.8)
"""
def principalCollectionURL():
"""
Provides a URL for this resource which may be used to identify this
resource in ACL requests. (RFC 3744, section 5.8)
@return: a URL.
"""
def principalForUser(user):
"""
Retrieve the principal for a given username.
@param user: the (short) name of a user.
@type user: C{str}
@return: the resource representing the DAV principal resource for the
given username.
@rtype: L{IDAVPrincipalResource}
"""
calendarserver-5.2+dfsg/twext/web2/dav/http.py 0000644 0001750 0001750 00000031076 12263343324 020435 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
##
"""
HTTP Utilities
"""
__all__ = [
"ErrorResponse",
"NeedPrivilegesResponse",
"MultiStatusResponse",
"ResponseQueue",
"PropertyStatusResponseQueue",
"statusForFailure",
"errorForFailure",
"messageForFailure",
]
import errno
from twisted.python.failure import Failure
from twisted.python.filepath import InsecurePath
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.iweb import IResponse
from twext.web2.http import Response, HTTPError, StatusResponse
from twext.web2.http_headers import MimeType
from twext.web2.dav.util import joinURL
from txdav.xml import element
log = Logger()
class ErrorResponse(Response):
"""
A L{Response} object which contains a status code and a L{element.Error}
element.
Renders itself as a DAV:error XML document.
"""
error = None
unregistered = True # base class is already registered
def __init__(self, code, error, description=None):
"""
@param code: a response code.
@param error: an L{WebDAVElement} identifying the error, or a
tuple C{(namespace, name)} with which to create an empty element
denoting the error. (The latter is useful in the case of
preconditions ans postconditions, not all of which have defined
XML element classes.)
@param description: an optional string that, if present, will get
wrapped in a (twisted_dav_namespace, error-description) element.
"""
if type(error) is tuple:
xml_namespace, xml_name = error
error = element.WebDAVUnknownElement()
error.namespace = xml_namespace
error.name = xml_name
self.description = description
if self.description:
output = element.Error(error, element.ErrorDescription(self.description)).toxml()
else:
output = element.Error(error).toxml()
Response.__init__(self, code=code, stream=output)
self.headers.setHeader("content-type", MimeType("text", "xml"))
self.error = error
def __repr__(self):
return "<%s %s %s>" % (self.__class__.__name__, self.code, self.error.sname())
class NeedPrivilegesResponse (ErrorResponse):
def __init__(self, base_uri, errors):
"""
An error response which is due to unsufficient privileges, as
determined by L{DAVResource.checkPrivileges}.
@param base_uri: the base URI for the resources with errors (the URI of
the resource on which C{checkPrivileges} was called).
@param errors: a sequence of tuples, as returned by
C{checkPrivileges}.
"""
denials = []
for subpath, privileges in errors:
if subpath is None:
uri = base_uri
else:
uri = joinURL(base_uri, subpath)
for p in privileges:
denials.append(element.Resource(element.HRef(uri),
element.Privilege(p)))
super(NeedPrivilegesResponse, self).__init__(responsecode.FORBIDDEN, element.NeedPrivileges(*denials))
class MultiStatusResponse (Response):
"""
Multi-status L{Response} object.
Renders itself as a DAV:multi-status XML document.
"""
def __init__(self, xml_responses):
"""
@param xml_responses: an interable of element.Response objects.
"""
Response.__init__(self, code=responsecode.MULTI_STATUS,
stream=element.MultiStatus(*xml_responses).toxml())
self.headers.setHeader("content-type", MimeType("text", "xml"))
class ResponseQueue (object):
"""
Stores a list of (typically error) responses for use in a
L{MultiStatusResponse}.
"""
def __init__(self, path_basename, method, success_response):
"""
@param path_basename: the base path for all responses to be added to the
queue.
All paths for responses added to the queue must start with
C{path_basename}, which will be stripped from the beginning of each
path to determine the response's URI.
@param method: the name of the method generating the queue.
@param success_response: the response to return in lieu of a
L{MultiStatusResponse} if no responses are added to this queue.
"""
self.responses = []
self.path_basename = path_basename
self.path_basename_len = len(path_basename)
self.method = method
self.success_response = success_response
def add(self, path, what):
"""
Add a response.
@param path: a path, which must be a subpath of C{path_basename} as
provided to L{__init__}.
@param what: a status code or a L{Failure} for the given path.
"""
assert path.startswith(self.path_basename), "%s does not start with %s" % (path, self.path_basename)
if type(what) is int:
code = what
error = None
message = responsecode.RESPONSES[code]
elif isinstance(what, Failure):
code = statusForFailure(what)
error = errorForFailure(what)
message = messageForFailure(what)
else:
raise AssertionError("Unknown data type: %r" % (what,))
if code > 400: # Error codes only
log.error("Error during %s for %s: %s" % (self.method, path, message))
uri = path[self.path_basename_len:]
children = []
children.append(element.HRef(uri))
children.append(element.Status.fromResponseCode(code))
if error is not None:
children.append(error)
if message is not None:
children.append(element.ResponseDescription(message))
self.responses.append(element.StatusResponse(*children))
def response(self):
"""
Generate a L{MultiStatusResponse} with the responses contained in the
queue or, if no such responses, return the C{success_response} provided
to L{__init__}.
@return: the response.
"""
if self.responses:
return MultiStatusResponse(self.responses)
else:
return self.success_response
class PropertyStatusResponseQueue (object):
"""
Stores a list of propstat elements for use in a L{Response}
in a L{MultiStatusResponse}.
"""
def __init__(self, method, uri, success_response):
"""
@param method: the name of the method generating the queue.
@param uri: the URI for the response.
@param success_response: the status to return if no
L{PropertyStatus} are added to this queue.
"""
self.method = method
self.uri = uri
self.propstats = []
self.success_response = success_response
def add(self, what, property):
"""
Add a response.
@param what: a status code or a L{Failure} for the given path.
@param property: the property whose status is being reported.
"""
if type(what) is int:
code = what
error = None
message = responsecode.RESPONSES[code]
elif isinstance(what, Failure):
code = statusForFailure(what)
error = errorForFailure(what)
message = messageForFailure(what)
else:
raise AssertionError("Unknown data type: %r" % (what,))
if len(property.children) > 0:
# Re-instantiate as empty element.
property = element.WebDAVUnknownElement.withName(property.namespace, property.name)
if code > 400: # Error codes only
log.error("Error during %s for %s: %s" % (self.method, property, message))
children = []
children.append(element.PropertyContainer(property))
children.append(element.Status.fromResponseCode(code))
if error is not None:
children.append(error)
if message is not None:
children.append(element.ResponseDescription(message))
self.propstats.append(element.PropertyStatus(*children))
def error(self):
"""
Convert any 2xx codes in the propstat responses to 424 Failed Dependency.
"""
for index, propstat in enumerate(self.propstats):
# Check the status
changed_status = False
newchildren = []
for child in propstat.children:
if isinstance(child, element.Status) and (child.code / 100 == 2):
# Change the code
newchildren.append(element.Status.fromResponseCode(responsecode.FAILED_DEPENDENCY))
changed_status = True
elif changed_status and isinstance(child, element.ResponseDescription):
newchildren.append(element.ResponseDescription(responsecode.RESPONSES[responsecode.FAILED_DEPENDENCY]))
else:
newchildren.append(child)
self.propstats[index] = element.PropertyStatus(*newchildren)
def response(self):
"""
Generate a response from the responses contained in the queue or, if
there are no such responses, return the C{success_response} provided to
L{__init__}.
@return: a L{element.PropertyStatusResponse}.
"""
if self.propstats:
return element.PropertyStatusResponse(
element.HRef(self.uri),
*self.propstats
)
else:
return element.StatusResponse(
element.HRef(self.uri),
element.Status.fromResponseCode(self.success_response)
)
##
# Exceptions and response codes
##
def statusForFailure(failure, what=None):
"""
@param failure: a L{Failure}.
@param what: a decription of what was going on when the failure occurred.
If what is not C{None}, emit a cooresponding message via L{log.err}.
@return: a response code cooresponding to the given C{failure}.
"""
def msg(err):
if what is not None:
log.debug("%s while %s" % (err, what))
if failure.check(IOError, OSError):
e = failure.value[0]
if e == errno.EACCES or e == errno.EPERM:
msg("Permission denied")
return responsecode.FORBIDDEN
elif e == errno.ENOSPC:
msg("Out of storage space")
return responsecode.INSUFFICIENT_STORAGE_SPACE
elif e == errno.ENOENT:
msg("Not found")
return responsecode.NOT_FOUND
else:
failure.raiseException()
elif failure.check(NotImplementedError):
msg("Unimplemented error")
return responsecode.NOT_IMPLEMENTED
elif failure.check(InsecurePath):
msg("Insecure path")
return responsecode.FORBIDDEN
elif failure.check(HTTPError):
code = IResponse(failure.value.response).code
msg("%d response" % (code,))
return code
else:
failure.raiseException()
def errorForFailure(failure):
if failure.check(HTTPError) and isinstance(failure.value.response, ErrorResponse):
return element.Error(failure.value.response.error)
else:
return None
def messageForFailure(failure):
if failure.check(HTTPError):
if isinstance(failure.value.response, ErrorResponse):
return failure.value.response.description
elif isinstance(failure.value.response, StatusResponse):
return failure.value.response.description
return str(failure)
calendarserver-5.2+dfsg/twext/web2/dav/method/ 0000755 0001750 0001750 00000000000 12322625325 020355 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/dav/method/report_acl_principal_prop_set.py 0000644 0001750 0001750 00000013716 12263343324 027045 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_report_expand -*-
##
# Copyright (c) 2006-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV acl-prinicpal-prop-set report
"""
__all__ = ["report_DAV__acl_principal_prop_set"]
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import HTTPError, StatusResponse
from txdav.xml import element as davxml
from twext.web2.dav.http import ErrorResponse
from twext.web2.dav.http import MultiStatusResponse
from twext.web2.dav.method import prop_common
from twext.web2.dav.method.report import NumberOfMatchesWithinLimits
from twext.web2.dav.method.report import max_number_of_matches
log = Logger()
def report_DAV__acl_principal_prop_set(self, request, acl_prinicpal_prop_set):
"""
Generate an acl-prinicpal-prop-set REPORT. (RFC 3744, section 9.2)
"""
# Verify root element
if not isinstance(acl_prinicpal_prop_set, davxml.ACLPrincipalPropSet):
raise ValueError("%s expected as root element, not %s."
% (davxml.ACLPrincipalPropSet.sname(), acl_prinicpal_prop_set.sname()))
# Depth must be "0"
depth = request.headers.getHeader("depth", "0")
if depth != "0":
log.error("Error in prinicpal-prop-set REPORT, Depth set to %s" % (depth,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,)))
#
# Check authentication and access controls
#
x = waitForDeferred(self.authorize(request, (davxml.ReadACL(),)))
yield x
x.getResult()
# Get a single DAV:prop element from the REPORT request body
propertiesForResource = None
propElement = None
for child in acl_prinicpal_prop_set.children:
if child.qname() == ("DAV:", "prop"):
if propertiesForResource is not None:
log.error("Only one DAV:prop element allowed")
raise HTTPError(StatusResponse(
responsecode.BAD_REQUEST,
"Only one DAV:prop element allowed"
))
propertiesForResource = prop_common.propertyListForResource
propElement = child
if propertiesForResource is None:
log.error("Error in acl-principal-prop-set REPORT, no DAV:prop element")
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "No DAV:prop element"))
# Enumerate principals on ACL in current resource
principals = []
acl = waitForDeferred(self.accessControlList(request))
yield acl
acl = acl.getResult()
for ace in acl.children:
resolved = waitForDeferred(self.resolvePrincipal(ace.principal.children[0], request))
yield resolved
resolved = resolved.getResult()
if resolved is not None and resolved not in principals:
principals.append(resolved)
# Run report for each referenced principal
try:
responses = []
matchcount = 0
for principal in principals:
# Check size of results is within limit
matchcount += 1
if matchcount > max_number_of_matches:
raise NumberOfMatchesWithinLimits(max_number_of_matches)
resource = waitForDeferred(request.locateResource(str(principal)))
yield resource
resource = resource.getResult()
if resource is not None:
#
# Check authentication and access controls
#
x = waitForDeferred(resource.authorize(request, (davxml.Read(),)))
yield x
try:
x.getResult()
except HTTPError:
responses.append(davxml.StatusResponse(
principal,
davxml.Status.fromResponseCode(responsecode.FORBIDDEN)
))
else:
d = waitForDeferred(prop_common.responseForHref(
request,
responses,
principal,
resource,
propertiesForResource,
propElement
))
yield d
d.getResult()
else:
log.error("Requested principal resource not found: %s" % (str(principal),))
responses.append(davxml.StatusResponse(
principal,
davxml.Status.fromResponseCode(responsecode.NOT_FOUND)
))
except NumberOfMatchesWithinLimits:
log.error("Too many matching components")
raise HTTPError(ErrorResponse(
responsecode.FORBIDDEN,
davxml.NumberOfMatchesWithinLimits()
))
yield MultiStatusResponse(responses)
report_DAV__acl_principal_prop_set = deferredGenerator(report_DAV__acl_principal_prop_set)
calendarserver-5.2+dfsg/twext/web2/dav/method/put.py 0000644 0001750 0001750 00000007420 12263343324 021542 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_put -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV PUT method
"""
__all__ = ["preconditions_PUT", "http_PUT"]
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import HTTPError, StatusResponse
from txdav.xml import element as davxml
from twext.web2.dav.method import put_common
from twext.web2.dav.util import parentForURL
log = Logger()
def preconditions_PUT(self, request):
#
# Check authentication and access controls
#
if self.exists():
x = waitForDeferred(self.authorize(request, (davxml.WriteContent(),)))
yield x
x.getResult()
else:
parent = waitForDeferred(request.locateResource(parentForURL(request.uri)))
yield parent
parent = parent.getResult()
if not parent.exists():
raise HTTPError(
StatusResponse(
responsecode.CONFLICT,
"cannot PUT to non-existent parent"))
x = waitForDeferred(parent.authorize(request, (davxml.Bind(),)))
yield x
x.getResult()
#
# HTTP/1.1 (RFC 2068, section 9.6) requires that we respond with a Not
# Implemented error if we get a Content-* header which we don't
# recognize and handle properly.
#
for header, value in request.headers.getAllRawHeaders():
if header.startswith("Content-") and header not in (
#"Content-Base", # Doesn't make sense in PUT?
#"Content-Encoding", # Requires that we decode it?
"Content-Language",
"Content-Length",
#"Content-Location", # Doesn't make sense in PUT?
"Content-MD5",
#"Content-Range", # FIXME: Need to implement this
"Content-Type",
):
log.error("Client sent unrecognized content header in PUT request: %s"
% (header,))
raise HTTPError(StatusResponse(
responsecode.NOT_IMPLEMENTED,
"Unrecognized content header %r in request." % (header,)
))
preconditions_PUT = deferredGenerator(preconditions_PUT)
def http_PUT(self, request):
"""
Respond to a PUT request. (RFC 2518, section 8.7)
"""
log.info("Writing request stream to %s" % (self,))
#
# Don't pass in the request URI, since PUT isn't specified to be able
# to return a MULTI_STATUS response, which is WebDAV-specific (and PUT is
# not).
#
#return put(request.stream, self.fp)
return put_common.storeResource(request, destination=self, destination_uri=request.uri)
calendarserver-5.2+dfsg/twext/web2/dav/method/mkcol.py 0000644 0001750 0001750 00000005501 12263343324 022035 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_mkcol -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV MKCOL method
"""
__all__ = ["http_MKCOL"]
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import HTTPError, StatusResponse
from txdav.xml import element as davxml
from twext.web2.dav.fileop import mkcollection
from twext.web2.dav.util import noDataFromStream, parentForURL
log = Logger()
def http_MKCOL(self, request):
"""
Respond to a MKCOL request. (RFC 2518, section 8.3)
"""
parent = waitForDeferred(request.locateResource(parentForURL(request.uri)))
yield parent
parent = parent.getResult()
x = waitForDeferred(parent.authorize(request, (davxml.Bind(),)))
yield x
x.getResult()
if self.exists():
log.error("Attempt to create collection where file exists: %s"
% (self,))
raise HTTPError(responsecode.NOT_ALLOWED)
if not parent.isCollection():
log.error("Attempt to create collection with non-collection parent: %s"
% (self,))
raise HTTPError(StatusResponse(
responsecode.CONFLICT,
"Parent resource is not a collection."
))
#
# Read request body
#
x = waitForDeferred(noDataFromStream(request.stream))
yield x
try:
x.getResult()
except ValueError, e:
log.error("Error while handling MKCOL body: %s" % (e,))
raise HTTPError(responsecode.UNSUPPORTED_MEDIA_TYPE)
response = waitForDeferred(mkcollection(self.fp))
yield response
yield response.getResult()
http_MKCOL = deferredGenerator(http_MKCOL)
calendarserver-5.2+dfsg/twext/web2/dav/method/copymove.py 0000644 0001750 0001750 00000022224 12263343324 022572 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_copy,twext.web2.dav.test.test_move -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV COPY and MOVE methods.
"""
__all__ = ["http_COPY", "http_MOVE"]
from twisted.internet.defer import waitForDeferred, deferredGenerator
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.dav.fileop import move
from twext.web2.http import HTTPError, StatusResponse
from twext.web2.filter.location import addLocation
from txdav.xml import element as davxml
from twext.web2.dav.idav import IDAVResource
from twext.web2.dav.method import put_common
from twext.web2.dav.util import parentForURL
# FIXME: This is circular
import twext.web2.dav.static
log = Logger()
def http_COPY(self, request):
"""
Respond to a COPY request. (RFC 2518, section 8.8)
"""
r = waitForDeferred(prepareForCopy(self, request))
yield r
r = r.getResult()
destination, destination_uri, depth = r
#
# Check authentication and access controls
#
x = waitForDeferred(self.authorize(request, (davxml.Read(),), recurse=True))
yield x
x.getResult()
if destination.exists():
x = waitForDeferred(destination.authorize(
request,
(davxml.WriteContent(), davxml.WriteProperties()),
recurse=True
))
yield x
x.getResult()
else:
destparent = waitForDeferred(request.locateResource(parentForURL(destination_uri)))
yield destparent
destparent = destparent.getResult()
x = waitForDeferred(destparent.authorize(request, (davxml.Bind(),)))
yield x
x.getResult()
# May need to add a location header
addLocation(request, destination_uri)
#x = waitForDeferred(copy(self.fp, destination.fp, destination_uri, depth))
x = waitForDeferred(put_common.storeResource(request,
source=self,
source_uri=request.uri,
destination=destination,
destination_uri=destination_uri,
deletesource=False,
depth=depth
))
yield x
yield x.getResult()
http_COPY = deferredGenerator(http_COPY)
def http_MOVE(self, request):
"""
Respond to a MOVE request. (RFC 2518, section 8.9)
"""
r = waitForDeferred(prepareForCopy(self, request))
yield r
r = r.getResult()
destination, destination_uri, depth = r
#
# Check authentication and access controls
#
parentURL = parentForURL(request.uri)
parent = waitForDeferred(request.locateResource(parentURL))
yield parent
parent = parent.getResult()
x = waitForDeferred(parent.authorize(request, (davxml.Unbind(),)))
yield x
x.getResult()
if destination.exists():
x = waitForDeferred(destination.authorize(
request,
(davxml.Bind(), davxml.Unbind()),
recurse=True
))
yield x
x.getResult()
else:
destparentURL = parentForURL(destination_uri)
destparent = waitForDeferred(request.locateResource(destparentURL))
yield destparent
destparent = destparent.getResult()
x = waitForDeferred(destparent.authorize(request, (davxml.Bind(),)))
yield x
x.getResult()
# May need to add a location header
addLocation(request, destination_uri)
#
# RFC 2518, section 8.9 says that we must act as if the Depth header is set
# to infinity, and that the client must omit the Depth header or set it to
# infinity.
#
# This seems somewhat at odds with the notion that a bad request should be
# rejected outright; if the client sends a bad depth header, the client is
# broken, and section 8 suggests that a bad request should be rejected...
#
# Let's play it safe for now and ignore broken clients.
#
if self.isCollection() and depth != "infinity":
msg = "Client sent illegal depth header value for MOVE: %s" % (depth,)
log.error(msg)
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg))
# Lets optimise a move within the same directory to a new resource as a simple move
# rather than using the full transaction based storeResource api. This allows simple
# "rename" operations to work quickly.
if (not destination.exists()) and destparent == parent:
x = waitForDeferred(move(self.fp, request.uri, destination.fp, destination_uri, depth))
else:
x = waitForDeferred(put_common.storeResource(request,
source=self,
source_uri=request.uri,
destination=destination,
destination_uri=destination_uri,
deletesource=True,
depth=depth))
yield x
yield x.getResult()
http_MOVE = deferredGenerator(http_MOVE)
def prepareForCopy(self, request):
#
# Get the depth
#
depth = request.headers.getHeader("depth", "infinity")
if depth not in ("0", "infinity"):
msg = ("Client sent illegal depth header value: %s" % (depth,))
log.error(msg)
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg))
#
# Verify this resource exists
#
if not self.exists():
log.error("File not found: %s" % (self,))
raise HTTPError(StatusResponse(
responsecode.NOT_FOUND,
"Source resource %s not found." % (request.uri,)
))
#
# Get the destination
#
destination_uri = request.headers.getHeader("destination")
if not destination_uri:
msg = "No destination header in %s request." % (request.method,)
log.error(msg)
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg))
d = request.locateResource(destination_uri)
d.addCallback(_prepareForCopy, destination_uri, request, depth)
return d
def _prepareForCopy(destination, destination_uri, request, depth):
#
# Destination must be a DAV resource
#
try:
destination = IDAVResource(destination)
except TypeError:
log.error("Attempt to %s to a non-DAV resource: (%s) %s"
% (request.method, destination.__class__, destination_uri))
raise HTTPError(StatusResponse(
responsecode.FORBIDDEN,
"Destination %s is not a WebDAV resource." % (destination_uri,)
))
#
# FIXME: Right now we don't know how to copy to a non-DAVFile resource.
# We may need some more API in IDAVResource.
# So far, we need: .exists(), .fp.parent()
#
if not isinstance(destination, twext.web2.dav.static.DAVFile):
log.error("DAV copy between non-DAVFile DAV resources isn't implemented")
raise HTTPError(StatusResponse(
responsecode.NOT_IMPLEMENTED,
"Destination %s is not a DAVFile resource." % (destination_uri,)
))
#
# Check for existing destination resource
#
overwrite = request.headers.getHeader("overwrite", True)
if destination.exists() and not overwrite:
log.error("Attempt to %s onto existing file without overwrite flag enabled: %s"
% (request.method, destination))
raise HTTPError(StatusResponse(
responsecode.PRECONDITION_FAILED,
"Destination %s already exists." % (destination_uri,)
))
#
# Make sure destination's parent exists
#
if not destination.parent().isCollection():
log.error("Attempt to %s to a resource with no parent: %s"
% (request.method, destination.fp.path))
raise HTTPError(StatusResponse(responsecode.CONFLICT, "No parent collection."))
return destination, destination_uri, depth
calendarserver-5.2+dfsg/twext/web2/dav/method/report_principal_property_search.py 0000644 0001750 0001750 00000017464 12263343324 027610 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_report_expand -*-
##
# Copyright (c) 2006-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV prinicpal-property-search report
"""
__all__ = ["report_DAV__principal_property_search"]
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import HTTPError, StatusResponse
from txdav.xml.base import PCDATAElement
from txdav.xml import element
from txdav.xml.element import dav_namespace
from twext.web2.dav.http import ErrorResponse, MultiStatusResponse
from twext.web2.dav.method import prop_common
from twext.web2.dav.method.report import NumberOfMatchesWithinLimits
from twext.web2.dav.method.report import max_number_of_matches
from twext.web2.dav.resource import isPrincipalResource
log = Logger()
def report_DAV__principal_property_search(self, request, principal_property_search):
"""
Generate a principal-property-search REPORT. (RFC 3744, section 9.4)
"""
# Verify root element
if not isinstance(principal_property_search, element.PrincipalPropertySearch):
raise ValueError("%s expected as root element, not %s."
% (element.PrincipalPropertySearch.sname(), principal_property_search.sname()))
# Only handle Depth: 0
depth = request.headers.getHeader("depth", "0")
if depth != "0":
log.error("Error in prinicpal-property-search REPORT, Depth set to %s" % (depth,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,)))
# Get a single DAV:prop element from the REPORT request body
propertiesForResource = None
propElement = None
propertySearches = []
applyTo = False
for child in principal_property_search.children:
if child.qname() == (dav_namespace, "prop"):
propertiesForResource = prop_common.propertyListForResource
propElement = child
elif child.qname() == (dav_namespace, "apply-to-principal-collection-set"):
applyTo = True
elif child.qname() == (dav_namespace, "property-search"):
props = child.childOfType(element.PropertyContainer)
props.removeWhitespaceNodes()
match = child.childOfType(element.Match)
propertySearches.append((props.children, str(match).lower()))
def nodeMatch(node, match):
"""
See if the content of the supplied node matches the supplied text.
Try to follow the matching guidance in rfc3744 section 9.4.1.
@param prop: the property element to match.
@param match: the text to match against.
@return: True if the property matches, False otherwise.
"""
node.removeWhitespaceNodes()
for child in node.children:
if isinstance(child, PCDATAElement):
comp = str(child).lower()
if comp.find(match) != -1:
return True
else:
return nodeMatch(child, match)
else:
return False
def propertySearch(resource, request):
"""
Test the resource to see if it contains properties matching the
property-search specification in this report.
@param resource: the L{DAVFile} for the resource to test.
@param request: the current request.
@return: True if the resource has matching properties, False otherwise.
"""
for props, match in propertySearches:
# Test each property
for prop in props:
try:
propvalue = waitForDeferred(resource.readProperty(prop.qname(), request))
yield propvalue
propvalue = propvalue.getResult()
if propvalue and not nodeMatch(propvalue, match):
yield False
return
except HTTPError:
# No property => no match
yield False
return
yield True
propertySearch = deferredGenerator(propertySearch)
# Run report
try:
resources = []
responses = []
matchcount = 0
if applyTo:
for principalCollection in self.principalCollections():
uri = principalCollection.principalCollectionURL()
resource = waitForDeferred(request.locateResource(uri))
yield resource
resource = resource.getResult()
if resource:
resources.append((resource, uri))
else:
resources.append((self, request.uri))
# Loop over all collections and principal resources within
for resource, ruri in resources:
# Do some optimisation of access control calculation by determining any inherited ACLs outside of
# the child resource loop and supply those to the checkPrivileges on each child.
filteredaces = waitForDeferred(resource.inheritedACEsforChildren(request))
yield filteredaces
filteredaces = filteredaces.getResult()
children = []
d = waitForDeferred(resource.findChildren("infinity", request, lambda x, y: children.append((x,y)),
privileges=(element.Read(),), inherited_aces=filteredaces))
yield d
d.getResult()
for child, uri in children:
if isPrincipalResource(child):
d = waitForDeferred(propertySearch(child, request))
yield d
d = d.getResult()
if d:
# Check size of results is within limit
matchcount += 1
if matchcount > max_number_of_matches:
raise NumberOfMatchesWithinLimits(max_number_of_matches)
d = waitForDeferred(prop_common.responseForHref(
request,
responses,
element.HRef.fromString(uri),
child,
propertiesForResource,
propElement
))
yield d
d.getResult()
except NumberOfMatchesWithinLimits:
log.error("Too many matching components in prinicpal-property-search report")
raise HTTPError(ErrorResponse(
responsecode.FORBIDDEN,
element.NumberOfMatchesWithinLimits()
))
yield MultiStatusResponse(responses)
report_DAV__principal_property_search = deferredGenerator(report_DAV__principal_property_search)
calendarserver-5.2+dfsg/twext/web2/dav/method/delete_common.py 0000644 0001750 0001750 00000004643 12263343324 023550 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_delete -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
##
"""
WebDAV DELETE method
"""
__all__ = ["deleteResource"]
from twisted.internet.defer import waitForDeferred, deferredGenerator
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import HTTPError
from twext.web2.dav.fileop import delete
log = Logger()
def deleteResource(request, resource, resource_uri, depth="0"):
"""
Handle a resource delete with proper quota etc updates
"""
if not resource.exists():
log.error("File not found: %s" % (resource,))
raise HTTPError(responsecode.NOT_FOUND)
# Do quota checks before we start deleting things
myquota = waitForDeferred(resource.quota(request))
yield myquota
myquota = myquota.getResult()
if myquota is not None:
old_size = waitForDeferred(resource.quotaSize(request))
yield old_size
old_size = old_size.getResult()
else:
old_size = 0
# Do delete
x = waitForDeferred(delete(resource_uri, resource.fp, depth))
yield x
result = x.getResult()
# Adjust quota
if myquota is not None:
d = waitForDeferred(resource.quotaSizeAdjust(request, -old_size))
yield d
d.getResult()
yield result
deleteResource = deferredGenerator(deleteResource)
calendarserver-5.2+dfsg/twext/web2/dav/method/put_common.py 0000644 0001750 0001750 00000025511 12263343324 023113 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# DRI: Cyrus Daboo, cdaboo@apple.com
##
"""
PUT/COPY/MOVE common behavior.
"""
__version__ = "0.0"
__all__ = ["storeResource"]
from twisted.python.failure import Failure
from twext.python.filepath import CachingFilePath as FilePath
from twisted.internet.defer import deferredGenerator, maybeDeferred, waitForDeferred
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.dav.fileop import copy, delete, put
from twext.web2.dav.http import ErrorResponse
from twext.web2.dav.resource import TwistedGETContentMD5
from twext.web2.stream import MD5Stream
from twext.web2.http import HTTPError
from twext.web2.http_headers import generateContentType
from twext.web2.iweb import IResponse
from twext.web2.stream import MemoryStream
from txdav.xml import element as davxml
from txdav.xml.base import dav_namespace
log = Logger()
def storeResource(
request,
source=None, source_uri=None, data=None,
destination=None, destination_uri=None,
deletesource=False,
depth="0"
):
"""
Function that does common PUT/COPY/MOVE behaviour.
@param request: the L{twext.web2.server.Request} for the current HTTP request.
@param source: the L{DAVFile} for the source resource to copy from, or None if source data
is to be read from the request.
@param source_uri: the URI for the source resource.
@param data: a C{str} to copy data from instead of the request stream.
@param destination: the L{DAVFile} for the destination resource to copy into.
@param destination_uri: the URI for the destination resource.
@param deletesource: True if the source resource is to be deleted on successful completion, False otherwise.
@param depth: a C{str} containing the COPY/MOVE Depth header value.
@return: status response.
"""
try:
assert request is not None and destination is not None and destination_uri is not None
assert (source is None) or (source is not None and source_uri is not None)
assert not deletesource or (deletesource and source is not None)
except AssertionError:
log.error("Invalid arguments to storeResource():")
log.error("request=%s\n" % (request,))
log.error("source=%s\n" % (source,))
log.error("source_uri=%s\n" % (source_uri,))
log.error("data=%s\n" % (data,))
log.error("destination=%s\n" % (destination,))
log.error("destination_uri=%s\n" % (destination_uri,))
log.error("deletesource=%s\n" % (deletesource,))
log.error("depth=%s\n" % (depth,))
raise
class RollbackState(object):
"""
This class encapsulates the state needed to rollback the entire PUT/COPY/MOVE
transaction, leaving the server state the same as it was before the request was
processed. The DoRollback method will actually execute the rollback operations.
"""
def __init__(self):
self.active = True
self.source_copy = None
self.destination_copy = None
self.destination_created = False
self.source_deleted = False
def Rollback(self):
"""
Rollback the server state. Do not allow this to raise another exception. If
rollback fails then we are going to be left in an awkward state that will need
to be cleaned up eventually.
"""
if self.active:
self.active = False
log.error("Rollback: rollback")
try:
if self.source_copy and self.source_deleted:
self.source_copy.moveTo(source.fp)
log.error("Rollback: source restored %s to %s" % (self.source_copy.path, source.fp.path))
self.source_copy = None
self.source_deleted = False
if self.destination_copy:
destination.fp.remove()
log.error("Rollback: destination restored %s to %s" % (self.destination_copy.path, destination.fp.path))
self.destination_copy.moveTo(destination.fp)
self.destination_copy = None
elif self.destination_created:
destination.fp.remove()
log.error("Rollback: destination removed %s" % (destination.fp.path,))
self.destination_created = False
except:
log.error("Rollback: exception caught and not handled: %s" % Failure())
def Commit(self):
"""
Commit the resource changes by wiping the rollback state.
"""
if self.active:
log.error("Rollback: commit")
self.active = False
if self.source_copy:
self.source_copy.remove()
log.error("Rollback: removed source backup %s" % (self.source_copy.path,))
self.source_copy = None
if self.destination_copy:
self.destination_copy.remove()
log.error("Rollback: removed destination backup %s" % (self.destination_copy.path,))
self.destination_copy = None
self.destination_created = False
self.source_deleted = False
rollback = RollbackState()
try:
"""
Handle validation operations here.
"""
"""
Handle rollback setup here.
"""
# Do quota checks on destination and source before we start messing with adding other files
destquota = waitForDeferred(destination.quota(request))
yield destquota
destquota = destquota.getResult()
if destquota is not None and destination.exists():
old_dest_size = waitForDeferred(destination.quotaSize(request))
yield old_dest_size
old_dest_size = old_dest_size.getResult()
else:
old_dest_size = 0
if source is not None:
sourcequota = waitForDeferred(source.quota(request))
yield sourcequota
sourcequota = sourcequota.getResult()
if sourcequota is not None and source.exists():
old_source_size = waitForDeferred(source.quotaSize(request))
yield old_source_size
old_source_size = old_source_size.getResult()
else:
old_source_size = 0
else:
sourcequota = None
old_source_size = 0
# We may need to restore the original resource data if the PUT/COPY/MOVE fails,
# so rename the original file in case we need to rollback.
overwrite = destination.exists()
if overwrite:
rollback.destination_copy = FilePath(destination.fp.path)
rollback.destination_copy.path += ".rollback"
destination.fp.copyTo(rollback.destination_copy)
else:
rollback.destination_created = True
if deletesource:
rollback.source_copy = FilePath(source.fp.path)
rollback.source_copy.path += ".rollback"
source.fp.copyTo(rollback.source_copy)
"""
Handle actual store operations here.
"""
# Do put or copy based on whether source exists
if source is not None:
response = maybeDeferred(copy, source.fp, destination.fp, destination_uri, depth)
else:
datastream = request.stream
if data is not None:
datastream = MemoryStream(data)
md5 = MD5Stream(datastream)
response = maybeDeferred(put, md5, destination.fp)
response = waitForDeferred(response)
yield response
response = response.getResult()
# Update the MD5 value on the resource
if source is not None:
# Copy MD5 value from source to destination
if source.hasDeadProperty(TwistedGETContentMD5):
md5 = source.readDeadProperty(TwistedGETContentMD5)
destination.writeDeadProperty(md5)
else:
# Finish MD5 calc and write dead property
md5.close()
md5 = md5.getMD5()
destination.writeDeadProperty(TwistedGETContentMD5.fromString(md5))
# Update the content-type value on the resource if it is not been copied or moved
if source is None:
content_type = request.headers.getHeader("content-type")
if content_type is not None:
destination.writeDeadProperty(davxml.GETContentType.fromString(generateContentType(content_type)))
response = IResponse(response)
# Do quota check on destination
if destquota is not None:
# Get size of new/old resources
new_dest_size = waitForDeferred(destination.quotaSize(request))
yield new_dest_size
new_dest_size = new_dest_size.getResult()
diff_size = new_dest_size - old_dest_size
if diff_size >= destquota[0]:
log.error("Over quota: available %d, need %d" % (destquota[0], diff_size))
raise HTTPError(ErrorResponse(
responsecode.INSUFFICIENT_STORAGE_SPACE,
(dav_namespace, "quota-not-exceeded")
))
d = waitForDeferred(destination.quotaSizeAdjust(request, diff_size))
yield d
d.getResult()
if deletesource:
# Delete the source resource
if sourcequota is not None:
delete_size = 0 - old_source_size
d = waitForDeferred(source.quotaSizeAdjust(request, delete_size))
yield d
d.getResult()
delete(source_uri, source.fp, depth)
rollback.source_deleted = True
# Can now commit changes and forget the rollback details
rollback.Commit()
yield response
return
except:
# Roll back changes to original server state. Note this may do nothing
# if the rollback has already ocurred or changes already committed.
rollback.Rollback()
raise
storeResource = deferredGenerator(storeResource)
calendarserver-5.2+dfsg/twext/web2/dav/method/delete.py 0000644 0001750 0001750 00000004514 12263343324 022175 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_delete -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV DELETE method
"""
__all__ = ["http_DELETE"]
from twisted.internet.defer import waitForDeferred, deferredGenerator
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import HTTPError
from txdav.xml import element as davxml
from twext.web2.dav.method.delete_common import deleteResource
from twext.web2.dav.util import parentForURL
log = Logger()
def http_DELETE(self, request):
"""
Respond to a DELETE request. (RFC 2518, section 8.6)
"""
if not self.exists():
log.error("File not found: %s" % (self,))
raise HTTPError(responsecode.NOT_FOUND)
depth = request.headers.getHeader("depth", "infinity")
#
# Check authentication and access controls
#
parent = waitForDeferred(request.locateResource(parentForURL(request.uri)))
yield parent
parent = parent.getResult()
x = waitForDeferred(parent.authorize(request, (davxml.Unbind(),)))
yield x
x.getResult()
x = waitForDeferred(deleteResource(request, self, request.uri, depth))
yield x
yield x.getResult()
http_DELETE = deferredGenerator(http_DELETE)
calendarserver-5.2+dfsg/twext/web2/dav/method/prop_common.py 0000644 0001750 0001750 00000007417 12147725751 023301 0 ustar rahul rahul ##
# Cyrus Daboo, cdaboo@apple.com
# Copyright 2006-2012 Apple Computer, Inc. All Rights Reserved.
##
__all__ = [
"responseForHref",
"propertyListForResource",
]
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twisted.python.failure import Failure
from twext.python.log import Logger
from twext.web2 import responsecode
from txdav.xml import element
from twext.web2.dav.http import statusForFailure
from twext.web2.dav.method.propfind import propertyName
log = Logger()
def responseForHref(request, responses, href, resource, propertiesForResource, propertyreq):
if propertiesForResource is not None:
properties_by_status = waitForDeferred(propertiesForResource(request, propertyreq, resource))
yield properties_by_status
properties_by_status = properties_by_status.getResult()
propstats = []
for status in properties_by_status:
properties = properties_by_status[status]
if properties:
xml_status = element.Status.fromResponseCode(status)
xml_container = element.PropertyContainer(*properties)
xml_propstat = element.PropertyStatus(xml_container, xml_status)
propstats.append(xml_propstat)
if propstats:
responses.append(element.PropertyStatusResponse(href, *propstats))
else:
responses.append(
element.StatusResponse(
href,
element.Status.fromResponseCode(responsecode.OK),
)
)
responseForHref = deferredGenerator(responseForHref)
def propertyListForResource(request, prop, resource):
"""
Return the specified properties on the specified resource.
@param request: the L{IRequest} for the current request.
@param prop: the L{PropertyContainer} element for the properties of interest.
@param resource: the L{DAVFile} for the targetted resource.
@return: a map of OK and NOT FOUND property values.
"""
return _namedPropertiesForResource(request, prop.children, resource)
def _namedPropertiesForResource(request, props, resource):
"""
Return the specified properties on the specified resource.
@param request: the L{IRequest} for the current request.
@param props: a list of property elements or qname tuples for the properties of interest.
@param resource: the L{DAVFile} for the targetted resource.
@return: a map of OK and NOT FOUND property values.
"""
properties_by_status = {
responsecode.OK : [],
responsecode.NOT_FOUND : [],
}
for property in props:
if isinstance(property, element.WebDAVElement):
qname = property.qname()
else:
qname = property
props = waitForDeferred(resource.listProperties(request))
yield props
props = props.getResult()
if qname in props:
try:
prop = waitForDeferred(resource.readProperty(qname, request))
yield prop
prop = prop.getResult()
properties_by_status[responsecode.OK].append(prop)
except:
f = Failure()
status = statusForFailure(f, "getting property: %s" % (qname,))
if status != responsecode.NOT_FOUND:
log.error("Error reading property %r for resource %s: %s" %
(qname, request.uri, f.value))
if status not in properties_by_status: properties_by_status[status] = []
properties_by_status[status].append(propertyName(qname))
else:
properties_by_status[responsecode.NOT_FOUND].append(propertyName(qname))
yield properties_by_status
_namedPropertiesForResource = deferredGenerator(_namedPropertiesForResource)
calendarserver-5.2+dfsg/twext/web2/dav/method/get.py 0000644 0001750 0001750 00000004261 12263343324 021511 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_lock -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV GET and HEAD methods
"""
__all__ = ["http_OPTIONS", "http_HEAD", "http_GET"]
import twext
from txdav.xml import element as davxml
from twext.web2.dav.util import parentForURL
def http_OPTIONS(self, request):
d = authorize(self, request)
d.addCallback(lambda _: super(twext.web2.dav.resource.DAVResource, self).http_OPTIONS(request))
return d
def http_HEAD(self, request):
d = authorize(self, request)
d.addCallback(lambda _: super(twext.web2.dav.resource.DAVResource, self).http_HEAD(request))
return d
def http_GET(self, request):
d = authorize(self, request)
d.addCallback(lambda _: super(twext.web2.dav.resource.DAVResource, self).http_GET(request))
return d
def authorize(self, request):
if self.exists():
d = self.authorize(request, (davxml.Read(),))
else:
d = request.locateResource(parentForURL(request.uri))
d.addCallback(lambda parent: parent.authorize(request, (davxml.Bind(),)))
return d
calendarserver-5.2+dfsg/twext/web2/dav/method/report_principal_match.py 0000644 0001750 0001750 00000021374 12263343324 025466 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_report_expand -*-
##
# Copyright (c) 2006-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV principal-match report
"""
__all__ = ["report_DAV__principal_match"]
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import StatusResponse, HTTPError
from txdav.xml import element
from txdav.xml.element import dav_namespace
from twext.web2.dav.http import ErrorResponse, MultiStatusResponse
from twext.web2.dav.method import prop_common
from twext.web2.dav.method.report import NumberOfMatchesWithinLimits
from twext.web2.dav.method.report import max_number_of_matches
from twext.web2.dav.resource import isPrincipalResource
log = Logger()
def report_DAV__principal_match(self, request, principal_match):
"""
Generate a principal-match REPORT. (RFC 3744, section 9.3)
"""
# Verify root element
if not isinstance(principal_match, element.PrincipalMatch):
raise ValueError("%s expected as root element, not %s."
% (element.PrincipalMatch.sname(), principal_match.sname()))
# Only handle Depth: 0
depth = request.headers.getHeader("depth", "0")
if depth != "0":
log.error("Non-zero depth is not allowed: %s" % (depth,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,)))
# Get a single DAV:prop element from the REPORT request body
propertiesForResource = None
propElement = None
principalPropElement = None
lookForPrincipals = True
for child in principal_match.children:
if child.qname() == (dav_namespace, "prop"):
propertiesForResource = prop_common.propertyListForResource
propElement = child
elif child.qname() == (dav_namespace, "self"):
lookForPrincipals = True
elif child.qname() == (dav_namespace, "principal-property"):
# Must have one and only one property in this element
if len(child.children) != 1:
log.error("Wrong number of properties in DAV:principal-property: %s"
% (len(child.children),))
raise HTTPError(StatusResponse(
responsecode.BAD_REQUEST,
"DAV:principal-property must contain exactly one property"
))
lookForPrincipals = False
principalPropElement = child.children[0]
# Run report for each referenced principal
try:
responses = []
matchcount = 0
myPrincipalURL = self.currentPrincipal(request).children[0]
if lookForPrincipals:
# Find the set of principals that represent "self".
# First add "self"
principal = waitForDeferred(request.locateResource(str(myPrincipalURL)))
yield principal
principal = principal.getResult()
selfItems = [principal,]
# Get group memberships for "self" and add each of those
d = waitForDeferred(principal.groupMemberships())
yield d
memberships = d.getResult()
selfItems.extend(memberships)
# Now add each principal found to the response provided the principal resource is a child of
# the current resource.
for principal in selfItems:
# Get all the URIs that point to the principal resource
# FIXME: making the assumption that the principalURL() is the URL of the resource we found
principal_uris = [principal.principalURL()]
principal_uris.extend(principal.alternateURIs())
# Compare each one to the request URI and return at most one that matches
for uri in principal_uris:
if uri.startswith(request.uri):
# Check size of results is within limit
matchcount += 1
if matchcount > max_number_of_matches:
raise NumberOfMatchesWithinLimits(max_number_of_matches)
d = waitForDeferred(prop_common.responseForHref(
request,
responses,
element.HRef.fromString(uri),
principal,
propertiesForResource,
propElement
))
yield d
d.getResult()
break
else:
# Do some optimisation of access control calculation by determining any inherited ACLs outside of
# the child resource loop and supply those to the checkPrivileges on each child.
filteredaces = waitForDeferred(self.inheritedACEsforChildren(request))
yield filteredaces
filteredaces = filteredaces.getResult()
children = []
d = waitForDeferred(self.findChildren("infinity", request, lambda x, y: children.append((x,y)),
privileges=(element.Read(),), inherited_aces=filteredaces))
yield d
d.getResult()
for child, uri in children:
# Try to read the requested property from this resource
try:
prop = waitForDeferred(child.readProperty(principalPropElement.qname(), request))
yield prop
prop = prop.getResult()
if prop: prop.removeWhitespaceNodes()
if prop and len(prop.children) == 1 and isinstance(prop.children[0], element.HRef):
# Find principal associated with this property and test it
principal = waitForDeferred(request.locateResource(str(prop.children[0])))
yield principal
principal = principal.getResult()
if principal and isPrincipalResource(principal):
d = waitForDeferred(principal.principalMatch(myPrincipalURL))
yield d
matched = d.getResult()
if matched:
# Check size of results is within limit
matchcount += 1
if matchcount > max_number_of_matches:
raise NumberOfMatchesWithinLimits(max_number_of_matches)
d = waitForDeferred(prop_common.responseForHref(
request,
responses,
element.HRef.fromString(uri),
child,
propertiesForResource,
propElement
))
yield d
d.getResult()
except HTTPError:
# Just ignore a failure to access the property. We treat this like a property that does not exist
# or does not match the principal.
pass
except NumberOfMatchesWithinLimits:
log.error("Too many matching components in principal-match report")
raise HTTPError(ErrorResponse(
responsecode.FORBIDDEN,
element.NumberOfMatchesWithinLimits()
))
yield MultiStatusResponse(responses)
report_DAV__principal_match = deferredGenerator(report_DAV__principal_match)
calendarserver-5.2+dfsg/twext/web2/dav/method/acl.py 0000644 0001750 0001750 00000006250 12263343324 021471 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_lock -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV ACL method
"""
__all__ = ["http_ACL"]
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import StatusResponse, HTTPError
from txdav.xml import element as davxml
from twext.web2.dav.http import ErrorResponse
from twext.web2.dav.util import davXMLFromStream
log = Logger()
def http_ACL(self, request):
"""
Respond to a ACL request. (RFC 3744, section 8.1)
"""
if not self.exists():
log.error("File not found: %s" % (self,))
yield responsecode.NOT_FOUND
return
#
# Check authentication and access controls
#
x = waitForDeferred(self.authorize(request, (davxml.WriteACL(),)))
yield x
x.getResult()
#
# Read request body
#
doc = waitForDeferred(davXMLFromStream(request.stream))
yield doc
try:
doc = doc.getResult()
except ValueError, e:
log.error("Error while handling ACL body: %s" % (e,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, str(e)))
#
# Set properties
#
if doc is None:
error = "Request XML body is required."
log.error("Error: {err}", err=error)
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error))
#
# Parse request
#
acl = doc.root_element
if not isinstance(acl, davxml.ACL):
error = ("Request XML body must be an acl element."
% (davxml.PropertyUpdate.sname(),))
log.error("Error: {err}", err=error)
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error))
#
# Do ACL merger
#
result = waitForDeferred(self.mergeAccessControlList(acl, request))
yield result
result = result.getResult()
#
# Return response
#
if result is None:
yield responsecode.OK
else:
yield ErrorResponse(responsecode.FORBIDDEN, result)
http_ACL = deferredGenerator(http_ACL)
calendarserver-5.2+dfsg/twext/web2/dav/method/report_principal_search_property_set.py 0000644 0001750 0001750 00000005547 12263343324 030462 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_report_expand -*-
##
# Copyright (c) 2006-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV principal-search-property-set report
"""
__all__ = ["report_DAV__principal_search_property_set"]
from twisted.internet.defer import deferredGenerator
from twext.python.log import Logger
from twext.web2 import responsecode
from txdav.xml import element as davxml
from twext.web2.http import HTTPError, Response, StatusResponse
from twext.web2.stream import MemoryStream
log = Logger()
def report_DAV__principal_search_property_set(self, request, principal_search_property_set):
"""
Generate a principal-search-property-set REPORT. (RFC 3744, section 9.5)
"""
# Verify root element
if not isinstance(principal_search_property_set, davxml.PrincipalSearchPropertySet):
raise ValueError("%s expected as root element, not %s."
% (davxml.PrincipalSearchPropertySet.sname(), principal_search_property_set.sname()))
# Only handle Depth: 0
depth = request.headers.getHeader("depth", "0")
if depth != "0":
log.error("Error in principal-search-property-set REPORT, Depth set to %s" % (depth,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,)))
# Get details from the resource
result = self.principalSearchPropertySet()
if result is None:
log.error("Error in principal-search-property-set REPORT not supported on: %s" % (self,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Not allowed on this resource"))
yield Response(code=responsecode.OK, stream=MemoryStream(result.toxml()))
report_DAV__principal_search_property_set = deferredGenerator(report_DAV__principal_search_property_set)
calendarserver-5.2+dfsg/twext/web2/dav/method/report.py 0000644 0001750 0001750 00000010657 12263343324 022253 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_report -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV REPORT method
"""
__all__ = [
"max_number_of_matches",
"NumberOfMatchesWithinLimits",
"http_REPORT",
]
import string
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import HTTPError, StatusResponse
from twext.web2.dav.http import ErrorResponse
from twext.web2.dav.util import davXMLFromStream
from txdav.xml import element as davxml
from txdav.xml.element import lookupElement
from txdav.xml.base import encodeXMLName
log = Logger()
max_number_of_matches = 500
class NumberOfMatchesWithinLimits(Exception):
def __init__(self, limit):
super(NumberOfMatchesWithinLimits, self).__init__()
self.limit = limit
def maxLimit(self):
return self.limit
def http_REPORT(self, request):
"""
Respond to a REPORT request. (RFC 3253, section 3.6)
"""
if not self.exists():
log.error("File not found: %s" % (self,))
raise HTTPError(responsecode.NOT_FOUND)
#
# Check authentication and access controls
#
x = waitForDeferred(self.authorize(request, (davxml.Read(),)))
yield x
x.getResult()
#
# Read request body
#
try:
doc = waitForDeferred(davXMLFromStream(request.stream))
yield doc
doc = doc.getResult()
except ValueError, e:
log.error("Error while handling REPORT body: %s" % (e,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, str(e)))
if doc is None:
raise HTTPError(StatusResponse(
responsecode.BAD_REQUEST,
"REPORT request body may not be empty"
))
#
# Parse request
#
namespace = doc.root_element.namespace
name = doc.root_element.name
ok = string.ascii_letters + string.digits + "_"
def to_method(s):
out = []
for c in s:
if c in ok:
out.append(c)
else:
out.append("_")
return "report_" + "".join(out)
if namespace:
method_name = to_method("_".join((namespace, name)))
if namespace == davxml.dav_namespace:
request.submethod = "DAV:" + name
else:
request.submethod = encodeXMLName(namespace, name)
else:
method_name = to_method(name)
request.submethod = name
try:
method = getattr(self, method_name)
# Also double-check via supported-reports property
reports = self.supportedReports()
test = lookupElement((namespace, name))
if not test:
raise AttributeError()
test = davxml.Report(test())
if test not in reports:
raise AttributeError()
except AttributeError:
#
# Requested report is not supported.
#
log.error("Unsupported REPORT %s for resource %s (no method %s)"
% (encodeXMLName(namespace, name), self, method_name))
raise HTTPError(ErrorResponse(
responsecode.FORBIDDEN,
davxml.SupportedReport()
))
d = waitForDeferred(method(request, doc.root_element))
yield d
yield d.getResult()
http_REPORT = deferredGenerator(http_REPORT)
calendarserver-5.2+dfsg/twext/web2/dav/method/proppatch.py 0000644 0001750 0001750 00000016376 12263343324 022744 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_prop.PROP.test_PROPPATCH -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV-aware static resources.
"""
__all__ = ["http_PROPPATCH"]
from twisted.python.failure import Failure
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2.http import HTTPError, StatusResponse
from txdav.xml import element as davxml
from twext.web2.dav.http import MultiStatusResponse, PropertyStatusResponseQueue
from twext.web2.dav.util import davXMLFromStream
log = Logger()
def http_PROPPATCH(self, request):
"""
Respond to a PROPPATCH request. (RFC 2518, section 8.2)
"""
if not self.exists():
log.error("File not found: %s" % (self,))
raise HTTPError(responsecode.NOT_FOUND)
x = waitForDeferred(self.authorize(request, (davxml.WriteProperties(),)))
yield x
x.getResult()
#
# Read request body
#
try:
doc = waitForDeferred(davXMLFromStream(request.stream))
yield doc
doc = doc.getResult()
except ValueError, e:
log.error("Error while handling PROPPATCH body: %s" % (e,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, str(e)))
if doc is None:
error = "Request XML body is required."
log.error("Error: {err}", error)
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error))
#
# Parse request
#
update = doc.root_element
if not isinstance(update, davxml.PropertyUpdate):
error = ("Request XML body must be a propertyupdate element."
% (davxml.PropertyUpdate.sname(),))
log.error("Error: {err}", error)
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error))
responses = PropertyStatusResponseQueue("PROPPATCH", request.uri, responsecode.NO_CONTENT)
undoActions = []
gotError = False
# Look for Prefer header
prefer = request.headers.getHeader("prefer", {})
returnMinimal = any([key == "return" and value == "minimal" for key, value, _ignore_args in prefer])
try:
#
# Update properties
#
for setOrRemove in update.children:
assert len(setOrRemove.children) == 1
container = setOrRemove.children[0]
assert isinstance(container, davxml.PropertyContainer)
properties = container.children
def do(action, property, removing=False):
"""
Perform action(property, request) while maintaining an
undo queue.
"""
has = waitForDeferred(self.hasProperty(property, request))
yield has
has = has.getResult()
if has:
oldProperty = waitForDeferred(self.readProperty(property, request))
yield oldProperty
oldProperty = oldProperty.getResult()
def undo():
return self.writeProperty(oldProperty, request)
else:
def undo():
return self.removeProperty(property, request)
try:
x = waitForDeferred(action(property, request))
yield x
x.getResult()
except KeyError, e:
# Removing a non-existent property is OK according to WebDAV
if removing:
responses.add(responsecode.OK, property)
yield True
return
else:
# Convert KeyError exception into HTTPError
responses.add(
Failure(exc_value=HTTPError(StatusResponse(responsecode.FORBIDDEN, str(e)))),
property
)
yield False
return
except:
responses.add(Failure(), property)
yield False
return
else:
responses.add(responsecode.OK, property)
# Only add undo action for those that succeed because those that fail will not have changed
undoActions.append(undo)
yield True
return
do = deferredGenerator(do)
if isinstance(setOrRemove, davxml.Set):
for property in properties:
ok = waitForDeferred(do(self.writeProperty, property))
yield ok
ok = ok.getResult()
if not ok:
gotError = True
elif isinstance(setOrRemove, davxml.Remove):
for property in properties:
ok = waitForDeferred(do(self.removeProperty, property, True))
yield ok
ok = ok.getResult()
if not ok:
gotError = True
else:
raise AssertionError("Unknown child of PropertyUpdate: %s" % (setOrRemove,))
except:
#
# If there is an error, we have to back out whatever we have
# operations we have done because PROPPATCH is an
# all-or-nothing request.
# We handle the first one here, and then re-raise to handle the
# rest in the containing scope.
#
for action in undoActions:
x = waitForDeferred(action())
yield x
x.getResult()
raise
#
# If we had an error we need to undo any changes that did succeed and change status of
# those to 424 Failed Dependency.
#
if gotError:
for action in undoActions:
x = waitForDeferred(action())
yield x
x.getResult()
responses.error()
#
# Return response - use 200 if Prefer:return=minimal set and no errors
#
if returnMinimal and not gotError:
yield responsecode.OK
else:
yield MultiStatusResponse([responses.response()])
http_PROPPATCH = deferredGenerator(http_PROPPATCH)
calendarserver-5.2+dfsg/twext/web2/dav/method/report_expand.py 0000644 0001750 0001750 00000016435 12263343324 023612 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_report_expand -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
##
"""
WebDAV expand-property report
"""
__all__ = ["report_DAV__expand_property"]
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.python.failure import Failure
from twext.python.log import Logger
from twext.web2 import responsecode
from txdav.xml import element
from txdav.xml.element import dav_namespace
from twext.web2.dav.http import statusForFailure, MultiStatusResponse
from twext.web2.dav.method import prop_common
from twext.web2.dav.method.propfind import propertyName
from twext.web2.dav.resource import AccessDeniedError
from twext.web2.dav.util import parentForURL
from twext.web2.http import HTTPError, StatusResponse
log = Logger()
@inlineCallbacks
def report_DAV__expand_property(self, request, expand_property):
"""
Generate an expand-property REPORT. (RFC 3253, section 3.8)
TODO: for simplicity we will only support one level of expansion.
"""
# Verify root element
if not isinstance(expand_property, element.ExpandProperty):
raise ValueError("%s expected as root element, not %s."
% (element.ExpandProperty.sname(), expand_property.sname()))
# Only handle Depth: 0
depth = request.headers.getHeader("depth", "0")
if depth != "0":
log.error("Non-zero depth is not allowed: %s" % (depth,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, "Depth %s not allowed" % (depth,)))
#
# Get top level properties to expand and make sure we only have one level
#
properties = {}
for property in expand_property.children:
namespace = property.attributes.get("namespace", dav_namespace)
name = property.attributes.get("name", "")
# Make sure children have no children
props_to_find = []
for child in property.children:
if child.children:
log.error("expand-property REPORT only supports single level expansion")
raise HTTPError(StatusResponse(
responsecode.NOT_IMPLEMENTED,
"expand-property REPORT only supports single level expansion"
))
child_namespace = child.attributes.get("namespace", dav_namespace)
child_name = child.attributes.get("name", "")
props_to_find.append((child_namespace, child_name))
properties[(namespace, name)] = props_to_find
#
# Generate the expanded responses status for each top-level property
#
properties_by_status = {
responsecode.OK : [],
responsecode.NOT_FOUND : [],
}
filteredaces = None
lastParent = None
for qname in properties.iterkeys():
try:
prop = (yield self.readProperty(qname, request))
# Form the PROPFIND-style DAV:prop element we need later
props_to_return = element.PropertyContainer(*properties[qname])
# Now dereference any HRefs
responses = []
for href in prop.children:
if isinstance(href, element.HRef):
# Locate the Href resource and its parent
resource_uri = str(href)
child = (yield request.locateResource(resource_uri))
if not child or not child.exists():
responses.append(element.StatusResponse(href, element.Status.fromResponseCode(responsecode.NOT_FOUND)))
continue
parent = (yield request.locateResource(parentForURL(resource_uri)))
# Check privileges on parent - must have at least DAV:read
try:
yield parent.checkPrivileges(request, (element.Read(),))
except AccessDeniedError:
responses.append(element.StatusResponse(href, element.Status.fromResponseCode(responsecode.FORBIDDEN)))
continue
# Cache the last parent's inherited aces for checkPrivileges optimization
if lastParent != parent:
lastParent = parent
# Do some optimisation of access control calculation by determining any inherited ACLs outside of
# the child resource loop and supply those to the checkPrivileges on each child.
filteredaces = (yield parent.inheritedACEsforChildren(request))
# Check privileges - must have at least DAV:read
try:
yield child.checkPrivileges(request, (element.Read(),), inherited_aces=filteredaces)
except AccessDeniedError:
responses.append(element.StatusResponse(href, element.Status.fromResponseCode(responsecode.FORBIDDEN)))
continue
# Now retrieve all the requested properties on the HRef resource
yield prop_common.responseForHref(
request,
responses,
href,
child,
prop_common.propertyListForResource,
props_to_return,
)
prop.children = responses
properties_by_status[responsecode.OK].append(prop)
except:
f = Failure()
log.error("Error reading property %r for resource %s: %s" % (qname, request.uri, f.value))
status = statusForFailure(f, "getting property: %s" % (qname,))
if status not in properties_by_status: properties_by_status[status] = []
properties_by_status[status].append(propertyName(qname))
# Build the overall response
propstats = [
element.PropertyStatus(
element.PropertyContainer(*properties_by_status[status]),
element.Status.fromResponseCode(status)
)
for status in properties_by_status if properties_by_status[status]
]
returnValue(MultiStatusResponse((element.PropertyStatusResponse(element.HRef(request.uri), *propstats),)))
calendarserver-5.2+dfsg/twext/web2/dav/method/__init__.py 0000644 0001750 0001750 00000003206 12263343324 022467 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV methods.
Modules in this package provide the implementation of
twext.web2.dav.static.DAVFile's dispatched methods.
"""
__all__ = [
"acl",
"copymove",
"delete",
"get",
"lock",
"mkcol",
"propfind",
"proppatch",
"prop_common",
"put",
"put_common",
"report",
"report_acl_principal_prop_set",
"report_expand",
"report_principal_match",
"report_principal_property_search",
"report_principal_search_property_set",
]
calendarserver-5.2+dfsg/twext/web2/dav/method/lock.py 0000644 0001750 0001750 00000003152 12263343324 021660 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_lock -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV LOCK and UNLOCK methods
"""
__all__ = ["http_LOCK", "http_UNLOCK"]
from twext.web2 import responsecode
def http_LOCK(self, request):
"""
Respond to a LOCK request. (RFC 2518, section 8.10)
"""
return responsecode.NOT_IMPLEMENTED
def http_UNLOCK(self, request):
"""
Respond to a UNLOCK request. (RFC 2518, section 8.11)
"""
return responsecode.NOT_IMPLEMENTED
calendarserver-5.2+dfsg/twext/web2/dav/method/propfind.py 0000644 0001750 0001750 00000021337 12263343324 022556 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_prop.PROP.test_PROPFIND -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV PROPFIND method
"""
__all__ = [
"http_PROPFIND",
"propertyName",
]
from twisted.python.failure import Failure
from twisted.internet.defer import deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2.http import HTTPError
from twext.web2 import responsecode
from twext.web2.http import StatusResponse
from txdav.xml import element as davxml
from twext.web2.dav.http import MultiStatusResponse, statusForFailure, \
ErrorResponse
from twext.web2.dav.util import normalizeURL, davXMLFromStream
log = Logger()
def http_PROPFIND(self, request):
"""
Respond to a PROPFIND request. (RFC 2518, section 8.1)
"""
if not self.exists():
log.error("File not found: %s" % (self,))
raise HTTPError(responsecode.NOT_FOUND)
#
# Check authentication and access controls
#
x = waitForDeferred(self.authorize(request, (davxml.Read(),)))
yield x
x.getResult()
#
# Read request body
#
try:
doc = waitForDeferred(davXMLFromStream(request.stream))
yield doc
doc = doc.getResult()
except ValueError, e:
log.error("Error while handling PROPFIND body: %s" % (e,))
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, str(e)))
if doc is None:
# No request body means get all properties.
search_properties = "all"
else:
#
# Parse request
#
find = doc.root_element
if not isinstance(find, davxml.PropertyFind):
error = ("Non-%s element in PROPFIND request body: %s"
% (davxml.PropertyFind.sname(), find))
log.error("Error: {err}", err=error)
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, error))
container = find.children[0]
if isinstance(container, davxml.AllProperties):
# Get all properties
search_properties = "all"
elif isinstance(container, davxml.PropertyName):
# Get names only
search_properties = "names"
elif isinstance(container, davxml.PropertyContainer):
properties = container.children
search_properties = [(p.namespace, p.name) for p in properties]
else:
raise AssertionError("Unexpected element type in %s: %s"
% (davxml.PropertyFind.sname(), container))
#
# Generate XML output stream
#
request_uri = request.uri
depth = request.headers.getHeader("depth", "infinity")
# By policy we will never allow a depth:infinity propfind
if depth == "infinity":
raise HTTPError(ErrorResponse(responsecode.FORBIDDEN, davxml.PropfindFiniteDepth()))
# Look for Prefer header first, then try Brief
prefer = request.headers.getHeader("prefer", {})
returnMinimal = any([key == "return" and value == "minimal" for key, value, _ignore_args in prefer])
noRoot = any([key == "depth-noroot" and value is None for key, value, _ignore_args in prefer])
if not returnMinimal:
returnMinimal = request.headers.getHeader("brief", False)
xml_responses = []
# FIXME: take advantage of the new generative properties of findChildren
my_url = normalizeURL(request_uri)
if self.isCollection() and not my_url.endswith("/"):
my_url += "/"
# Do some optimisation of access control calculation by determining any inherited ACLs outside of
# the child resource loop and supply those to the checkPrivileges on each child.
filtered_aces = waitForDeferred(self.inheritedACEsforChildren(request))
yield filtered_aces
filtered_aces = filtered_aces.getResult()
if depth in ("1", "infinity") and noRoot:
resources = []
else:
resources = [(self, my_url)]
d = self.findChildren(depth, request, lambda x, y: resources.append((x, y)), (davxml.Read(),), inherited_aces=filtered_aces)
x = waitForDeferred(d)
yield x
x.getResult()
for resource, uri in resources:
if search_properties is "names":
try:
resource_properties = waitForDeferred(resource.listProperties(request))
yield resource_properties
resource_properties = resource_properties.getResult()
except:
log.error("Unable to get properties for resource %r" % (resource,))
raise
properties_by_status = {
responsecode.OK: [propertyName(p) for p in resource_properties]
}
else:
properties_by_status = {
responsecode.OK : [],
responsecode.NOT_FOUND : [],
}
if search_properties is "all":
properties_to_enumerate = waitForDeferred(resource.listAllprop(request))
yield properties_to_enumerate
properties_to_enumerate = properties_to_enumerate.getResult()
else:
properties_to_enumerate = search_properties
for property in properties_to_enumerate:
has = waitForDeferred(resource.hasProperty(property, request))
yield has
has = has.getResult()
if has:
try:
resource_property = waitForDeferred(resource.readProperty(property, request))
yield resource_property
resource_property = resource_property.getResult()
except:
f = Failure()
status = statusForFailure(f, "getting property: %s" % (property,))
if status not in properties_by_status:
properties_by_status[status] = []
if not returnMinimal or status != responsecode.NOT_FOUND:
properties_by_status[status].append(propertyName(property))
else:
if resource_property is not None:
properties_by_status[responsecode.OK].append(resource_property)
elif not returnMinimal:
properties_by_status[responsecode.NOT_FOUND].append(propertyName(property))
elif not returnMinimal:
properties_by_status[responsecode.NOT_FOUND].append(propertyName(property))
propstats = []
for status in properties_by_status:
properties = properties_by_status[status]
if not properties:
continue
xml_status = davxml.Status.fromResponseCode(status)
xml_container = davxml.PropertyContainer(*properties)
xml_propstat = davxml.PropertyStatus(xml_container, xml_status)
propstats.append(xml_propstat)
# Always need to have at least one propstat present (required by Prefer header behavior)
if len(propstats) == 0:
propstats.append(davxml.PropertyStatus(
davxml.PropertyContainer(),
davxml.Status.fromResponseCode(responsecode.OK)
))
xml_resource = davxml.HRef(uri)
xml_response = davxml.PropertyStatusResponse(xml_resource, *propstats)
xml_responses.append(xml_response)
#
# Return response
#
yield MultiStatusResponse(xml_responses)
http_PROPFIND = deferredGenerator(http_PROPFIND)
##
# Utilities
##
def propertyName(name):
property_namespace, property_name = name
pname = davxml.WebDAVUnknownElement()
pname.namespace = property_namespace
pname.name = property_name
return pname
calendarserver-5.2+dfsg/twext/web2/dav/resource.py 0000644 0001750 0001750 00000301110 12263343324 021272 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_resource -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
from __future__ import print_function
"""
WebDAV resources.
"""
__all__ = [
"DAVPropertyMixIn",
"DAVResource",
"DAVLeafResource",
"DAVPrincipalResource",
"DAVPrincipalCollectionResource",
"AccessDeniedError",
"isPrincipalResource",
"TwistedACLInheritable",
"TwistedGETContentMD5",
"TwistedQuotaRootProperty",
"allACL",
"readonlyACL",
"davPrivilegeSet",
"unauthenticatedPrincipal",
]
import cPickle as pickle
import urllib
from zope.interface import implements
from twisted.cred.error import LoginFailed, UnauthorizedLogin
from twisted.python.failure import Failure
from twisted.internet.defer import (
Deferred, maybeDeferred, succeed, inlineCallbacks, returnValue
)
from twisted.internet import reactor
from twext.python.log import Logger
from txdav.xml import element
from txdav.xml.base import encodeXMLName
from txdav.xml.element import WebDAVElement, WebDAVEmptyElement, WebDAVTextElement
from txdav.xml.element import dav_namespace
from txdav.xml.element import twisted_dav_namespace, twisted_private_namespace
from txdav.xml.element import registerElement, lookupElement
from twext.web2 import responsecode
from twext.web2.http import HTTPError, RedirectResponse, StatusResponse
from twext.web2.http_headers import generateContentType
from twext.web2.iweb import IResponse
from twext.web2.resource import LeafResource
from twext.web2.server import NoURLForResourceError
from twext.web2.static import MetaDataMixin, StaticRenderMixin
from twext.web2.auth.wrapper import UnauthorizedResponse
from twext.web2.dav.idav import IDAVResource, IDAVPrincipalResource, IDAVPrincipalCollectionResource
from twext.web2.dav.http import NeedPrivilegesResponse
from twext.web2.dav.noneprops import NonePropertyStore
from twext.web2.dav.util import unimplemented, parentForURL, joinURL
from twext.web2.dav.auth import PrincipalCredentials
from twistedcaldav import customxml
log = Logger()
class DAVPropertyMixIn (MetaDataMixin):
"""
Mix-in class which implements the DAV property access API in
L{IDAVResource}.
There are three categories of DAV properties, for the purposes of
how this class manages them. A X{property} is either a X{live
property} or a X{dead property}, and live properties are split
into two categories:
1. Dead properties. There are properties that the server simply
stores as opaque data. These are store in the X{dead property
store}, which is provided by subclasses via the
L{deadProperties} method.
2. Live properties which are always computed. These properties
aren't stored anywhere (by this class) but instead are derived
from the resource state or from data that is persisted
elsewhere. These are listed in the L{liveProperties}
attribute and are handled explicitly by the L{readProperty}
method.
3. Live properties may be acted on specially and are stored in
the X{dead property store}. These are not listed in the
L{liveProperties} attribute, but may be handled specially by
the property access methods. For example, L{writeProperty}
might validate the data and refuse to write data it deems
inappropriate for a given property.
There are two sets of property access methods. The first group
(L{hasProperty}, etc.) provides access to all properties. They
automatically figure out which category a property falls into and
act accordingly.
The second group (L{hasDeadProperty}, etc.) accesses the dead
property store directly and bypasses any live property logic that
exists in the first group of methods. These methods are used by
the first group of methods, and there are cases where they may be
needed by other methods. I{Accessing dead properties directly
should be done with caution.} Bypassing the live property logic
means that values may not be the correct ones for use in DAV
requests such as PROPFIND, and may be bypassing security checks.
In general, one should never bypass the live property logic as
part of a client request for property data.
Properties in the L{twisted_private_namespace} namespace are
internal to the server and should not be exposed to clients. They
can only be accessed via the dead property store.
"""
# Note: The DAV:owner and DAV:group live properties are only
# meaningful if you are using ACL semantics (ie. Unix-like) which
# use them. This (generic) class does not.
def liveProperties(self):
return (
(dav_namespace, "resourcetype"),
(dav_namespace, "getetag"),
(dav_namespace, "getcontenttype"),
(dav_namespace, "getcontentlength"),
(dav_namespace, "getlastmodified"),
(dav_namespace, "creationdate"),
(dav_namespace, "displayname"),
(dav_namespace, "supportedlock"),
(dav_namespace, "supported-report-set"), # RFC 3253, section 3.1.5
#(dav_namespace, "owner" ), # RFC 3744, section 5.1
#(dav_namespace, "group" ), # RFC 3744, section 5.2
(dav_namespace, "supported-privilege-set"), # RFC 3744, section 5.3
(dav_namespace, "current-user-privilege-set"), # RFC 3744, section 5.4
(dav_namespace, "current-user-principal"), # RFC 5397, Section 3
(dav_namespace, "acl"), # RFC 3744, section 5.5
(dav_namespace, "acl-restrictions"), # RFC 3744, section 5.6
(dav_namespace, "inherited-acl-set"), # RFC 3744, section 5.7
(dav_namespace, "principal-collection-set"), # RFC 3744, section 5.8
(dav_namespace, "quota-available-bytes"), # RFC 4331, section 3
(dav_namespace, "quota-used-bytes"), # RFC 4331, section 4
(twisted_dav_namespace, "resource-class"),
)
def deadProperties(self):
"""
Provides internal access to the WebDAV dead property store.
You probably shouldn't be calling this directly if you can use
the property accessors in the L{IDAVResource} API instead.
However, a subclass must override this method to provide it's
own dead property store.
This implementation returns an instance of
L{NonePropertyStore}, which cannot store dead properties.
Subclasses must override this method if they wish to store
dead properties.
@return: a dict-like object from which one can read and to
which one can write dead properties. Keys are qname
tuples (i.e. C{(namespace, name)}) as returned by
L{WebDAVElement.qname()} and values are
L{WebDAVElement} instances.
"""
if not hasattr(self, "_dead_properties"):
self._dead_properties = NonePropertyStore(self)
return self._dead_properties
def hasProperty(self, property, request):
"""
See L{IDAVResource.hasProperty}.
"""
if type(property) is tuple:
qname = property
else:
qname = (property.namespace, property.name)
if qname[0] == twisted_private_namespace:
return succeed(False)
# Need to special case the dynamic live properties
namespace, name = qname
if namespace == dav_namespace:
if name in ("quota-available-bytes", "quota-used-bytes"):
d = self.hasQuota(request)
d.addCallback(lambda result: result)
return d
return succeed(
qname in self.liveProperties() or
self.deadProperties().contains(qname)
)
def readProperty(self, property, request):
"""
See L{IDAVResource.readProperty}.
"""
@inlineCallbacks
def defer():
if type(property) is tuple:
qname = property
sname = encodeXMLName(*property)
else:
qname = property.qname()
sname = property.sname()
namespace, name = qname
if namespace == dav_namespace:
if name == "resourcetype":
# Allow live property to be overridden by dead property
if self.deadProperties().contains(qname):
returnValue(self.deadProperties().get(qname))
if self.isCollection():
returnValue(element.ResourceType.collection) #@UndefinedVariable
returnValue(element.ResourceType.empty) #@UndefinedVariable
if name == "getetag":
etag = (yield self.etag())
if etag is None:
returnValue(None)
returnValue(element.GETETag(etag.generate()))
if name == "getcontenttype":
mimeType = self.contentType()
if mimeType is None:
returnValue(None)
returnValue(element.GETContentType(generateContentType(mimeType)))
if name == "getcontentlength":
length = self.contentLength()
if length is None:
# TODO: really we should "render" the resource and
# determine its size from that but for now we just
# return an empty element.
returnValue(element.GETContentLength(""))
else:
returnValue(element.GETContentLength(str(length)))
if name == "getlastmodified":
lastModified = self.lastModified()
if lastModified is None:
returnValue(None)
returnValue(element.GETLastModified.fromDate(lastModified))
if name == "creationdate":
creationDate = self.creationDate()
if creationDate is None:
returnValue(None)
returnValue(element.CreationDate.fromDate(creationDate))
if name == "displayname":
displayName = self.displayName()
if displayName is None:
returnValue(None)
returnValue(element.DisplayName(displayName))
if name == "supportedlock":
returnValue(element.SupportedLock(
element.LockEntry(
element.LockScope.exclusive, #@UndefinedVariable
element.LockType.write #@UndefinedVariable
),
element.LockEntry(
element.LockScope.shared, #@UndefinedVariable
element.LockType.write #@UndefinedVariable
),
))
if name == "supported-report-set":
returnValue(element.SupportedReportSet(*[
element.SupportedReport(report,)
for report in self.supportedReports()
]))
if name == "supported-privilege-set":
returnValue((yield self.supportedPrivileges(request)))
if name == "acl-restrictions":
returnValue(element.ACLRestrictions())
if name == "inherited-acl-set":
returnValue(element.InheritedACLSet(*self.inheritedACLSet()))
if name == "principal-collection-set":
returnValue(element.PrincipalCollectionSet(*[
element.HRef(
principalCollection.principalCollectionURL()
)
for principalCollection in self.principalCollections()
]))
@inlineCallbacks
def ifAllowed(privileges, callback):
try:
yield self.checkPrivileges(request, privileges)
result = yield callback()
except AccessDeniedError:
raise HTTPError(StatusResponse(
responsecode.UNAUTHORIZED,
"Access denied while reading property %s."
% (sname,)
))
returnValue(result)
if name == "current-user-privilege-set":
@inlineCallbacks
def callback():
privs = yield self.currentPrivileges(request)
returnValue(element.CurrentUserPrivilegeSet(*privs))
returnValue((yield ifAllowed(
(element.ReadCurrentUserPrivilegeSet(),),
callback
)))
if name == "acl":
@inlineCallbacks
def callback():
acl = yield self.accessControlList(request)
if acl is None:
acl = element.ACL()
returnValue(acl)
returnValue(
(yield ifAllowed((element.ReadACL(),), callback))
)
if name == "current-user-principal":
returnValue(element.CurrentUserPrincipal(
self.currentPrincipal(request).children[0]
))
if name == "quota-available-bytes":
qvalue = yield self.quota(request)
if qvalue is None:
raise HTTPError(StatusResponse(
responsecode.NOT_FOUND,
"Property %s does not exist." % (sname,)
))
else:
returnValue(element.QuotaAvailableBytes(str(qvalue[0])))
if name == "quota-used-bytes":
qvalue = yield self.quota(request)
if qvalue is None:
raise HTTPError(StatusResponse(
responsecode.NOT_FOUND,
"Property %s does not exist." % (sname,)
))
else:
returnValue(element.QuotaUsedBytes(str(qvalue[1])))
elif namespace == twisted_dav_namespace:
if name == "resource-class":
returnValue(ResourceClass(self.__class__.__name__))
elif namespace == twisted_private_namespace:
raise HTTPError(StatusResponse(
responsecode.FORBIDDEN,
"Properties in the %s namespace are private to the server."
% (sname,)
))
returnValue(self.deadProperties().get(qname))
return defer()
def writeProperty(self, property, request):
"""
See L{IDAVResource.writeProperty}.
"""
assert isinstance(property, WebDAVElement), (
"Not a property: %r" % (property,)
)
def defer():
if property.protected:
raise HTTPError(StatusResponse(
responsecode.FORBIDDEN,
"Protected property %s may not be set."
% (property.sname(),)
))
if property.namespace == twisted_private_namespace:
raise HTTPError(StatusResponse(
responsecode.FORBIDDEN,
"Properties in the %s namespace are private to the server."
% (property.sname(),)
))
return self.deadProperties().set(property)
return maybeDeferred(defer)
def removeProperty(self, property, request):
"""
See L{IDAVResource.removeProperty}.
"""
def defer():
if type(property) is tuple:
qname = property
sname = encodeXMLName(*property)
else:
qname = property.qname()
sname = property.sname()
if qname in self.liveProperties():
raise HTTPError(StatusResponse(
responsecode.FORBIDDEN,
"Live property %s cannot be deleted." % (sname,)
))
if qname[0] == twisted_private_namespace:
raise HTTPError(StatusResponse(
responsecode.FORBIDDEN,
"Properties in the %s namespace are private to the server."
% (qname[0],)
))
return self.deadProperties().delete(qname)
return maybeDeferred(defer)
@inlineCallbacks
def listProperties(self, request):
"""
See L{IDAVResource.listProperties}.
"""
qnames = set(self.liveProperties())
# Add dynamic live properties that exist
dynamicLiveProperties = (
(dav_namespace, "quota-available-bytes"),
(dav_namespace, "quota-used-bytes"),
)
for dqname in dynamicLiveProperties:
has = (yield self.hasProperty(dqname, request))
if not has:
qnames.remove(dqname)
for qname in self.deadProperties().list():
if (
qname not in qnames and
qname[0] != twisted_private_namespace
):
qnames.add(qname)
returnValue(qnames)
def listAllprop(self, request):
"""
Some DAV properties should not be returned to a C{DAV:allprop}
query. RFC 3253 defines several such properties. This method
computes a subset of the property qnames returned by
L{listProperties} by filtering out elements whose class have
the C{.hidden} attribute set to C{True}.
@return: a list of qnames of properties which are defined and
are appropriate for use in response to a C{DAV:allprop}
query.
"""
def doList(qnames):
result = []
for qname in qnames:
try:
if not lookupElement(qname).hidden:
result.append(qname)
except KeyError:
# Unknown element
result.append(qname)
return result
d = self.listProperties(request)
d.addCallback(doList)
return d
def hasDeadProperty(self, property):
"""
Same as L{hasProperty}, but bypasses the live property store
and checks directly from the dead property store.
"""
if type(property) is tuple:
qname = property
else:
qname = property.qname()
return self.deadProperties().contains(qname)
def readDeadProperty(self, property):
"""
Same as L{readProperty}, but bypasses the live property store
and reads directly from the dead property store.
"""
if type(property) is tuple:
qname = property
else:
qname = property.qname()
return self.deadProperties().get(qname)
def writeDeadProperty(self, property):
"""
Same as L{writeProperty}, but bypasses the live property store
and writes directly to the dead property store. Note that
this should not be used unless you know that you are writing
to an overrideable live property, as this bypasses the logic
which protects protected properties. The result of writing to
a non-overrideable live property with this method is
undefined; the value in the dead property store may or may not
be ignored when reading the property with L{readProperty}.
"""
self.deadProperties().set(property)
def removeDeadProperty(self, property):
"""
Same as L{removeProperty}, but bypasses the live property
store and acts directly on the dead property store.
"""
if self.hasDeadProperty(property):
if type(property) is tuple:
qname = property
else:
qname = property.qname()
self.deadProperties().delete(qname)
#
# Overrides some methods in MetaDataMixin in order to allow DAV properties
# to override the values of some HTTP metadata.
#
def contentType(self):
if self.hasDeadProperty((element.dav_namespace, "getcontenttype")):
return self.readDeadProperty(
(element.dav_namespace, "getcontenttype")
).mimeType()
else:
return super(DAVPropertyMixIn, self).contentType()
def displayName(self):
if self.hasDeadProperty((element.dav_namespace, "displayname")):
return str(self.readDeadProperty(
(element.dav_namespace, "displayname")
))
else:
return super(DAVPropertyMixIn, self).displayName()
class DAVResource (DAVPropertyMixIn, StaticRenderMixin):
"""
WebDAV resource.
"""
implements(IDAVResource)
def __init__(self, principalCollections=None):
"""
@param principalCollections: an iterable of
L{IDAVPrincipalCollectionResource}s which contain
principals to be used in ACLs for this resource.
"""
if principalCollections is not None:
self._principalCollections = frozenset([
IDAVPrincipalCollectionResource(principalCollection)
for principalCollection in principalCollections
])
##
# DAV
##
def davComplianceClasses(self):
"""
This implementation raises L{NotImplementedError}.
@return: a sequence of strings denoting WebDAV compliance
classes. For example, a DAV level 2 server might return
("1", "2").
"""
unimplemented(self)
def isCollection(self):
"""
See L{IDAVResource.isCollection}.
This implementation raises L{NotImplementedError}; a subclass
must override this method.
"""
unimplemented(self)
def findChildren(
self, depth, request, callback,
privileges=None, inherited_aces=None
):
"""
See L{IDAVResource.findChildren}.
This implementation works for C{depth} values of C{"0"},
C{"1"}, and C{"infinity"}. As long as C{self.listChildren} is
implemented
"""
assert depth in ("0", "1", "infinity"), "Invalid depth: %s" % (depth,)
if depth == "0" or not self.isCollection():
return succeed(None)
completionDeferred = Deferred()
basepath = request.urlForResource(self)
children = []
def checkPrivilegesError(failure):
failure.trap(AccessDeniedError)
reactor.callLater(0, getChild)
def checkPrivileges(child):
if child is None:
return None
if privileges is None:
return child
d = child.checkPrivileges(
request, privileges,
inherited_aces=inherited_aces
)
d.addCallback(lambda _: child)
return d
def gotChild(child, childpath):
if child is None:
callback(None, childpath + "/")
else:
if child.isCollection():
callback(child, childpath + "/")
if depth == "infinity":
d = child.findChildren(
depth, request,
callback, privileges
)
d.addCallback(lambda x: reactor.callLater(0, getChild))
return d
else:
callback(child, childpath)
reactor.callLater(0, getChild)
def getChild():
try:
childname = children.pop()
except IndexError:
completionDeferred.callback(None)
else:
childpath = joinURL(basepath, urllib.quote(childname))
d = request.locateChildResource(self, childname)
d.addCallback(checkPrivileges)
d.addCallbacks(gotChild, checkPrivilegesError, (childpath,))
d.addErrback(completionDeferred.errback)
def gotChildren(listChildrenResult):
children[:] = list(listChildrenResult)
getChild()
maybeDeferred(self.listChildren).addCallback(gotChildren)
return completionDeferred
@inlineCallbacks
def findChildrenFaster(
self, depth, request, okcallback, badcallback, missingcallback, unavailablecallback,
names, privileges, inherited_aces
):
"""
See L{IDAVResource.findChildren}.
This implementation works for C{depth} values of C{"0"},
C{"1"}, and C{"infinity"}. As long as C{self.listChildren} is
implemented
@param depth: a C{str} for the depth: "0", "1" and "infinity"
only allowed.
@param request: the L{Request} for the current request in
progress
@param okcallback: a callback function used on all resources
that pass the privilege check, or C{None}
@param badcallback: a callback function used on all resources
that fail the privilege check, or C{None}
@param missingcallback: a callback function used on all resources
that are missing, or C{None}
@param names: a C{list} of C{str}'s containing the names of
the child resources to lookup. If empty or C{None} all
children will be examined, otherwise only the ones in the
list.
@param privileges: a list of privileges to check.
@param inherited_aces: the list of parent ACEs that are
inherited by all children.
"""
assert depth in ("0", "1", "infinity"), "Invalid depth: %s" % (depth,)
if depth == "0" or not self.isCollection():
returnValue(None)
# First find all depth 1 children
names1 = []
namesDeep = []
collections1 = []
if names:
for name in names:
(names1 if name.rstrip("/").find("/") == -1 else namesDeep).append(name.rstrip("/"))
#children = []
#yield self.findChildren("1", request, lambda x, y: children.append((x, y)), privileges=None, inherited_aces=None)
children = []
basepath = request.urlForResource(self)
childnames = list((yield self.listChildren()))
for childname in childnames:
childpath = joinURL(basepath, urllib.quote(childname))
try:
child = (yield request.locateChildResource(self, childname))
except HTTPError, e:
log.error("Resource cannot be located: %s" % (str(e),))
if unavailablecallback:
unavailablecallback(childpath)
continue
if child is not None:
if child.isCollection():
collections1.append((child, childpath + "/"))
if names and childname not in names1:
continue
if child.isCollection():
children.append((child, childpath + "/"))
else:
children.append((child, childpath))
if missingcallback:
for name in set(names1) - set(childnames):
missingcallback(joinURL(basepath, urllib.quote(name)))
# Generate (acl,supported_privs) map
aclmap = {}
for resource, url in children:
acl = (yield resource.accessControlList(
request, inheritance=False, inherited_aces=inherited_aces
))
supportedPrivs = (yield resource.supportedPrivileges(request))
aclmap.setdefault(
(pickle.dumps(acl), supportedPrivs),
(acl, supportedPrivs, [])
)[2].append((resource, url))
# Now determine whether each ace satisfies privileges
#print(aclmap)
for items in aclmap.itervalues():
checked = (yield self.checkACLPrivilege(
request, items[0], items[1], privileges, inherited_aces
))
if checked:
for resource, url in items[2]:
if okcallback:
okcallback(resource, url)
else:
if badcallback:
for resource, url in items[2]:
badcallback(resource, url)
if depth == "infinity":
# Split names into child collection groups
child_collections = {}
for name in namesDeep:
collection, name = name.split("/", 1)
child_collections.setdefault(collection, []).append(name)
for collection, url in collections1:
collection_name = url.split("/")[-2]
if collection_name in child_collections:
collection_inherited_aces = (
yield collection.inheritedACEsforChildren(request)
)
yield collection.findChildrenFaster(
depth, request, okcallback, badcallback, missingcallback, unavailablecallback,
child_collections[collection_name] if names else None, privileges,
inherited_aces=collection_inherited_aces
)
returnValue(None)
@inlineCallbacks
def checkACLPrivilege(
self, request, acl, privyset, privileges, inherited_aces
):
if acl is None:
returnValue(False)
principal = self.currentPrincipal(request)
# Other principal types don't make sense as actors.
assert principal.children[0].name in ("unauthenticated", "href"), (
"Principal is not an actor: %r" % (principal,)
)
acl = self.fullAccessControlList(acl, inherited_aces)
pending = list(privileges)
denied = []
for ace in acl.children:
for privilege in tuple(pending):
if not self.matchPrivilege(
element.Privilege(privilege), ace.privileges, privyset
):
continue
match = (yield self.matchPrincipal(
principal, ace.principal, request
))
if match:
if ace.invert:
continue
else:
if not ace.invert:
continue
pending.remove(privilege)
if not ace.allow:
denied.append(privilege)
returnValue(len(denied) + len(pending) == 0)
def fullAccessControlList(self, acl, inherited_aces):
"""
See L{IDAVResource.accessControlList}.
This implementation looks up the ACL in the private property
C{(L{twisted_private_namespace}, "acl")}.
If no ACL has been stored for this resource, it returns the value
returned by C{defaultAccessControlList}.
If access is disabled it will return C{None}.
"""
#
# Inheritance is problematic. Here is what we do:
#
# 1. A private element is defined for use inside
# of a . This private element is removed when the ACE is
# exposed via WebDAV.
#
# 2. When checking ACLs with inheritance resolution, the server must
# examine all parent resources of the current one looking for any
# elements.
#
# If those are defined, the relevant ace is applied to the ACL on the
# current resource.
#
# Dynamically update privileges for those ace's that are inherited.
if acl:
aces = list(acl.children)
else:
aces = []
aces.extend(inherited_aces)
acl = element.ACL(*aces)
return acl
def supportedReports(self):
"""
See L{IDAVResource.supportedReports}.
This implementation lists the three main ACL reports and
expand-property.
"""
result = []
result.append(element.Report(element.ACLPrincipalPropSet(),))
result.append(element.Report(element.PrincipalMatch(),))
result.append(element.Report(element.PrincipalPropertySearch(),))
result.append(element.Report(element.ExpandProperty(),))
result.append(element.Report(customxml.CalendarServerPrincipalSearch(),))
return result
##
# Authentication
##
def authorize(self, request, privileges, recurse=False):
"""
See L{IDAVResource.authorize}.
"""
def whenAuthenticated(result):
privilegeCheck = self.checkPrivileges(request, privileges, recurse)
return privilegeCheck.addErrback(whenAccessDenied)
def whenAccessDenied(f):
f.trap(AccessDeniedError)
# If we were unauthenticated to start with (no
# Authorization header from client) then we should return
# an unauthorized response instead to force the client to
# login if it can.
# We're not adding the headers here because this response
# class is supposed to be a FORBIDDEN status code and
# "Authorization will not help" according to RFC2616
def translateError(response):
return Failure(HTTPError(response))
if request.authnUser == element.Principal(element.Unauthenticated()):
return UnauthorizedResponse.makeResponse(
request.credentialFactories,
request.remoteAddr).addCallback(translateError)
else:
return translateError(
NeedPrivilegesResponse(request.uri, f.value.errors))
d = self.authenticate(request)
d.addCallback(whenAuthenticated)
return d
def authenticate(self, request):
"""
Authenticate the given request against the portal, setting
both C{request.authzUser} (a C{str}, the username for the
purposes of authorization) and C{request.authnUser} (a C{str},
the username for the purposes of authentication) when it has
been authenticated.
In order to authenticate, the request must have been
previously prepared by
L{twext.web2.dav.auth.AuthenticationWrapper.hook} to have the
necessary authentication metadata.
If the request was not thusly prepared, both C{authzUser} and
C{authnUser} will be L{element.Unauthenticated}.
@param request: the request which may contain authentication
information and a reference to a portal to authenticate
against.
@type request: L{twext.web2.iweb.IRequest}.
@return: a L{Deferred} which fires with a 2-tuple of
C{(authnUser, authzUser)} if either the request is
unauthenticated OR contains valid credentials to
authenticate as a principal, or errbacks with L{HTTPError}
if the authentication scheme is unsupported, or the
credentials provided by the request are not valid.
"""
# Bypass normal authentication if its already been done (by SACL check)
if (
hasattr(request, "authnUser") and
hasattr(request, "authzUser") and
request.authnUser is not None and
request.authzUser is not None
):
return succeed((request.authnUser, request.authzUser))
if not (
hasattr(request, "portal") and
hasattr(request, "credentialFactories") and
hasattr(request, "loginInterfaces")
):
request.authnUser = element.Principal(element.Unauthenticated())
request.authzUser = element.Principal(element.Unauthenticated())
return succeed((request.authnUser, request.authzUser))
authHeader = request.headers.getHeader("authorization")
if authHeader is not None:
if authHeader[0] not in request.credentialFactories:
log.debug(
"Client authentication scheme %s is not provided by server %s"
% (authHeader[0], request.credentialFactories.keys())
)
d = UnauthorizedResponse.makeResponse(
request.credentialFactories,
request.remoteAddr
)
return d.addCallback(lambda response: Failure(HTTPError(response)))
else:
factory = request.credentialFactories[authHeader[0]]
def gotCreds(creds):
d = self.principalsForAuthID(request, creds.username)
d.addCallback(gotDetails, creds)
return d
# Try to match principals in each principal collection
# on the resource
def gotDetails(details, creds):
if details == (None, None):
log.info(
"Could not find the principal resource for user id: %s"
% (creds.username,)
)
return Failure(HTTPError(responsecode.UNAUTHORIZED))
authnPrincipal = IDAVPrincipalResource(details[0])
authzPrincipal = IDAVPrincipalResource(details[1])
return PrincipalCredentials(authnPrincipal, authzPrincipal, creds)
def login(pcreds):
return request.portal.login(pcreds, None, *request.loginInterfaces)
def gotAuth(result):
request.authnUser = result[1]
request.authzUser = result[2]
return (request.authnUser, request.authzUser)
def translateUnauthenticated(f):
f.trap(UnauthorizedLogin, LoginFailed)
log.info("Authentication failed: %s" % (f.value,))
d = UnauthorizedResponse.makeResponse(
request.credentialFactories, request.remoteAddr
)
d.addCallback(lambda response: Failure(HTTPError(response)))
return d
d = factory.decode(authHeader[1], request)
d.addCallback(gotCreds)
d.addCallback(login)
d.addCallbacks(gotAuth, translateUnauthenticated)
return d
else:
if (
hasattr(request, "checkedWiki") and
hasattr(request, "authnUser") and
hasattr(request, "authzUser")
):
# This request has already been authenticated via the wiki
return succeed((request.authnUser, request.authzUser))
request.authnUser = element.Principal(element.Unauthenticated())
request.authzUser = element.Principal(element.Unauthenticated())
return succeed((request.authnUser, request.authzUser))
##
# ACL
##
def currentPrincipal(self, request):
"""
@param request: the request being processed.
@return: the current authorized principal, as derived from the
given request.
"""
if hasattr(request, "authzUser"):
return request.authzUser
else:
return unauthenticatedPrincipal
def principalCollections(self):
"""
See L{IDAVResource.principalCollections}.
"""
if hasattr(self, "_principalCollections"):
return self._principalCollections
else:
return ()
def defaultRootAccessControlList(self):
"""
@return: the L{element.ACL} element containing the default
access control list for this resource.
"""
#
# The default behaviour is to allow GET access to everything
# and deny any type of write access (PUT, DELETE, etc.) to
# everything.
#
return readonlyACL
def defaultAccessControlList(self):
"""
@return: the L{element.ACL} element containing the default
access control list for this resource.
"""
#
# The default behaviour is no ACL; we should inherit from the parent
# collection.
#
return element.ACL()
def setAccessControlList(self, acl):
"""
See L{IDAVResource.setAccessControlList}.
This implementation stores the ACL in the private property
C{(L{twisted_private_namespace}, "acl")}.
"""
self.writeDeadProperty(acl)
@inlineCallbacks
def mergeAccessControlList(self, new_acl, request):
"""
Merges the supplied access control list with the one on this
resource. Merging means change all the non-inherited and
non-protected ace's in the original, and do not allow the new
one to specify an inherited or protected access control
entry. This is the behaviour required by the C{ACL}
request. (RFC 3744, section 8.1).
@param new_acl: an L{element.ACL} element
@param request: the request being processed.
@return: a tuple of the C{DAV:error} precondition element if
an error occurred, C{None} otherwise.
This implementation stores the ACL in the private property
"""
# C{(L{twisted_private_namespace}, "acl")}.
# Steps for ACL evaluation:
# 1. Check that ace's on incoming do not match a protected ace
# 2. Check that ace's on incoming do not match an inherited ace
# 3. Check that ace's on incoming all have deny before grant
# 4. Check that ace's on incoming do not use abstract privilege
# 5. Check that ace's on incoming are supported
# (and are not inherited themselves)
# 6. Check that ace's on incoming have valid principals
# 7. Copy the original
# 8. Remove all non-inherited and non-protected - and also inherited
# 9. Add in ace's from incoming
# 10. Verify that new acl is not in conflict with itself
# 11. Update acl on the resource
# Get the current access control list, preserving any private
# properties on the ACEs as we will need to keep those when we
# change the ACL.
old_acl = (yield self.accessControlList(request, expanding=True))
# Check disabled
if old_acl is None:
returnValue(None)
# Need to get list of supported privileges
supported = []
def addSupportedPrivilege(sp):
"""
Add the element in any DAV:Privilege to our list
and recurse into any DAV:SupportedPrivilege's
"""
for item in sp.children:
if isinstance(item, element.Privilege):
supported.append(item.children[0])
elif isinstance(item, element.SupportedPrivilege):
addSupportedPrivilege(item)
supportedPrivs = (yield self.supportedPrivileges(request))
for item in supportedPrivs.children:
assert isinstance(item, element.SupportedPrivilege), (
"Not a SupportedPrivilege: %r" % (item,)
)
addSupportedPrivilege(item)
# Steps 1 - 6
got_deny = False
for ace in new_acl.children:
for old_ace in old_acl.children:
if (ace.principal == old_ace.principal):
# Step 1
if old_ace.protected:
log.error("Attempt to overwrite protected ace %r "
"on resource %r"
% (old_ace, self))
returnValue((
element.dav_namespace,
"no-protected-ace-conflict"
))
# Step 2
#
# RFC3744 says that we either enforce the
# inherited ace conflict or we ignore it but use
# access control evaluation to determine whether
# there is any impact. Given that we have the
# "inheritable" behavior it does not make sense to
# disallow overrides of inherited ACEs since
# "inheritable" cannot itself be controlled via
# protocol.
#
# Otherwise, we'd use this logic:
#
#elif old_ace.inherited:
# log.error("Attempt to overwrite inherited ace %r "
# "on resource %r" % (old_ace, self))
# returnValue((
# element.dav_namespace,
# "no-inherited-ace-conflict"
# ))
# Step 3
if ace.allow and got_deny:
log.error("Attempt to set grant ace %r after deny ace "
"on resource %r"
% (ace, self))
returnValue((element.dav_namespace, "deny-before-grant"))
got_deny = not ace.allow
# Step 4: ignore as this server has no abstract privileges
# (FIXME: none yet?)
# Step 5
for privilege in ace.privileges:
if privilege.children[0] not in supported:
log.error("Attempt to use unsupported privilege %r "
"in ace %r on resource %r"
% (privilege.children[0], ace, self))
returnValue((
element.dav_namespace,
"not-supported-privilege"
))
if ace.protected:
log.error("Attempt to create protected ace %r on resource %r"
% (ace, self))
returnValue((element.dav_namespace, "no-ace-conflict"))
if ace.inherited:
log.error("Attempt to create inherited ace %r on resource %r"
% (ace, self))
returnValue((element.dav_namespace, "no-ace-conflict"))
# Step 6
valid = (yield self.validPrincipal(ace.principal, request))
if not valid:
log.error("Attempt to use unrecognized principal %r "
"in ace %r on resource %r"
% (ace.principal, ace, self))
returnValue((element.dav_namespace, "recognized-principal"))
# Step 8 & 9
#
# Iterate through the old ones and replace any that are in the
# new set, or remove the non-inherited/non-protected not in
# the new set
#
new_aces = [ace for ace in new_acl.children]
new_set = []
for old_ace in old_acl.children:
for i, new_ace in enumerate(new_aces):
if self.samePrincipal(new_ace.principal, old_ace.principal):
new_set.append(new_ace)
del new_aces[i]
break
else:
if old_ace.protected and not old_ace.inherited:
new_set.append(old_ace)
new_set.extend(new_aces)
# Step 10
# FIXME: verify acl is self-consistent
# Step 11
yield self.writeNewACEs(new_set)
returnValue(None)
def writeNewACEs(self, new_aces):
"""
Write a new ACL to the resource's property store. This is a
separate method so that it can be overridden by resources that
need to do extra processing of ACLs being set via the ACL
command.
@param new_aces: C{list} of L{ACE} for ACL being set.
"""
return self.setAccessControlList(element.ACL(*new_aces))
def matchPrivilege(self, privilege, ace_privileges, supportedPrivileges):
for ace_privilege in ace_privileges:
if (
privilege == ace_privilege or
ace_privilege.isAggregateOf(privilege, supportedPrivileges)
):
return True
return False
@inlineCallbacks
def checkPrivileges(
self, request, privileges, recurse=False,
principal=None, inherited_aces=None
):
"""
Check whether the given principal has the given privileges.
(RFC 3744, section 5.5)
@param request: the request being processed.
@param privileges: an iterable of L{WebDAVElement}
elements denoting access control privileges.
@param recurse: C{True} if a recursive check on all child
resources of this resource should be performed as well,
C{False} otherwise.
@param principal: the L{element.Principal} to check privileges
for. If C{None}, it is deduced from C{request} by calling
L{currentPrincipal}.
@param inherited_aces: a list of L{element.ACE}s corresponding
to the pre-computed inheritable aces from the parent
resource hierarchy.
@return: a L{Deferred} that callbacks with C{None} or errbacks
with an L{AccessDeniedError}
"""
if principal is None:
principal = self.currentPrincipal(request)
supportedPrivs = (yield self.supportedPrivileges(request))
# Other principals types don't make sense as actors.
assert principal.children[0].name in ("unauthenticated", "href"), (
"Principal is not an actor: %r" % (principal,)
)
errors = []
resources = [(self, None)]
if recurse:
yield self.findChildren(
"infinity", request,
lambda x, y: resources.append((x, y))
)
for resource, uri in resources:
acl = (yield
resource.accessControlList(
request,
inherited_aces=inherited_aces
)
)
# Check for disabled
if acl is None:
errors.append((uri, list(privileges)))
continue
pending = list(privileges)
denied = []
for ace in acl.children:
for privilege in tuple(pending):
if not self.matchPrivilege(
element.Privilege(privilege),
ace.privileges, supportedPrivs
):
continue
match = (yield
self.matchPrincipal(principal, ace.principal, request)
)
if match:
if ace.invert:
continue
else:
if not ace.invert:
continue
pending.remove(privilege)
if not ace.allow:
denied.append(privilege)
denied += pending # If no matching ACE, then denied
if denied:
errors.append((uri, denied))
if errors:
raise AccessDeniedError(errors,)
returnValue(None)
def supportedPrivileges(self, request):
"""
See L{IDAVResource.supportedPrivileges}.
This implementation returns a supported privilege set
containing only the DAV:all privilege.
"""
return succeed(davPrivilegeSet)
def currentPrivileges(self, request):
"""
See L{IDAVResource.currentPrivileges}.
This implementation returns a current privilege set containing
only the DAV:all privilege.
"""
current = self.currentPrincipal(request)
return self.privilegesForPrincipal(current, request)
@inlineCallbacks
def accessControlList(
self, request, inheritance=True,
expanding=False, inherited_aces=None
):
"""
See L{IDAVResource.accessControlList}.
This implementation looks up the ACL in the private property
C{(L{twisted_private_namespace}, "acl")}. If no ACL has been
stored for this resource, it returns the value returned by
C{defaultAccessControlList}. If access is disabled it will
return C{None}.
"""
#
# Inheritance is problematic. Here is what we do:
#
# 1. A private element is defined for
# use inside of a . This private element is
# removed when the ACE is exposed via WebDAV.
#
# 2. When checking ACLs with inheritance resolution, the
# server must examine all parent resources of the current
# one looking for any elements.
#
# If those are defined, the relevant ace is applied to the ACL on the
# current resource.
#
myURL = None
def getMyURL():
url = request.urlForResource(self)
assert url is not None, (
"urlForResource(self) returned None for resource %s" % (self,)
)
return url
try:
acl = self.readDeadProperty(element.ACL)
except HTTPError, e:
assert e.response.code == responsecode.NOT_FOUND, (
"Expected %s response from readDeadProperty() exception, "
"not %s"
% (responsecode.NOT_FOUND, e.response.code)
)
# Produce a sensible default for an empty ACL.
if myURL is None:
myURL = getMyURL()
if myURL == "/":
# If we get to the root without any ACLs, then use the default.
acl = self.defaultRootAccessControlList()
else:
acl = self.defaultAccessControlList()
# Dynamically update privileges for those ace's that are inherited.
if inheritance:
aces = list(acl.children)
if myURL is None:
myURL = getMyURL()
if inherited_aces is None:
if myURL != "/":
parentURL = parentForURL(myURL)
parent = (yield request.locateResource(parentURL))
if parent:
parent_acl = (yield
parent.accessControlList(
request, inheritance=True, expanding=True
)
)
# Check disabled
if parent_acl is None:
returnValue(None)
for ace in parent_acl.children:
if ace.inherited:
aces.append(ace)
elif TwistedACLInheritable() in ace.children:
# Adjust ACE for inherit on this resource
children = list(ace.children)
children.remove(TwistedACLInheritable())
children.append(
element.Inherited(element.HRef(parentURL))
)
aces.append(element.ACE(*children))
else:
aces.extend(inherited_aces)
# Always filter out any remaining private properties when we are
# returning the ACL for the final resource after doing parent
# inheritance.
if not expanding:
aces = [
element.ACE(*[
c for c in ace.children
if c != TwistedACLInheritable()
])
for ace in aces
]
acl = element.ACL(*aces)
returnValue(acl)
def inheritedACEsforChildren(self, request):
"""
Do some optimisation of access control calculation by
determining any inherited ACLs outside of the child resource
loop and supply those to the checkPrivileges on each child.
@param request: the L{IRequest} for the request in progress.
@return: a C{list} of L{Ace}s that child resources of this one
will inherit.
"""
# Get the parent ACLs with inheritance and preserve the
# element.
def gotACL(parent_acl):
# Check disabled
if parent_acl is None:
return None
# Filter out those that are not inheritable (and remove
# the inheritable element from those that are)
aces = []
for ace in parent_acl.children:
if ace.inherited:
aces.append(ace)
elif TwistedACLInheritable() in ace.children:
# Adjust ACE for inherit on this resource
children = list(ace.children)
children.remove(TwistedACLInheritable())
children.append(
element.Inherited(
element.HRef(request.urlForResource(self))
)
)
aces.append(element.ACE(*children))
return aces
d = self.accessControlList(request, inheritance=True, expanding=True)
d.addCallback(gotACL)
return d
def inheritedACLSet(self):
"""
@return: a sequence of L{element.HRef}s from which ACLs are
inherited.
This implementation returns an empty set.
"""
return []
def principalsForAuthID(self, request, authid):
"""
Return authentication and authorization principal identifiers
for the authentication identifier passed in. In this
implementation authn and authz principals are the same.
@param request: the L{IRequest} for the request in progress.
@param authid: a string containing the
authentication/authorization identifier for the principal
to lookup.
@return: a deferred tuple of two tuples. Each tuple is
C{(principal, principalURI)} where: C{principal} is the
L{Principal} that is found; {principalURI} is the C{str}
URI of the principal. The first tuple corresponds to
authentication identifiers, the second to authorization
identifiers. It will errback with an
HTTPError(responsecode.FORBIDDEN) if the principal isn't
found.
"""
authnPrincipal = self.findPrincipalForAuthID(authid)
if authnPrincipal is None:
return succeed((None, None))
d = self.authorizationPrincipal(request, authid, authnPrincipal)
d.addCallback(lambda authzPrincipal: (authnPrincipal, authzPrincipal))
return d
def findPrincipalForAuthID(self, authid):
"""
Return authentication and authorization principal identifiers
for the authentication identifier passed in. In this
implementation authn and authz principals are the same.
@param authid: a string containing the
authentication/authorization identifier for the principal
to lookup.
@return: a tuple of C{(principal, principalURI)} where:
C{principal} is the L{Principal} that is found;
{principalURI} is the C{str} URI of the principal. If not
found return None.
"""
for collection in self.principalCollections():
principal = collection.principalForUser(authid)
if principal is not None:
return principal
return None
def authorizationPrincipal(self, request, authid, authnPrincipal):
"""
Determine the authorization principal for the given request
and authentication principal. This implementation simply uses
that authentication principal as the authorization principal.
@param request: the L{IRequest} for the request in progress.
@param authid: a string containing the
authentication/authorization identifier for the principal
to lookup.
@param authnPrincipal: the L{IDAVPrincipal} for the
authenticated principal
@return: a deferred result C{tuple} of (L{IDAVPrincipal},
C{str}) containing the authorization principal resource
and URI respectively.
"""
return succeed(authnPrincipal)
def samePrincipal(self, principal1, principal2):
"""
Check whether the two principals are exactly the same in terms of
elements and data.
@param principal1: a L{Principal} to test.
@param principal2: a L{Principal} to test.
@return: C{True} if they are the same, C{False} otherwise.
"""
# The interesting part of a principal is it's one child
principal1 = principal1.children[0]
principal2 = principal2.children[0]
if type(principal1) == type(principal2):
if isinstance(principal1, element.Property):
return (
type(principal1.children[0]) ==
type(principal2.children[0])
)
elif isinstance(principal1, element.HRef):
return (
str(principal1.children[0]) ==
str(principal2.children[0])
)
else:
return True
else:
return False
def matchPrincipal(self, principal1, principal2, request):
"""
Check whether the principal1 is a principal in the set defined
by principal2.
@param principal1: a L{Principal} to test. C{principal1} must
contain a L{element.HRef} or L{element.Unauthenticated}
element.
@param principal2: a L{Principal} to test.
@param request: the request being processed.
@return: C{True} if they match, C{False} otherwise.
"""
# See RFC 3744, section 5.5.1
# The interesting part of a principal is it's one child
principal1 = principal1.children[0]
principal2 = principal2.children[0]
if not hasattr(request, "matchPrincipals"):
request.matchPrincipals = {}
cache_key = (str(principal1), str(principal2))
match = request.matchPrincipals.get(cache_key, None)
if match is not None:
return succeed(match)
def doMatch():
if isinstance(principal2, element.All):
return succeed(True)
elif isinstance(principal2, element.Authenticated):
if isinstance(principal1, element.Unauthenticated):
return succeed(False)
elif isinstance(principal1, element.All):
return succeed(False)
else:
return succeed(True)
elif isinstance(principal2, element.Unauthenticated):
if isinstance(principal1, element.Unauthenticated):
return succeed(True)
else:
return succeed(False)
elif isinstance(principal1, element.Unauthenticated):
return succeed(False)
assert isinstance(principal1, element.HRef), (
"Not an HRef: %r" % (principal1,)
)
def resolved(principal2):
assert principal2 is not None, "principal2 is None"
# Compare two HRefs and do group membership test as well
if principal1 == principal2:
return True
return self.principalIsGroupMember(
str(principal1), str(principal2), request
)
d = self.resolvePrincipal(principal2, request)
d.addCallback(resolved)
return d
def cache(match):
request.matchPrincipals[cache_key] = match
return match
d = doMatch()
d.addCallback(cache)
return d
@inlineCallbacks
def principalIsGroupMember(self, principal1, principal2, request):
"""
Check whether one principal is a group member of another.
@param principal1: C{str} principalURL for principal to test.
@param principal2: C{str} principalURL for possible group
principal to test against.
@param request: the request being processed.
@return: L{Deferred} with result C{True} if principal1 is a
member of principal2, C{False} otherwise
"""
resource1 = yield request.locateResource(principal1)
resource2 = yield request.locateResource(principal2)
if resource2 and isinstance(resource2, DAVPrincipalResource):
isContained = yield resource2.containsPrincipal(resource1)
returnValue(isContained)
returnValue(False)
def validPrincipal(self, ace_principal, request):
"""
Check whether the supplied principal is valid for this resource.
@param ace_principal: the L{Principal} element to test
@param request: the request being processed.
@return C{True} if C{ace_principal} is valid, C{False} otherwise.
This implementation tests for a valid element type and checks
for an href principal that exists inside of a principal
collection.
"""
def defer():
#
# We know that the element contains a valid element type, so all
# we need to do is check for a valid property and a valid href.
#
real_principal = ace_principal.children[0]
if isinstance(real_principal, element.Property):
# See comments in matchPrincipal(). We probably need
# some common code.
log.error("Encountered a property principal (%s), "
"but handling is not implemented."
% (real_principal,))
return False
if isinstance(real_principal, element.HRef):
return self.validHrefPrincipal(real_principal, request)
return True
return maybeDeferred(defer)
def validHrefPrincipal(self, href_principal, request):
"""
Check whether the supplied principal (in the form of an Href)
is valid for this resource.
@param href_principal: the L{Href} element to test
@param request: the request being processed.
@return C{True} if C{href_principal} is valid, C{False}
otherwise.
This implementation tests for a href element that corresponds
to a principal resource and matches the principal-URL.
"""
# Must have the principal resource type and must match the
# principal-URL
def _matchPrincipalURL(resource):
return (
isPrincipalResource(resource) and
resource.principalURL() == str(href_principal)
)
d = request.locateResource(str(href_principal))
d.addCallback(_matchPrincipalURL)
return d
def resolvePrincipal(self, principal, request):
"""
Resolves a L{element.Principal} element into a L{element.HRef}
element if possible. Specifically, the given C{principal}'s
contained element is resolved.
L{element.Property} is resolved to the URI in the contained
property.
L{element.Self} is resolved to the URI of this resource.
L{element.HRef} elements are returned as-is.
All other principals, including meta-principals
(eg. L{element.All}), resolve to C{None}.
@param principal: the L{element.Principal} child element to
resolve.
@param request: the request being processed.
@return: a deferred L{element.HRef} element or C{None}.
"""
if isinstance(principal, element.Property):
# NotImplementedError("Property principals are not implemented.")
#
# We can't raise here without potentially crippling the
# server in a way that can't be fixed over the wire, so
# let's refuse the match and log an error instead.
#
# Note: When fixing this, also fix validPrincipal()
#
log.error("Encountered a property principal (%s), "
"but handling is not implemented; invalid for ACL use."
% (principal,))
return succeed(None)
#
# FIXME: I think this is wrong - we need to get the
# namespace and name from the first child of DAV:property
#
namespace = principal.attributes.get(["namespace"], dav_namespace)
name = principal.attributes["name"]
def gotPrincipal(principal):
try:
principal = principal.getResult()
except HTTPError, e:
assert e.response.code == responsecode.NOT_FOUND, (
"%s (!= %s) status from readProperty() exception"
% (e.response.code, responsecode.NOT_FOUND)
)
return None
if not isinstance(principal, element.Principal):
log.error("Non-principal value in property %s "
"referenced by property principal."
% (encodeXMLName(namespace, name),))
return None
if len(principal.children) != 1:
return None
# The interesting part of a principal is it's one child
principal = principal.children[0]
# XXXXXX FIXME XXXXXX
d = self.readProperty((namespace, name), request)
d.addCallback(gotPrincipal)
return d
elif isinstance(principal, element.Self):
try:
self = IDAVPrincipalResource(self)
except TypeError:
log.error("DAV:self ACE is set on non-principal resource %r"
% (self,))
return succeed(None)
principal = element.HRef(self.principalURL())
if isinstance(principal, element.HRef):
return succeed(principal)
assert isinstance(principal, (
element.All,
element.Authenticated,
element.Unauthenticated
)), "Not a meta-principal: %r" % (principal,)
return succeed(None)
@inlineCallbacks
def privilegesForPrincipal(self, principal, request):
"""
See L{IDAVResource.privilegesForPrincipal}.
"""
# NB Return aggregate privileges expanded.
acl = (yield self.accessControlList(request))
# Check disabled
if acl is None:
returnValue(())
granted = []
denied = []
for ace in acl.children:
# First see if the ace's principal affects the principal
# being tested. FIXME: support the DAV:invert operation
match = (yield
self.matchPrincipal(principal, ace.principal, request)
)
if match:
# Expand aggregate privileges
ps = []
supportedPrivs = (yield
self.supportedPrivileges(request)
)
for p in ace.privileges:
ps.extend(p.expandAggregate(supportedPrivs))
# Merge grant/deny privileges
if ace.allow:
granted.extend([p for p in ps if p not in granted])
else:
denied.extend([p for p in ps if p not in denied])
# Subtract denied from granted
allowed = tuple(p for p in granted if p not in denied)
returnValue(allowed)
def matchACEinACL(self, acl, ace):
"""
Find an ACE in the ACL that matches the supplied ACE's principal.
@param acl: the L{ACL} to look at.
@param ace: the L{ACE} to try and match
@return: the L{ACE} in acl that matches, None otherwise.
"""
for a in acl.children:
if self.samePrincipal(a.principal, ace.principal):
return a
return None
def principalSearchPropertySet(self):
"""
@return: a L{element.PrincipalSearchPropertySet} element describing the
principal properties that can be searched on this principal collection,
or C{None} if this is not a principal collection.
This implementation returns None. Principal collection resources must
override and return their own suitable response.
"""
return None
##
# Quota
##
"""
The basic policy here is to define a private 'quota-root' property
on a collection. That property will contain the maximum allowed
bytes for the collections and all its contents.
In order to determine the quota property values on a resource, the
server must look for the private property on that resource and any
of its parents. If found on a parent, then that parent should be
queried for quota information. If not found, no quota exists for
the resource.
To determine that actual quota in use we will cache the used byte
count on the quota-root collection in another private property. It
is the servers responsibility to keep that property up to date by
adjusting it after every PUT, DELETE, COPY, MOVE, MKCOL,
PROPPATCH, ACL, POST or any other method that may affect the size
of stored data. If the private property is not present, the server
will fall back to getting the size by iterating over all resources
(this is done in static.py).
"""
def quota(self, request):
"""
Get current available & used quota values for this resource's
quota root collection.
@return: an L{Deferred} with result C{tuple} containing two
C{int}'s the first is quota-available-bytes, the second is
quota-used-bytes, or C{None} if quota is not defined on
the resource.
"""
# See if already cached
if hasattr(request, "quota"):
if self in request.quota:
return succeed(request.quota[self])
else:
request.quota = {}
# Find the quota root for this resource and return its data
def gotQuotaRootResource(qroot_resource):
if qroot_resource:
qroot = qroot_resource.quotaRoot(request)
if qroot is not None:
def gotUsage(used):
available = qroot - used
if available < 0:
available = 0
request.quota[self] = (available, used)
return (available, used)
d = qroot_resource.currentQuotaUse(request)
d.addCallback(gotUsage)
return d
request.quota[self] = None
return None
d = self.quotaRootResource(request)
d.addCallback(gotQuotaRootResource)
return d
def hasQuota(self, request):
"""
Check whether this resource is under quota control by checking
each parent to see if it has a quota root.
@return: C{True} if under quota control, C{False} if not.
"""
def gotQuotaRootResource(qroot_resource):
return qroot_resource is not None
d = self.quotaRootResource(request)
d.addCallback(gotQuotaRootResource)
return d
def hasQuotaRoot(self, request):
"""
@return: a C{True} if this resource has quota root, C{False} otherwise.
"""
return self.hasDeadProperty(TwistedQuotaRootProperty)
def quotaRoot(self, request):
"""
@return: a C{int} containing the maximum allowed bytes if this
collection is quota-controlled, or C{None} if not quota
controlled.
"""
if self.hasDeadProperty(TwistedQuotaRootProperty):
return int(str(self.readDeadProperty(TwistedQuotaRootProperty)))
else:
return None
@inlineCallbacks
def quotaRootResource(self, request):
"""
Return the quota root for this resource.
@return: L{DAVResource} or C{None}
"""
if self.hasQuotaRoot(request):
returnValue(self)
# Check the next parent
try:
url = request.urlForResource(self)
except NoURLForResourceError:
returnValue(None)
while (url != "/"):
url = parentForURL(url)
if url is None:
break
parent = (yield request.locateResource(url))
if parent is None:
break
if parent.hasQuotaRoot(request):
returnValue(parent)
returnValue(None)
def setQuotaRoot(self, request, maxsize):
"""
@param maxsize: a C{int} containing the maximum allowed bytes
for the contents of this collection, or C{None} to remove
quota restriction.
"""
assert self.isCollection(), "Only collections can have a quota root"
assert maxsize is None or isinstance(maxsize, int), (
"maxsize must be an int or None"
)
if maxsize is not None:
self.writeDeadProperty(TwistedQuotaRootProperty(str(maxsize)))
else:
# Remove both the root and the cached used value
self.removeDeadProperty(TwistedQuotaRootProperty)
self.removeDeadProperty(TwistedQuotaUsedProperty)
def quotaSize(self, request):
"""
Get the size of this resource (if its a collection get total
for all children as well). TODO: Take into account size of
dead-properties.
@return: a C{int} containing the size of the resource.
"""
unimplemented(self)
def checkQuota(self, request, available):
"""
Check to see whether all quota roots have sufficient available
bytes. We currently do not use hierarchical quota checks -
i.e. only the most immediate quota root parent is checked for
quota.
@param available: a C{int} containing the additional quota
required.
@return: C{True} if there is sufficient quota remaining on all
quota roots, C{False} otherwise.
"""
def _defer(quotaroot):
if quotaroot:
# Check quota on this root (if it has one)
quota = quotaroot.quotaRoot(request)
if quota is not None:
if available > quota[0]:
return False
return True
d = self.quotaRootResource(request)
d.addCallback(_defer)
return d
def quotaSizeAdjust(self, request, adjust):
"""
Update the quota used value on all quota root parents of this
resource.
@param adjust: a C{int} containing the number of bytes added
(positive) or removed (negative) that should be used to
adjust the cached total.
"""
def _defer(quotaroot):
if quotaroot:
# Check quota on this root (if it has one)
return quotaroot.updateQuotaUse(request, adjust)
d = self.quotaRootResource(request)
d.addCallback(_defer)
return d
def currentQuotaUse(self, request):
"""
Get the cached quota use value, or if not present (or invalid)
determine quota use by brute force.
@return: an L{Deferred} with a C{int} result containing the
current used byte if this collection is quota-controlled,
or C{None} if not quota controlled.
"""
assert self.isCollection(), "Only collections can have a quota root"
assert self.hasQuotaRoot(request), (
"Quota use only on quota root collection"
)
# Try to get the cached value property
if self.hasDeadProperty(TwistedQuotaUsedProperty):
return succeed(
int(str(self.readDeadProperty(TwistedQuotaUsedProperty)))
)
else:
# Do brute force size determination and cache the result
# in the private property
def _defer(result):
self.writeDeadProperty(TwistedQuotaUsedProperty(str(result)))
return result
d = self.quotaSize(request)
d.addCallback(_defer)
return d
def updateQuotaUse(self, request, adjust):
"""
Update the quota used value on this resource.
@param adjust: a C{int} containing the number of bytes added
(positive) or removed (negative) that should be used to
adjust the cached total.
@return: an L{Deferred} with a C{int} result containing the
current used byte if this collection is quota-controlled,
or C{None} if not quota controlled.
"""
assert self.isCollection(), "Only collections can have a quota root"
# Get current value
def _defer(size):
size += adjust
# Sanity check the resulting size
if size >= 0:
self.writeDeadProperty(TwistedQuotaUsedProperty(str(size)))
else:
# Remove the dead property and re-read to do brute
# force quota calc
log.info("Attempt to set quota used to a negative value: %s "
"(adjustment: %s)"
% (size, adjust,))
self.removeDeadProperty(TwistedQuotaUsedProperty)
return self.currentQuotaUse(request)
d = self.currentQuotaUse(request)
d.addCallback(_defer)
return d
##
# HTTP
##
def renderHTTP(self, request):
# FIXME: This is for testing with litmus; comment out when not in use
#litmus = request.headers.getRawHeaders("x-litmus")
#if litmus: log.info("*** Litmus test: %s ***" % (litmus,))
#
# If this is a collection and the URI doesn't end in "/", redirect.
#
if self.isCollection() and request.path[-1:] != "/":
return RedirectResponse(
request.unparseURL(
path=urllib.quote(
urllib.unquote(request.path),
safe=':/') + '/'
)
)
def setHeaders(response):
response = IResponse(response)
response.headers.setHeader("dav", self.davComplianceClasses())
#
# If this is a collection and the URI doesn't end in "/",
# add a Content-Location header. This is needed even if
# we redirect such requests (as above) in the event that
# this resource was created or modified by the request.
#
if self.isCollection() and request.path[-1:] != "/" and not response.headers.hasHeader("content-location"):
response.headers.setHeader(
"content-location", request.path + "/"
)
return response
def onError(f):
# If we get an HTTPError, run its response through
# setHeaders() as well.
f.trap(HTTPError)
return setHeaders(f.value.response)
d = maybeDeferred(super(DAVResource, self).renderHTTP, request)
return d.addCallbacks(setHeaders, onError)
class DAVLeafResource (DAVResource, LeafResource):
"""
DAV resource with no children.
"""
def findChildren(
self, depth, request, callback,
privileges=None, inherited_aces=None
):
return succeed(None)
class DAVPrincipalResource (DAVResource):
"""
Resource representing a WebDAV principal. (RFC 3744, section 2)
"""
implements(IDAVPrincipalResource)
##
# WebDAV
##
def liveProperties(self):
return super(DAVPrincipalResource, self).liveProperties() + (
(dav_namespace, "alternate-URI-set"),
(dav_namespace, "principal-URL"),
(dav_namespace, "group-member-set"),
(dav_namespace, "group-membership"),
)
def davComplianceClasses(self):
return ("1", "access-control",)
def isCollection(self):
return False
def readProperty(self, property, request):
def defer():
if type(property) is tuple:
qname = property
else:
qname = property.qname()
namespace, name = qname
if namespace == dav_namespace:
if name == "alternate-URI-set":
return element.AlternateURISet(*[
element.HRef(u) for u in self.alternateURIs()
])
if name == "principal-URL":
return element.PrincipalURL(
element.HRef(self.principalURL())
)
if name == "group-member-set":
def callback(members):
return element.GroupMemberSet(*[
element.HRef(p.principalURL())
for p in members
])
d = self.groupMembers()
d.addCallback(callback)
return d
if name == "group-membership":
def callback(memberships):
return element.GroupMembership(*[
element.HRef(g.principalURL())
for g in memberships
])
d = self.groupMemberships()
d.addCallback(callback)
return d
if name == "resourcetype":
if self.isCollection():
return element.ResourceType(
element.Collection(),
element.Principal()
)
else:
return element.ResourceType(element.Principal())
return super(DAVPrincipalResource, self).readProperty(
qname, request
)
return maybeDeferred(defer)
##
# ACL
##
def alternateURIs(self):
"""
See L{IDAVPrincipalResource.alternateURIs}.
This implementation returns C{()}. Subclasses should override
this method to provide alternate URIs for this resource if
appropriate.
"""
return ()
def principalURL(self):
"""
See L{IDAVPrincipalResource.principalURL}.
This implementation raises L{NotImplementedError}. Subclasses
must override this method to provide the principal URL for
this resource.
"""
unimplemented(self)
def groupMembers(self):
"""
This implementation returns a Deferred which fires with C{()},
which is appropriate for non-group principals. Subclasses
should override this method to provide member URLs for this
resource if appropriate.
@see: L{IDAVPrincipalResource.groupMembers}.
"""
return succeed(())
def expandedGroupMembers(self):
"""
This implementation returns a Deferred which fires with C{()},
which is appropriate for non-group principals. Subclasses
should override this method to provide expanded member URLs
for this resource if appropriate.
@see: L{IDAVPrincipalResource.expandedGroupMembers}
"""
return succeed(())
def groupMemberships(self):
"""
See L{IDAVPrincipalResource.groupMemberships}.
This implementation raises L{NotImplementedError}. Subclasses
must override this method to provide the group URLs for this
resource.
"""
unimplemented(self)
def principalMatch(self, href):
"""
Check whether the supplied principal matches this principal or
is a member of this principal resource.
@param href: the L{HRef} to test.
@return: True if there is a match, False otherwise.
"""
uri = str(href)
if self.principalURL() == uri:
return succeed(True)
else:
d = self.expandedGroupMembers()
d.addCallback(
lambda members:
uri in [member.principalURL() for member in members]
)
return d
@inlineCallbacks
def containsPrincipal(self, principal):
"""
Is the given principal contained within our expanded group membership?
@param principal: The principal to check
@type principal: L{DirectoryCalendarPrincipalResource}
@return: True if principal is a member, False otherwise
@rtype: C{boolean}
"""
members = yield self.expandedGroupMembers()
returnValue(principal in members)
class DAVPrincipalCollectionResource (DAVResource):
"""
WebDAV principal collection resource. (RFC 3744, section 5.8)
This is an abstract class; subclasses must implement
C{principalForUser} in order to properly implement it.
"""
implements(IDAVPrincipalCollectionResource)
def __init__(self, url, principalCollections=()):
"""
@param url: This resource's URL.
"""
DAVResource.__init__(self, principalCollections=principalCollections)
assert url.endswith("/"), "Collection URL must end in '/'"
self._url = url
def principalCollectionURL(self):
"""
Return the URL for this principal collection.
"""
return self._url
def principalForUser(self, user):
"""
Subclasses must implement this method.
@see: L{IDAVPrincipalCollectionResource.principalForUser}
@raise: L{NotImplementedError}
"""
raise NotImplementedError(
"%s did not implement principalForUser" % (self.__class__)
)
class AccessDeniedError(Exception):
def __init__(self, errors):
"""
An error to be raised when some request fails to meet
sufficient access privileges for a resource.
@param errors: sequence of tuples, one for each resource for
which one or more of the given privileges are not granted,
in the form C{(uri, privileges)}, where uri is a URL path
relative to resource or C{None} if the error was in this
resource, privileges is a sequence of the privileges which
are not granted a subset thereof.
"""
Exception.__init__(self, "Access denied for some resources: %r"
% (errors,))
self.errors = errors
##
# Utilities
##
def isPrincipalResource(resource):
try:
resource = IDAVPrincipalResource(resource)
except TypeError:
return False
else:
return True
class TwistedACLInheritable (WebDAVEmptyElement):
"""
When set on an ACE, this indicates that the ACE privileges should
be inherited by all child resources within the resource with this
ACE.
"""
namespace = twisted_dav_namespace
name = "inheritable"
registerElement(TwistedACLInheritable)
element.ACE.allowed_children[(twisted_dav_namespace, "inheritable")] = (0, 1)
class TwistedGETContentMD5 (WebDAVTextElement):
"""
MD5 hash of the resource content.
"""
namespace = twisted_dav_namespace
name = "getcontentmd5"
registerElement(TwistedGETContentMD5)
class TwistedQuotaRootProperty (WebDAVTextElement):
"""
When set on a collection, this property indicates that the
collection has a quota limit for the size of all resources stored
in the collection (and any associate meta-data such as
properties). The value is a number - the maximum size in bytes
allowed.
"""
namespace = twisted_private_namespace
name = "quota-root"
registerElement(TwistedQuotaRootProperty)
class TwistedQuotaUsedProperty (WebDAVTextElement):
"""
When set on a collection, this property contains the cached
running total of the size of all resources stored in the
collection (and any associate meta-data such as properties). The
value is a number - the size in bytes used.
"""
namespace = twisted_private_namespace
name = "quota-used"
registerElement(TwistedQuotaUsedProperty)
allACL = element.ACL(
element.ACE(
element.Principal(element.All()),
element.Grant(element.Privilege(element.All())),
element.Protected(),
TwistedACLInheritable()
)
)
readonlyACL = element.ACL(
element.ACE(
element.Principal(element.All()),
element.Grant(element.Privilege(element.Read())),
element.Protected(),
TwistedACLInheritable()
)
)
allPrivilegeSet = element.SupportedPrivilegeSet(
element.SupportedPrivilege(
element.Privilege(element.All()),
element.Description("all privileges", **{"xml:lang": "en"})
)
)
#
# This is one possible graph of the "standard" privileges documented
# in 3744, section 3.
#
davPrivilegeSet = element.SupportedPrivilegeSet(
element.SupportedPrivilege(
element.Privilege(element.All()),
element.Description(
"all privileges",
**{"xml:lang": "en"}
),
element.SupportedPrivilege(
element.Privilege(element.Read()),
element.Description(
"read resource",
**{"xml:lang": "en"}
),
),
element.SupportedPrivilege(
element.Privilege(element.Write()),
element.Description(
"write resource",
**{"xml:lang": "en"}
),
element.SupportedPrivilege(
element.Privilege(element.WriteProperties()),
element.Description(
"write resource properties",
**{"xml:lang": "en"}
),
),
element.SupportedPrivilege(
element.Privilege(element.WriteContent()),
element.Description(
"write resource content",
**{"xml:lang": "en"}
),
),
element.SupportedPrivilege(
element.Privilege(element.Bind()),
element.Description(
"add child resource",
**{"xml:lang": "en"}
),
),
element.SupportedPrivilege(
element.Privilege(element.Unbind()),
element.Description(
"remove child resource",
**{"xml:lang": "en"}
),
),
),
element.SupportedPrivilege(
element.Privilege(element.Unlock()),
element.Description(
"unlock resource without ownership of lock",
**{"xml:lang": "en"}
),
),
element.SupportedPrivilege(
element.Privilege(element.ReadACL()),
element.Description(
"read resource access control list",
**{"xml:lang": "en"}
),
),
element.SupportedPrivilege(
element.Privilege(element.WriteACL()),
element.Description(
"write resource access control list",
**{"xml:lang": "en"}
),
),
element.SupportedPrivilege(
element.Privilege(element.ReadCurrentUserPrivilegeSet()),
element.Description(
"read privileges for current principal",
**{"xml:lang": "en"}
),
),
),
)
unauthenticatedPrincipal = element.Principal(element.Unauthenticated())
class ResourceClass (WebDAVTextElement):
namespace = twisted_dav_namespace
name = "resource-class"
hidden = False
calendarserver-5.2+dfsg/twext/web2/dav/xattrprops.py 0000644 0001750 0001750 00000023321 12263343324 021676 0 ustar rahul rahul # Copyright (c) 2009 Twisted Matrix Laboratories.
# See LICENSE for details.
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
DAV Property store using file system extended attributes.
This API is considered private to static.py and is therefore subject to
change.
"""
__all__ = ["xattrPropertyStore"]
import urllib
import sys
import zlib
import errno
from operator import setitem
from zlib import compress, decompress
from cPickle import UnpicklingError, loads as unpickle
import xattr
if getattr(xattr, 'xattr', None) is None:
raise ImportError("wrong xattr package imported")
from twisted.python.util import untilConcludes
from twisted.python.failure import Failure
from twisted.python.log import err
from txdav.xml.base import encodeXMLName
from txdav.xml.parser import WebDAVDocument
from twext.web2 import responsecode
from twext.web2.http import HTTPError, StatusResponse
from twext.web2.dav.http import statusForFailure
# RFC 2518 Section 12.13.1 says that removal of non-existing property
# is not an error. python-xattr on Linux fails with ENODATA in this
# case. On Darwin and FreeBSD, the xattr library fails with ENOATTR,
# which CPython does not expose. Its value is 93.
_ATTR_MISSING = (93,)
if hasattr(errno, "ENODATA"):
_ATTR_MISSING += (errno.ENODATA,)
class xattrPropertyStore (object):
"""
This implementation uses Bob Ippolito's xattr package, available from::
http://undefined.org/python/#xattr
Note that the Bob's xattr package is specific to Linux and Darwin, at least
presently.
"""
#
# Dead properties are stored as extended attributes on disk. In order to
# avoid conflicts with other attributes, prefix dead property names.
#
deadPropertyXattrPrefix = "WebDAV:"
# Linux seems to require that attribute names use a "user." prefix.
# FIXME: Is is a system-wide thing, or a per-filesystem thing?
# If the latter, how to we detect the file system?
if sys.platform == "linux2":
deadPropertyXattrPrefix = "user."
def _encode(clazz, name, uid=None):
result = urllib.quote(encodeXMLName(*name), safe='{}:')
if uid:
result = uid + result
r = clazz.deadPropertyXattrPrefix + result
return r
def _decode(clazz, name):
name = urllib.unquote(name[len(clazz.deadPropertyXattrPrefix):])
index1 = name.find("{")
index2 = name.find("}")
if (index1 is -1 or index2 is -1 or not len(name) > index2):
raise ValueError("Invalid encoded name: %r" % (name,))
if index1 == 0:
uid = None
else:
uid = name[:index1]
propnamespace = name[index1+1:index2]
propname = name[index2+1:]
return (propnamespace, propname, uid)
_encode = classmethod(_encode)
_decode = classmethod(_decode)
def __init__(self, resource):
self.resource = resource
self.attrs = xattr.xattr(self.resource.fp.path)
def get(self, qname, uid=None):
"""
Retrieve the value of a property stored as an extended attribute on the
wrapped path.
@param qname: The property to retrieve as a two-tuple of namespace URI
and local name.
@param uid: The per-user identifier for per user properties.
@raise HTTPError: If there is no value associated with the given
property.
@return: A L{WebDAVDocument} representing the value associated with the
given property.
"""
try:
data = self.attrs.get(self._encode(qname, uid))
except KeyError:
raise HTTPError(StatusResponse(
responsecode.NOT_FOUND,
"No such property: %s" % (encodeXMLName(*qname),)
))
except IOError, e:
if e.errno in _ATTR_MISSING or e.errno == errno.ENOENT:
raise HTTPError(StatusResponse(
responsecode.NOT_FOUND,
"No such property: %s" % (encodeXMLName(*qname),)
))
else:
raise HTTPError(StatusResponse(
statusForFailure(Failure()),
"Unable to read property: %s" % (encodeXMLName(*qname),)
))
#
# Unserialize XML data from an xattr. The storage format has changed
# over time:
#
# 1- Started with XML
# 2- Started compressing the XML due to limits on xattr size
# 3- Switched to pickle which is faster, still compressing
# 4- Back to compressed XML for interoperability, size
#
# We only write the current format, but we also read the old
# ones for compatibility.
#
legacy = False
try:
data = decompress(data)
except zlib.error:
legacy = True
try:
doc = WebDAVDocument.fromString(data)
except ValueError:
try:
doc = unpickle(data)
except UnpicklingError:
format = "Invalid property value stored on server: %s %s"
msg = format % (encodeXMLName(*qname), data)
err(None, msg)
raise HTTPError(
StatusResponse(responsecode.INTERNAL_SERVER_ERROR, msg))
else:
legacy = True
if legacy:
self.set(doc.root_element)
return doc.root_element
def set(self, property, uid=None):
"""
Store the given property as an extended attribute on the wrapped path.
@param uid: The per-user identifier for per user properties.
@param property: A L{WebDAVElement} to store.
"""
key = self._encode(property.qname(), uid)
value = compress(property.toxml(pretty=False))
untilConcludes(setitem, self.attrs, key, value)
# Update the resource because we've modified it
self.resource.fp.restat()
def delete(self, qname, uid=None):
"""
Remove the extended attribute from the wrapped path which stores the
property given by C{qname}.
@param uid: The per-user identifier for per user properties.
@param qname: The property to delete as a two-tuple of namespace URI
and local name.
"""
key = self._encode(qname, uid)
try:
try:
self.attrs.remove(key)
except KeyError:
pass
except IOError, e:
if e.errno not in _ATTR_MISSING:
raise
except:
raise HTTPError(StatusResponse(
statusForFailure(Failure()),
"Unable to delete property: %s", (key,)
))
def contains(self, qname, uid=None):
"""
Determine whether the property given by C{qname} is stored in an
extended attribute of the wrapped path.
@param qname: The property to look up as a two-tuple of namespace URI
and local name.
@param uid: The per-user identifier for per user properties.
@return: C{True} if the property exists, C{False} otherwise.
"""
key = self._encode(qname, uid)
try:
self.attrs.get(key)
except KeyError:
return False
except IOError, e:
if e.errno in _ATTR_MISSING or e.errno == errno.ENOENT:
return False
raise HTTPError(StatusResponse(
statusForFailure(Failure()),
"Unable to read property: %s" % (key,)
))
else:
return True
def list(self, uid=None, filterByUID=True):
"""
Enumerate the property names stored in extended attributes of the
wrapped path.
@param uid: The per-user identifier for per user properties.
@return: A C{list} of property names as two-tuples of namespace URI and
local name.
"""
prefix = self.deadPropertyXattrPrefix
try:
attrs = iter(self.attrs)
except IOError, e:
if e.errno == errno.ENOENT:
return []
raise HTTPError(StatusResponse(
statusForFailure(Failure()),
"Unable to list properties: %s", (self.resource.fp.path,)
))
else:
results = [
self._decode(name)
for name in attrs
if name.startswith(prefix)
]
if filterByUID:
return [
(namespace, name)
for namespace, name, propuid in results
if propuid == uid
]
else:
return results
calendarserver-5.2+dfsg/twext/web2/dav/noneprops.py 0000644 0001750 0001750 00000004742 12263343324 021501 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
Empty DAV property store.
This API is considered private to static.py and is therefore subject to
change.
"""
__all__ = ["NonePropertyStore"]
from twext.web2 import responsecode
from twext.web2.http import HTTPError, StatusResponse
from txdav.xml.base import encodeXMLName
class NonePropertyStore (object):
"""
DAV property store which contains no properties and does not allow
properties to be set.
"""
__singleton = None
def __new__(clazz, resource):
if NonePropertyStore.__singleton is None:
NonePropertyStore.__singleton = object.__new__(clazz)
return NonePropertyStore.__singleton
def __init__(self, resource):
pass
def get(self, qname, uid=None):
raise HTTPError(StatusResponse(
responsecode.NOT_FOUND,
"No such property: %s" % (encodeXMLName(*qname),)
))
def set(self, property, uid=None):
raise HTTPError(StatusResponse(
responsecode.FORBIDDEN,
"Permission denied for setting property: %s" % (property,)
))
def delete(self, qname, uid=None):
# RFC 2518 Section 12.13.1 says that removal of
# non-existing property is not an error.
pass
def contains(self, qname, uid=None):
return False
def list(self, uid=None):
return ()
calendarserver-5.2+dfsg/twext/web2/dav/fileop.py 0000644 0001750 0001750 00000046451 12263343324 020737 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV file operations
This API is considered private to static.py and is therefore subject to
change.
"""
__all__ = [
"delete",
"copy",
"move",
"put",
"mkcollection",
"rmdir",
]
import os
import urllib
from urlparse import urlsplit
from twisted.python.failure import Failure
from twisted.internet.defer import succeed, deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.python.filepath import CachingFilePath as FilePath
from twext.web2 import responsecode
from twext.web2.http import StatusResponse, HTTPError
from twext.web2.stream import FileStream, readIntoFile
from twext.web2.dav.http import ResponseQueue, statusForFailure
log = Logger()
def delete(uri, filepath, depth="infinity"):
"""
Perform a X{DELETE} operation on the given URI, which is backed by the given
filepath.
@param filepath: the L{FilePath} to delete.
@param depth: the recursion X{Depth} for the X{DELETE} operation, which must
be "infinity".
@raise HTTPError: (containing a response with a status code of
L{responsecode.BAD_REQUEST}) if C{depth} is not "infinity".
@raise HTTPError: (containing an appropriate response) if the
delete operation fails. If C{filepath} is a directory, the response
will be a L{MultiStatusResponse}.
@return: a deferred response with a status code of L{responsecode.NO_CONTENT}
if the X{DELETE} operation succeeds.
"""
#
# Remove the file(s)
#
# FIXME: defer
if filepath.isdir():
#
# RFC 2518, section 8.6 says that we must act as if the Depth header is
# set to infinity, and that the client must omit the Depth header or set
# it to infinity, meaning that for collections, we will delete all
# members.
#
# This seems somewhat at odds with the notion that a bad request should
# be rejected outright; if the client sends a bad depth header, the
# client is broken, and RFC 2518, section 8 suggests that a bad request
# should be rejected...
#
# Let's play it safe for now and ignore broken clients.
#
if depth != "infinity":
msg = ("Client sent illegal depth header value for DELETE: %s" % (depth,))
log.error(msg)
raise HTTPError(StatusResponse(responsecode.BAD_REQUEST, msg))
#
# Recursive delete
#
# RFC 2518, section 8.6 says that if we get an error deleting a resource
# other than the collection in the request-URI, that we must respond
# with a multi-status response containing error statuses for each
# resource that we fail to delete. It also says we should not return
# no-content (success) status, which means that we should continue after
# errors, rather than aborting right away. This is interesting in that
# it's different from how most operating system tools act (eg. rm) when
# recursive filsystem deletes fail.
#
uri_path = urllib.unquote(urlsplit(uri)[2])
if uri_path[-1] == "/":
uri_path = uri_path[:-1]
log.info("Deleting directory %s" % (filepath.path,))
# NOTE: len(uri_path) is wrong if os.sep is not one byte long... meh.
request_basename = filepath.path[:-len(uri_path)]
errors = ResponseQueue(request_basename, "DELETE", responsecode.NO_CONTENT)
# FIXME: defer this
for dir, subdirs, files in os.walk(filepath.path, topdown=False):
for filename in files:
path = os.path.join(dir, filename)
try:
os.remove(path)
except:
errors.add(path, Failure())
for subdir in subdirs:
path = os.path.join(dir, subdir)
if os.path.islink(path):
try:
os.remove(path)
except:
errors.add(path, Failure())
else:
try:
os.rmdir(path)
except:
errors.add(path, Failure())
try:
os.rmdir(filepath.path)
except:
raise HTTPError(statusForFailure(
Failure(),
"deleting directory: %s" % (filepath.path,)
))
response = errors.response()
else:
#
# Delete a file; much simpler, eh?
#
log.info("Deleting file %s" % (filepath.path,))
try:
os.remove(filepath.path)
except:
raise HTTPError(statusForFailure(
Failure(),
"deleting file: %s" % (filepath.path,)
))
response = responsecode.NO_CONTENT
# Remove stat info for filepath since we deleted the backing file
filepath.changed()
return succeed(response)
def copy(source_filepath, destination_filepath, destination_uri, depth):
"""
Perform a X{COPY} from the given source and destination filepaths.
This will perform a X{DELETE} on the destination if necessary; the caller
should check and handle the X{overwrite} header before calling L{copy} (as
in L{COPYMOVE.prepareForCopy}).
@param source_filepath: a L{FilePath} for the file to copy from.
@param destination_filepath: a L{FilePath} for the file to copy to.
@param destination_uri: the URI of the destination resource.
@param depth: the recursion X{Depth} for the X{COPY} operation, which must
be one of "0", "1", or "infinity".
@raise HTTPError: (containing a response with a status code of
L{responsecode.BAD_REQUEST}) if C{depth} is not "0", "1" or "infinity".
@raise HTTPError: (containing an appropriate response) if the operation
fails. If C{source_filepath} is a directory, the response will be a
L{MultiStatusResponse}.
@return: a deferred response with a status code of L{responsecode.CREATED}
if the destination already exists, or L{responsecode.NO_CONTENT} if the
destination was created by the X{COPY} operation.
"""
if source_filepath.isfile():
#
# Copy the file
#
log.info("Copying file %s to %s" % (source_filepath.path, destination_filepath.path))
try:
source_file = source_filepath.open()
except:
raise HTTPError(statusForFailure(
Failure(),
"opening file for reading: %s" % (source_filepath.path,)
))
source_stream = FileStream(source_file)
response = waitForDeferred(put(source_stream, destination_filepath, destination_uri))
yield response
try:
response = response.getResult()
finally:
source_stream.close()
source_file.close()
checkResponse(response, "put", responsecode.NO_CONTENT, responsecode.CREATED)
yield response
return
elif source_filepath.isdir():
if destination_filepath.exists():
#
# Delete the destination
#
response = waitForDeferred(delete(destination_uri, destination_filepath))
yield response
response = response.getResult()
checkResponse(response, "delete", responsecode.NO_CONTENT)
success_code = responsecode.NO_CONTENT
else:
success_code = responsecode.CREATED
#
# Copy the directory
#
log.info("Copying directory %s to %s" % (source_filepath.path, destination_filepath.path))
source_basename = source_filepath.path
destination_basename = destination_filepath.path
errors = ResponseQueue(source_basename, "COPY", success_code)
if destination_filepath.parent().isdir():
if os.path.islink(source_basename):
link_destination = os.readlink(source_basename)
if link_destination[0] != os.path.sep:
link_destination = os.path.join(source_basename, link_destination)
try:
os.symlink(destination_basename, link_destination)
except:
errors.add(source_basename, Failure())
else:
try:
os.mkdir(destination_basename)
except:
raise HTTPError(statusForFailure(
Failure(),
"creating directory %s" % (destination_basename,)
))
if depth == "0":
yield success_code
return
else:
raise HTTPError(StatusResponse(
responsecode.CONFLICT,
"Parent collection for destination %s does not exist" % (destination_uri,)
))
#
# Recursive copy
#
# FIXME: When we report errors, do we report them on the source URI
# or on the destination URI? We're using the source URI here.
#
# FIXME: defer the walk?
source_basename_len = len(source_basename)
def paths(basepath, subpath):
source_path = os.path.join(basepath, subpath)
assert source_path.startswith(source_basename)
destination_path = os.path.join(destination_basename, source_path[source_basename_len+1:])
return source_path, destination_path
for dir, subdirs, files in os.walk(source_filepath.path, topdown=True):
for filename in files:
source_path, destination_path = paths(dir, filename)
if not os.path.isdir(os.path.dirname(destination_path)):
errors.add(source_path, responsecode.NOT_FOUND)
else:
response = waitForDeferred(copy(FilePath(source_path), FilePath(destination_path), destination_uri, depth))
yield response
response = response.getResult()
checkResponse(response, "copy", responsecode.CREATED, responsecode.NO_CONTENT)
for subdir in subdirs:
source_path, destination_path = paths(dir, subdir)
log.info("Copying directory %s to %s" % (source_path, destination_path))
if not os.path.isdir(os.path.dirname(destination_path)):
errors.add(source_path, responsecode.CONFLICT)
else:
if os.path.islink(source_path):
link_destination = os.readlink(source_path)
if link_destination[0] != os.path.sep:
link_destination = os.path.join(source_path, link_destination)
try:
os.symlink(destination_path, link_destination)
except:
errors.add(source_path, Failure())
else:
try:
os.mkdir(destination_path)
except:
errors.add(source_path, Failure())
yield errors.response()
return
else:
log.error("Unable to COPY to non-file: %s" % (source_filepath.path,))
raise HTTPError(StatusResponse(
responsecode.FORBIDDEN,
"The requested resource exists but is not backed by a regular file."
))
copy = deferredGenerator(copy)
def move(source_filepath, source_uri, destination_filepath, destination_uri, depth):
"""
Perform a X{MOVE} from the given source and destination filepaths.
This will perform a X{DELETE} on the destination if necessary; the caller
should check and handle the X{overwrite} header before calling L{copy} (as
in L{COPYMOVE.prepareForCopy}).
Following the X{DELETE}, this will attempt an atomic filesystem move. If
that fails, a X{COPY} operation followed by a X{DELETE} on the source will
be attempted instead.
@param source_filepath: a L{FilePath} for the file to copy from.
@param destination_filepath: a L{FilePath} for the file to copy to.
@param destination_uri: the URI of the destination resource.
@param depth: the recursion X{Depth} for the X{MOVE} operation, which must
be "infinity".
@raise HTTPError: (containing a response with a status code of
L{responsecode.BAD_REQUEST}) if C{depth} is not "infinity".
@raise HTTPError: (containing an appropriate response) if the operation
fails. If C{source_filepath} is a directory, the response will be a
L{MultiStatusResponse}.
@return: a deferred response with a status code of L{responsecode.CREATED}
if the destination already exists, or L{responsecode.NO_CONTENT} if the
destination was created by the X{MOVE} operation.
"""
log.info("Moving %s to %s" % (source_filepath.path, destination_filepath.path))
#
# Choose a success status
#
if destination_filepath.exists():
#
# Delete the destination
#
response = waitForDeferred(delete(destination_uri, destination_filepath))
yield response
response = response.getResult()
checkResponse(response, "delete", responsecode.NO_CONTENT)
success_code = responsecode.NO_CONTENT
else:
success_code = responsecode.CREATED
#
# See if rename (which is atomic, and fast) works
#
try:
os.rename(source_filepath.path, destination_filepath.path)
except OSError:
pass
else:
# Remove stat info from source filepath since we moved it
source_filepath.changed()
yield success_code
return
#
# Do a copy, then delete the source
#
response = waitForDeferred(copy(source_filepath, destination_filepath, destination_uri, depth))
yield response
response = response.getResult()
checkResponse(response, "copy", responsecode.CREATED, responsecode.NO_CONTENT)
response = waitForDeferred(delete(source_uri, source_filepath))
yield response
response = response.getResult()
checkResponse(response, "delete", responsecode.NO_CONTENT)
yield success_code
move = deferredGenerator(move)
def put(stream, filepath, uri=None):
"""
Perform a PUT of the given data stream into the given filepath.
@param stream: the stream to write to the destination.
@param filepath: the L{FilePath} of the destination file.
@param uri: the URI of the destination resource.
If the destination exists, if C{uri} is not C{None}, perform a
X{DELETE} operation on the destination, but if C{uri} is C{None},
delete the destination directly.
Note that whether a L{put} deletes the destination directly vs.
performing a X{DELETE} on the destination affects the response returned
in the event of an error during deletion. Specifically, X{DELETE}
on collections must return a L{MultiStatusResponse} under certain
circumstances, whereas X{PUT} isn't required to do so. Therefore,
if the caller expects X{DELETE} semantics, it must provide a valid
C{uri}.
@raise HTTPError: (containing an appropriate response) if the operation
fails.
@return: a deferred response with a status code of L{responsecode.CREATED}
if the destination already exists, or L{responsecode.NO_CONTENT} if the
destination was created by the X{PUT} operation.
"""
log.info("Writing to file %s" % (filepath.path,))
if filepath.exists():
if uri is None:
try:
if filepath.isdir():
rmdir(filepath.path)
else:
os.remove(filepath.path)
except:
raise HTTPError(statusForFailure(
Failure(),
"writing to file: %s" % (filepath.path,)
))
else:
response = waitForDeferred(delete(uri, filepath))
yield response
response = response.getResult()
checkResponse(response, "delete", responsecode.NO_CONTENT)
success_code = responsecode.NO_CONTENT
else:
success_code = responsecode.CREATED
#
# Write the contents of the request stream to resource's file
#
try:
resource_file = filepath.open("w")
except:
raise HTTPError(statusForFailure(
Failure(),
"opening file for writing: %s" % (filepath.path,)
))
try:
x = waitForDeferred(readIntoFile(stream, resource_file))
yield x
x.getResult()
except:
raise HTTPError(statusForFailure(
Failure(),
"writing to file: %s" % (filepath.path,)
))
# Remove stat info from filepath since we modified the backing file
filepath.changed()
yield success_code
put = deferredGenerator(put)
def mkcollection(filepath):
"""
Perform a X{MKCOL} on the given filepath.
@param filepath: the L{FilePath} of the collection resource to create.
@raise HTTPError: (containing an appropriate response) if the operation
fails.
@return: a deferred response with a status code of L{responsecode.CREATED}
if the destination already exists, or L{responsecode.NO_CONTENT} if the
destination was created by the X{MKCOL} operation.
"""
try:
os.mkdir(filepath.path)
# Remove stat info from filepath because we modified it
filepath.changed()
except:
raise HTTPError(statusForFailure(
Failure(),
"creating directory in MKCOL: %s" % (filepath.path,)
))
return succeed(responsecode.CREATED)
def rmdir(dirname):
"""
Removes the directory with the given name, as well as its contents.
@param dirname: the path to the directory to remove.
"""
for dir, subdirs, files in os.walk(dirname, topdown=False):
for filename in files:
os.remove(os.path.join(dir, filename))
for subdir in subdirs:
path = os.path.join(dir, subdir)
if os.path.islink(path):
os.remove(path)
else:
os.rmdir(path)
os.rmdir(dirname)
def checkResponse(response, method, *codes):
assert response in codes, \
"%s() returned %r, but should have returned one of %r instead" % (method, response, codes)
calendarserver-5.2+dfsg/twext/web2/dav/__init__.py 0000644 0001750 0001750 00000003460 12263343324 021211 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test -*-
##
# Copyright (c) 2009 Twisted Matrix Laboratories.
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
WebDAV support for Twext.Web2.
See RFC 2616: http://www.ietf.org/rfc/rfc2616.txt (HTTP)
See RFC 2518: http://www.ietf.org/rfc/rfc2518.txt (WebDAV)
See RFC 3253: http://www.ietf.org/rfc/rfc3253.txt (WebDAV Versioning Extentions)
See RFC 3744: http://www.ietf.org/rfc/rfc3744.txt (WebDAV Access Control Protocol)
See also: http://skrb.org/ietf/http_errata.html (Errata to RFC 2616)
"""
__version__ = 'SVN-Trunk'
version = __version__
__all__ = [
"auth",
"fileop",
"davxml",
"http",
"idav",
"noneprops",
"resource",
"static",
"stream",
"util",
"xattrprops",
]
calendarserver-5.2+dfsg/twext/web2/dav/static.py 0000644 0001750 0001750 00000015523 12263343324 020744 0 ustar rahul rahul # -*- test-case-name: twext.web2.dav.test.test_static -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
WebDAV-aware static resources.
"""
__all__ = ["DAVFile"]
from twisted.python.filepath import InsecurePath
from twisted.internet.defer import succeed, deferredGenerator, waitForDeferred
from twext.python.log import Logger
from twext.web2 import http_headers
from twext.web2 import responsecode
from twext.web2.dav.resource import DAVResource, davPrivilegeSet
from twext.web2.dav.resource import TwistedGETContentMD5
from twext.web2.dav.util import bindMethods
from twext.web2.http import HTTPError, StatusResponse
from twext.web2.static import File
log = Logger()
try:
from twext.web2.dav.xattrprops import xattrPropertyStore as DeadPropertyStore
except ImportError:
log.info("No dead property store available; using nonePropertyStore.")
log.info("Setting of dead properties will not be allowed.")
from twext.web2.dav.noneprops import NonePropertyStore as DeadPropertyStore
class DAVFile (DAVResource, File):
"""
WebDAV-accessible File resource.
Extends twext.web2.static.File to handle WebDAV methods.
"""
def __init__(
self, path,
defaultType="text/plain", indexNames=None,
principalCollections=()
):
"""
@param path: the path of the file backing this resource.
@param defaultType: the default mime type (as a string) for this
resource and (eg. child) resources derived from it.
@param indexNames: a sequence of index file names.
@param acl: an L{IDAVAccessControlList} with the .
"""
File.__init__(
self, path,
defaultType = defaultType,
ignoredExts = (),
processors = None,
indexNames = indexNames,
)
DAVResource.__init__(self, principalCollections=principalCollections)
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.fp.path)
##
# WebDAV
##
def etag(self):
if not self.fp.exists(): return succeed(None)
if self.hasDeadProperty(TwistedGETContentMD5):
return succeed(http_headers.ETag(str(self.readDeadProperty(TwistedGETContentMD5))))
else:
return super(DAVFile, self).etag()
def davComplianceClasses(self):
return ("1", "access-control") # Add "2" when we have locking
def deadProperties(self):
if not hasattr(self, "_dead_properties"):
self._dead_properties = DeadPropertyStore(self)
return self._dead_properties
def isCollection(self):
"""
See L{IDAVResource.isCollection}.
"""
return self.fp.isdir()
##
# ACL
##
def supportedPrivileges(self, request):
return succeed(davPrivilegeSet)
##
# Quota
##
def quotaSize(self, request):
"""
Get the size of this resource.
TODO: Take into account size of dead-properties. Does stat
include xattrs size?
@return: an L{Deferred} with a C{int} result containing the size of the resource.
"""
if self.isCollection():
def walktree(top):
"""
Recursively descend the directory tree rooted at top,
calling the callback function for each regular file
@param top: L{FilePath} for the directory to walk.
"""
total = 0
for f in top.listdir():
child = top.child(f)
if child.isdir():
# It's a directory, recurse into it
result = waitForDeferred(walktree(child))
yield result
total += result.getResult()
elif child.isfile():
# It's a file, call the callback function
total += child.getsize()
else:
# Unknown file type, print a message
pass
yield total
walktree = deferredGenerator(walktree)
return walktree(self.fp)
else:
return succeed(self.fp.getsize())
##
# Workarounds for issues with File
##
def ignoreExt(self, ext):
"""
Does nothing; doesn't apply to this subclass.
"""
pass
def locateChild(self, req, segments):
"""
See L{IResource}C{.locateChild}.
"""
# If getChild() finds a child resource, return it
try:
child = self.getChild(segments[0])
if child is not None:
return (child, segments[1:])
except InsecurePath:
raise HTTPError(StatusResponse(responsecode.FORBIDDEN, "Invalid URL path"))
# If we're not backed by a directory, we have no children.
# But check for existance first; we might be a collection resource
# that the request wants created.
self.fp.restat(False)
if self.fp.exists() and not self.fp.isdir():
return (None, ())
# OK, we need to return a child corresponding to the first segment
path = segments[0]
if path == "":
# Request is for a directory (collection) resource
return (self, ())
return (self.createSimilarFile(self.fp.child(path).path), segments[1:])
def createSimilarFile(self, path):
return self.__class__(
path, defaultType=self.defaultType, indexNames=self.indexNames[:],
principalCollections=self.principalCollections())
#
# Attach method handlers to DAVFile
#
import twext.web2.dav.method
bindMethods(twext.web2.dav.method, DAVFile)
calendarserver-5.2+dfsg/twext/web2/dav/util.py 0000644 0001750 0001750 00000015205 12263343324 020427 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_util -*-
##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# DRI: Wilfredo Sanchez, wsanchez@apple.com
##
"""
Utilities
This API is considered private to static.py and is therefore subject to
change.
"""
__all__ = [
"allDataFromStream",
"davXMLFromStream",
"noDataFromStream",
"normalizeURL",
"joinURL",
"parentForURL",
"unimplemented",
"bindMethods",
]
import urllib
from urlparse import urlsplit, urlunsplit
import posixpath # Careful; this module is not documented as public API
from twisted.python.failure import Failure
from twisted.internet.defer import succeed
from twext.python.log import Logger
from twext.web2.stream import readStream
from txdav.xml.parser import WebDAVDocument
log = Logger()
##
# Reading request body
##
def allDataFromStream(stream, filter=None):
data = []
def gotAllData(_):
if not data:
return None
result = "".join([str(x) for x in data])
if filter is None:
return result
else:
return filter(result)
return readStream(stream, data.append).addCallback(gotAllData)
def davXMLFromStream(stream):
# FIXME:
# This reads the request body into a string and then parses it.
# A better solution would parse directly and incrementally from the
# request stream.
if stream is None:
return succeed(None)
def parse(xml):
try:
doc = WebDAVDocument.fromString(xml)
doc.root_element.validate()
return doc
except ValueError:
log.error("Bad XML:\n%s" % (xml,))
raise
return allDataFromStream(stream, parse)
def noDataFromStream(stream):
def gotData(data):
if data:
raise ValueError("Stream contains unexpected data.")
return readStream(stream, gotData)
##
# URLs
##
def normalizeURL(url):
"""
Normalized a URL.
@param url: a URL.
@return: the normalized representation of C{url}. The returned URL will
never contain a trailing C{"/"}; it is up to the caller to determine
whether the resource referred to by the URL is a collection and add a
trailing C{"/"} if so.
"""
def cleanup(path):
# For some silly reason, posixpath.normpath doesn't clean up '//' at the
# start of a filename, so let's clean it up here.
if path[0] == "/":
count = 0
for char in path:
if char != "/":
break
count += 1
path = path[count - 1:]
return path
(scheme, host, path, query, fragment) = urlsplit(cleanup(url))
path = cleanup(posixpath.normpath(urllib.unquote(path)))
return urlunsplit((scheme, host, urllib.quote(path), query, fragment))
def joinURL(*urls):
"""
Appends URLs in series.
@param urls: URLs to join.
@return: the normalized URL formed by combining each URL in C{urls}. The
returned URL will contain a trailing C{"/"} if and only if the last
given URL contains a trailing C{"/"}.
"""
if len(urls) > 0 and len(urls[-1]) > 0 and urls[-1][-1] == "/":
trailing = "/"
else:
trailing = ""
url = normalizeURL("/".join([url for url in urls]))
if url == "/":
return "/"
else:
return url + trailing
def parentForURL(url):
"""
Extracts the URL of the containing collection resource for the resource
corresponding to a given URL. This removes any query or fragment pieces.
@param url: an absolute (server-relative is OK) URL.
@return: the normalized URL of the collection resource containing the
resource corresponding to C{url}. The returned URL will always contain
a trailing C{"/"}.
"""
(scheme, host, path, _ignore_query, _ignore_fragment) = urlsplit(normalizeURL(url))
index = path.rfind("/")
if index is 0:
if path == "/":
return None
else:
path = "/"
else:
if index is -1:
raise ValueError("Invalid URL: %s" % (url,))
else:
path = path[:index] + "/"
return urlunsplit((scheme, host, path, None, None))
##
# Python magic
##
def unimplemented(obj):
"""
Throw an exception signifying that the current method is unimplemented
and should not have been invoked.
"""
import inspect
caller = inspect.getouterframes(inspect.currentframe())[1][3]
raise NotImplementedError("Method %s is unimplemented in subclass %s" % (caller, obj.__class__))
def bindMethods(module, clazz, prefixes=("preconditions_", "http_", "report_")):
"""
Binds all functions in the given module (as defined by that module's
C{__all__} attribute) which start with any of the given prefixes as methods
of the given class.
@param module: the module in which to search for functions.
@param clazz: the class to bind found functions to as methods.
@param prefixes: a sequence of prefixes to match found functions against.
"""
for submodule_name in module.__all__:
try:
__import__(module.__name__ + "." + submodule_name)
except ImportError:
log.error("Unable to import module %s" % (module.__name__ + "." + submodule_name,))
Failure().raiseException()
submodule = getattr(module, submodule_name)
for method_name in submodule.__all__:
for prefix in prefixes:
if method_name.startswith(prefix):
method = getattr(submodule, method_name)
setattr(clazz, method_name, method)
break
calendarserver-5.2+dfsg/twext/web2/dav/auth.py 0000644 0001750 0001750 00000014411 12263343324 020411 0 ustar rahul rahul ##
# Copyright (c) 2005-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
##
__all__ = [
"IPrincipal",
"DavRealm",
"IPrincipalCredentials",
"PrincipalCredentials",
"AuthenticationWrapper",
]
from zope.interface import implements, Interface
from twisted.internet import defer
from twisted.cred import checkers, error, portal
from twext.web2.resource import WrapperResource
from txdav.xml.element import twisted_private_namespace, registerElement
from txdav.xml.element import WebDAVTextElement, Principal, HRef
class AuthenticationWrapper(WrapperResource):
def __init__(self, resource, portal,
wireEncryptedCredentialFactories, wireUnencryptedCredentialFactories,
loginInterfaces):
"""
Wrap the given resource and use the parameters to set up the request
to allow anyone to challenge and handle authentication.
@param resource: L{DAVResource} FIXME: This should get promoted to
twext.web2.auth
@param portal: The cred portal
@param wireEncryptedCredentialFactories: Sequence of credentialFactories
that can be used to authenticate by resources in this tree over a
wire-encrypted channel (SSL).
@param wireUnencryptedCredentialFactories: Sequence of credentialFactories
that can be used to authenticate by resources in this tree over a
wire-unencrypted channel (non-SSL).
@param loginInterfaces: More cred stuff
"""
super(AuthenticationWrapper, self).__init__(resource)
self.portal = portal
self.wireEncryptedCredentialFactories = dict([(factory.scheme, factory)
for factory in wireEncryptedCredentialFactories])
self.wireUnencryptedCredentialFactories = dict([(factory.scheme, factory)
for factory in wireUnencryptedCredentialFactories])
self.loginInterfaces = loginInterfaces
# FIXME: some unit tests access self.credentialFactories, so assigning here
self.credentialFactories = self.wireEncryptedCredentialFactories
def hook(self, req):
req.portal = self.portal
req.loginInterfaces = self.loginInterfaces
# If not using SSL, use the factory list which excludes "Basic"
if getattr(req, "chanRequest", None) is None: # This is only None in unit tests
secureConnection = True
else:
ignored, secureConnection = req.chanRequest.getHostInfo()
req.credentialFactories = (
self.wireEncryptedCredentialFactories
if secureConnection
else self.wireUnencryptedCredentialFactories
)
class IPrincipal(Interface):
pass
class DavRealm(object):
implements(portal.IRealm)
def requestAvatar(self, avatarId, mind, *interfaces):
if IPrincipal in interfaces:
return IPrincipal, Principal(HRef(avatarId[0])), Principal(HRef(avatarId[1]))
raise NotImplementedError("Only IPrincipal interface is supported")
class IPrincipalCredentials(Interface):
pass
class PrincipalCredentials(object):
implements(IPrincipalCredentials)
def __init__(self, authnPrincipal, authzPrincipal, credentials):
"""
Initialize with both authentication and authorization values. Note that in most cases theses will be the same
since HTTP auth makes no distinction between the two - but we may be layering some addition auth on top of this
(.e.g.. proxy auth, cookies, forms etc) that make result in authentication and authorization being different.
@param authnPrincipal: L{IDAVPrincipalResource} for the authenticated principal.
@param authnURI: C{str} containing the URI of the authenticated principal.
@param authzPrincipal: L{IDAVPrincipalResource} for the authorized principal.
@param authzURI: C{str} containing the URI of the authorized principal.
@param credentials: L{ICredentials} for the authentication credentials.
"""
self.authnPrincipal = authnPrincipal
self.authzPrincipal = authzPrincipal
self.credentials = credentials
def checkPassword(self, password):
return self.credentials.checkPassword(password)
class TwistedPropertyChecker(object):
implements(checkers.ICredentialsChecker)
credentialInterfaces = (IPrincipalCredentials,)
def _cbPasswordMatch(self, matched, principalURIs):
if matched:
# We return both URIs
return principalURIs
else:
raise error.UnauthorizedLogin("Bad credentials for: %s" % (principalURIs[0],))
def requestAvatarId(self, credentials):
pcreds = IPrincipalCredentials(credentials)
pswd = str(pcreds.authnPrincipal.readDeadProperty(TwistedPasswordProperty))
d = defer.maybeDeferred(credentials.checkPassword, pswd)
d.addCallback(self._cbPasswordMatch, (
pcreds.authnPrincipal.principalURL(),
pcreds.authzPrincipal.principalURL(),
pcreds.authnPrincipal,
pcreds.authzPrincipal,
))
return d
##
# Utilities
##
class TwistedPasswordProperty (WebDAVTextElement):
namespace = twisted_private_namespace
name = "password"
registerElement(TwistedPasswordProperty)
calendarserver-5.2+dfsg/twext/web2/http_headers.py 0000644 0001750 0001750 00000156207 12263343324 021362 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_http_headers -*-
##
# Copyright (c) 2008 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
from __future__ import print_function
"""
HTTP header representation, parsing, and serialization.
"""
import time
from calendar import timegm
import base64
import re
def dashCapitalize(s):
''' Capitalize a string, making sure to treat - as a word separator '''
return '-'.join([x.capitalize() for x in s.split('-')])
# datetime parsing and formatting
weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
weekdayname_lower = [name.lower() for name in weekdayname]
monthname = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
monthname_lower = [name and name.lower() for name in monthname]
# HTTP Header parsing API
header_case_mapping = {}
def casemappingify(d):
global header_case_mapping
newd = dict([(key.lower(), key) for key in d.keys()])
header_case_mapping.update(newd)
def lowerify(d):
return dict([(key.lower(), value) for key, value in d.items()])
class HeaderHandler(object):
"""HeaderHandler manages header generating and parsing functions.
"""
HTTPParsers = {}
HTTPGenerators = {}
def __init__(self, parsers=None, generators=None):
"""
@param parsers: A map of header names to parsing functions.
@type parsers: L{dict}
@param generators: A map of header names to generating functions.
@type generators: L{dict}
"""
if parsers:
self.HTTPParsers.update(parsers)
if generators:
self.HTTPGenerators.update(generators)
def parse(self, name, header):
"""
Parse the given header based on its given name.
@param name: The header name to parse.
@type name: C{str}
@param header: A list of unparsed headers.
@type header: C{list} of C{str}
@return: The return value is the parsed header representation,
it is dependent on the header. See the HTTP Headers document.
"""
parser = self.HTTPParsers.get(name, None)
if parser is None:
raise ValueError("No header parser for header '%s', either add one or use getHeaderRaw." % (name,))
try:
for p in parser:
# print("Parsing %s: %s(%s)" % (name, repr(p), repr(h)))
header = p(header)
# if isinstance(h, types.GeneratorType):
# h=list(h)
except ValueError:
header = None
return header
def generate(self, name, header):
"""
Generate the given header based on its given name.
@param name: The header name to generate.
@type name: C{str}
@param header: A parsed header, such as the output of
L{HeaderHandler}.parse.
@return: C{list} of C{str} each representing a generated HTTP header.
"""
generator = self.HTTPGenerators.get(name, None)
if generator is None:
# print(self.generators)
raise ValueError("No header generator for header '%s', either add one or use setHeaderRaw." % (name,))
for g in generator:
header = g(header)
# self._raw_headers[name] = h
return header
def updateParsers(self, parsers):
"""Update en masse the parser maps.
@param parsers: Map of header names to parser chains.
@type parsers: C{dict}
"""
casemappingify(parsers)
self.HTTPParsers.update(lowerify(parsers))
def addParser(self, name, value):
"""Add an individual parser chain for the given header.
@param name: Name of the header to add
@type name: C{str}
@param value: The parser chain
@type value: C{str}
"""
self.updateParsers({name: value})
def updateGenerators(self, generators):
"""Update en masse the generator maps.
@param parsers: Map of header names to generator chains.
@type parsers: C{dict}
"""
casemappingify(generators)
self.HTTPGenerators.update(lowerify(generators))
def addGenerators(self, name, value):
"""Add an individual generator chain for the given header.
@param name: Name of the header to add
@type name: C{str}
@param value: The generator chain
@type value: C{str}
"""
self.updateGenerators({name: value})
def update(self, parsers, generators):
"""Conveniently update parsers and generators all at once.
"""
self.updateParsers(parsers)
self.updateGenerators(generators)
DefaultHTTPHandler = HeaderHandler()
# # HTTP DateTime parser
def parseDateTime(dateString):
"""Convert an HTTP date string (one of three formats) to seconds since epoch."""
parts = dateString.split()
if not parts[0][0:3].lower() in weekdayname_lower:
# Weekday is stupid. Might have been omitted.
try:
return parseDateTime("Sun, " + dateString)
except ValueError:
# Guess not.
pass
partlen = len(parts)
if (partlen == 5 or partlen == 6) and parts[1].isdigit():
# 1st date format: Sun, 06 Nov 1994 08:49:37 GMT
# (Note: "GMT" is literal, not a variable timezone)
# (also handles without "GMT")
# This is the normal format
day = parts[1]
month = parts[2]
year = parts[3]
time = parts[4]
elif (partlen == 3 or partlen == 4) and parts[1].find('-') != -1:
# 2nd date format: Sunday, 06-Nov-94 08:49:37 GMT
# (Note: "GMT" is literal, not a variable timezone)
# (also handles without without "GMT")
# Two digit year, yucko.
day, month, year = parts[1].split('-')
time = parts[2]
year = int(year)
if year < 69:
year = year + 2000
elif year < 100:
year = year + 1900
elif len(parts) == 5:
# 3rd date format: Sun Nov 6 08:49:37 1994
# ANSI C asctime() format.
day = parts[2]
month = parts[1]
year = parts[4]
time = parts[3]
else:
raise ValueError("Unknown datetime format %r" % dateString)
day = int(day)
month = int(monthname_lower.index(month.lower()))
year = int(year)
hour, min, sec = map(int, time.split(':'))
return int(timegm((year, month, day, hour, min, sec)))
##### HTTP tokenizer
class Token(str):
__slots__ = []
tokens = {}
def __new__(self, char):
token = Token.tokens.get(char)
if token is None:
Token.tokens[char] = token = str.__new__(self, char)
return token
def __repr__(self):
return "Token(%s)" % str.__repr__(self)
# RFC 2616 section 2.2
http_tokens = " \t\"()<>@,;:\\/[]?={}"
http_ctls = "\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x7f"
def tokenize(header, foldCase=True):
"""Tokenize a string according to normal HTTP header parsing rules.
In particular:
- Whitespace is irrelevant and eaten next to special separator tokens.
Its existance (but not amount) is important between character strings.
- Quoted string support including embedded backslashes.
- Case is insignificant (and thus lowercased), except in quoted strings.
(unless foldCase=False)
- Multiple headers are concatenated with ','
NOTE: not all headers can be parsed with this function.
Takes a raw header value (list of strings), and
Returns a generator of strings and Token class instances.
"""
tokens = http_tokens
ctls = http_ctls
string = ",".join(header)
start = 0
cur = 0
quoted = False
qpair = False
inSpaces = -1
qstring = None
for x in string:
if quoted:
if qpair:
qpair = False
qstring = qstring + string[start:cur - 1] + x
start = cur + 1
elif x == '\\':
qpair = True
elif x == '"':
quoted = False
yield qstring + string[start:cur]
qstring = None
start = cur + 1
elif x in tokens:
if start != cur:
if foldCase:
yield string[start:cur].lower()
else:
yield string[start:cur]
start = cur + 1
if x == '"':
quoted = True
qstring = ""
inSpaces = False
elif x in " \t":
if inSpaces is False:
inSpaces = True
else:
inSpaces = -1
yield Token(x)
elif x in ctls:
raise ValueError("Invalid control character: %d in header" % ord(x))
else:
if inSpaces is True:
yield Token(' ')
inSpaces = False
inSpaces = False
cur = cur + 1
if qpair:
raise ValueError("Missing character after '\\'")
if quoted:
raise ValueError("Missing end quote")
if start != cur:
if foldCase:
yield string[start:cur].lower()
else:
yield string[start:cur]
def split(seq, delim):
"""The same as str.split but works on arbitrary sequences.
Too bad it's not builtin to python!"""
cur = []
for item in seq:
if item == delim:
yield cur
cur = []
else:
cur.append(item)
yield cur
# def find(seq, *args):
# """The same as seq.index but returns -1 if not found, instead
# Too bad it's not builtin to python!"""
# try:
# return seq.index(value, *args)
# except ValueError:
# return -1
def filterTokens(seq):
"""Filter out instances of Token, leaving only a list of strings.
Used instead of a more specific parsing method (e.g. splitting on commas)
when only strings are expected, so as to be a little lenient.
Apache does it this way and has some comments about broken clients which
forget commas (?), so I'm doing it the same way. It shouldn't
hurt anything, in any case.
"""
l = []
for x in seq:
if not isinstance(x, Token):
l.append(x)
return l
##### parser utilities:
def checkSingleToken(tokens):
if len(tokens) != 1:
raise ValueError("Expected single token, not %s." % (tokens,))
return tokens[0]
def parseKeyValue(val):
if len(val) == 1:
return val[0], None
elif len(val) == 3 and val[1] == Token('='):
return val[0], val[2]
raise ValueError("Expected key or key=value, but got %s." % (val,))
def parseArgs(field):
args = split(field, Token(';'))
val = args.next()
args = [parseKeyValue(arg) for arg in args]
return val, args
def listParser(fun):
"""Return a function which applies 'fun' to every element in the
comma-separated list"""
def listParserHelper(tokens):
fields = split(tokens, Token(','))
for field in fields:
if len(field) != 0:
yield fun(field)
return listParserHelper
def last(seq):
"""Return seq[-1]"""
return seq[-1]
##### Generation utilities
def quoteString(s):
"""
Quote a string according to the rules for the I{quoted-string} production
in RFC 2616 section 2.2.
@type s: C{str}
@rtype: C{str}
"""
return '"%s"' % s.replace('\\', '\\\\').replace('"', '\\"')
def listGenerator(fun):
"""Return a function which applies 'fun' to every element in
the given list, then joins the result with generateList"""
def listGeneratorHelper(l):
return generateList([fun(e) for e in l])
return listGeneratorHelper
def generateList(seq):
return ", ".join(seq)
def singleHeader(item):
return [item]
_seperators = re.compile('[' + re.escape(http_tokens) + ']')
def generateKeyValues(parameters):
"""
Format an iterable of key/value pairs.
Although each header in HTTP 1.1 redefines the grammar for the formatting
of its parameters, the grammar defined by almost all headers conforms to
the specification given in RFC 2046. Note also that RFC 2616 section 19.2
note 2 points out that many implementations fail if the value is quoted,
therefore this function only quotes the value when it is necessary.
@param parameters: An iterable of C{tuple} of a C{str} parameter name and
C{str} or C{None} parameter value which will be formated.
@return: The formatted result.
@rtype: C{str}
"""
l = []
for k, v in parameters:
if v is None:
l.append('%s' % k)
else:
if _seperators.search(v) is not None:
v = quoteString(v)
l.append('%s=%s' % (k, v))
return ";".join(l)
class MimeType(object):
def fromString(cls, mimeTypeString):
"""Generate a MimeType object from the given string.
@param mimeTypeString: The mimetype to parse
@return: L{MimeType}
"""
return DefaultHTTPHandler.parse('content-type', [mimeTypeString])
fromString = classmethod(fromString)
def __init__(self, mediaType, mediaSubtype, params={}, **kwargs):
"""
@type mediaType: C{str}
@type mediaSubtype: C{str}
@type params: C{dict}
"""
self.mediaType = mediaType
self.mediaSubtype = mediaSubtype
self.params = dict(params)
if kwargs:
self.params.update(kwargs)
def __eq__(self, other):
if not isinstance(other, MimeType):
return NotImplemented
return (self.mediaType == other.mediaType and
self.mediaSubtype == other.mediaSubtype and
self.params == other.params)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "MimeType(%r, %r, %r)" % (self.mediaType, self.mediaSubtype, self.params)
def __hash__(self):
return hash(self.mediaType) ^ hash(self.mediaSubtype) ^ hash(tuple(self.params.iteritems()))
class MimeDisposition(object):
def fromString(cls, dispositionString):
"""Generate a MimeDisposition object from the given string.
@param dispositionString: The disposition to parse
@return: L{MimeDisposition}
"""
return DefaultHTTPHandler.parse('content-disposition', [dispositionString])
fromString = classmethod(fromString)
def __init__(self, dispositionType, params={}, **kwargs):
"""
@type mediaType: C{str}
@type mediaSubtype: C{str}
@type params: C{dict}
"""
self.dispositionType = dispositionType
self.params = dict(params)
if kwargs:
self.params.update(kwargs)
def __eq__(self, other):
if not isinstance(other, MimeDisposition):
return NotImplemented
return (self.dispositionType == other.dispositionType and
self.params == other.params)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "MimeDisposition(%r, %r)" % (self.dispositionType, self.params)
def __hash__(self):
return hash(self.dispositionType) ^ hash(tuple(self.params.iteritems()))
##### Specific header parsers.
def parseAccept(field):
atype, args = parseArgs(field)
if len(atype) != 3 or atype[1] != Token('/'):
raise ValueError("MIME Type " + str(atype) + " invalid.")
# okay, this spec is screwy. A 'q' parameter is used as the separator
# between MIME parameters and (as yet undefined) additional HTTP
# parameters.
num = 0
for arg in args:
if arg[0] == 'q':
mimeparams = tuple(args[0:num])
params = args[num:]
break
num = num + 1
else:
mimeparams = tuple(args)
params = []
# Default values for parameters:
qval = 1.0
# Parse accept parameters:
for param in params:
if param[0] == 'q':
qval = float(param[1])
else:
# Warn? ignored parameter.
pass
ret = MimeType(atype[0], atype[2], mimeparams), qval
return ret
def parseAcceptQvalue(field):
atype, args = parseArgs(field)
atype = checkSingleToken(atype)
qvalue = 1.0 # Default qvalue is 1
for arg in args:
if arg[0] == 'q':
qvalue = float(arg[1])
return atype, qvalue
def addDefaultCharset(charsets):
if charsets.get('*') is None and charsets.get('iso-8859-1') is None:
charsets['iso-8859-1'] = 1.0
return charsets
def addDefaultEncoding(encodings):
if encodings.get('*') is None and encodings.get('identity') is None:
# RFC doesn't specify a default value for identity, only that it
# "is acceptable" if not mentioned. Thus, give it a very low qvalue.
encodings['identity'] = .0001
return encodings
def parseContentType(header):
# Case folding is disabled for this header, because of use of
# Content-Type: multipart/form-data; boundary=CaSeFuLsTuFf
# So, we need to explicitly .lower() the ctype and arg keys.
ctype, args = parseArgs(header)
if len(ctype) != 3 or ctype[1] != Token('/'):
raise ValueError("MIME Type " + str(ctype) + " invalid.")
args = [(kv[0].lower(), kv[1]) for kv in args]
return MimeType(ctype[0].lower(), ctype[2].lower(), tuple(args))
def parseContentDisposition(header):
# Case folding is disabled for this header, because of use of
# So, we need to explicitly .lower() the dtype and arg keys.
dtype, args = parseArgs(header)
if len(dtype) != 1:
raise ValueError("Content-Disposition " + str(dtype) + " invalid.")
args = [(kv[0].lower(), kv[1]) for kv in args]
return MimeDisposition(dtype[0].lower(), tuple(args))
def parseContentMD5(header):
try:
return base64.decodestring(header)
except Exception, e:
raise ValueError(e)
def parseContentRange(header):
"""Parse a content-range header into (kind, start, end, realLength).
realLength might be None if real length is not known ('*').
start and end might be None if start,end unspecified (for response code 416)
"""
kind, other = header.strip().split()
if kind.lower() != "bytes":
raise ValueError("a range of type %r is not supported")
startend, realLength = other.split("/")
if startend.strip() == '*':
start, end = None, None
else:
start, end = map(int, startend.split("-"))
if realLength == "*":
realLength = None
else:
realLength = int(realLength)
return (kind, start, end, realLength)
def parseExpect(field):
etype, args = parseArgs(field)
etype = parseKeyValue(etype)
return (etype[0], (lambda *args: args)(etype[1], *args))
def parseExpires(header):
# """HTTP/1.1 clients and caches MUST treat other invalid date formats,
# especially including the value 0, as in the past (i.e., "already expired")."""
try:
return parseDateTime(header)
except ValueError:
return 0
def parseIfModifiedSince(header):
# Ancient versions of netscape and *current* versions of MSIE send
# If-Modified-Since: Thu, 05 Aug 2004 12:57:27 GMT; length=123
# which is blantantly RFC-violating and not documented anywhere
# except bug-trackers for web frameworks.
# So, we'll just strip off everything after a ';'.
return parseDateTime(header.split(';', 1)[0])
def parseIfRange(headers):
try:
return ETag.parse(tokenize(headers))
except ValueError:
return parseDateTime(last(headers))
def parseRange(crange):
crange = list(crange)
if len(crange) < 3 or crange[1] != Token('='):
raise ValueError("Invalid range header format: %s" % (crange,))
rtype = crange[0]
if rtype != 'bytes':
raise ValueError("Unknown range unit: %s." % (rtype,))
rangeset = split(crange[2:], Token(','))
ranges = []
for byterangespec in rangeset:
if len(byterangespec) != 1:
raise ValueError("Invalid range header format: %s" % (crange,))
start, end = byterangespec[0].split('-')
if not start and not end:
raise ValueError("Invalid range header format: %s" % (crange,))
if start:
start = int(start)
else:
start = None
if end:
end = int(end)
else:
end = None
if start and end and start > end:
raise ValueError("Invalid range header, start > end: %s" % (crange,))
ranges.append((start, end))
return rtype, ranges
def parseRetryAfter(header):
try:
# delta seconds
return time.time() + int(header)
except ValueError:
# or datetime
return parseDateTime(header)
# WWW-Authenticate and Authorization
def parseWWWAuthenticate(tokenized):
headers = []
tokenList = list(tokenized)
while tokenList:
scheme = tokenList.pop(0)
challenge = {}
last = None
kvChallenge = False
while tokenList:
token = tokenList.pop(0)
if token == Token('='):
kvChallenge = True
challenge[last] = tokenList.pop(0)
last = None
elif token == Token(','):
if kvChallenge:
if len(tokenList) > 1 and tokenList[1] != Token('='):
break
else:
break
else:
last = token
if last and scheme and not challenge and not kvChallenge:
challenge = last
last = None
headers.append((scheme, challenge))
if last and last not in (Token('='), Token(',')):
if headers[-1] == (scheme, challenge):
scheme = last
challenge = {}
headers.append((scheme, challenge))
return headers
def parseAuthorization(header):
scheme, rest = header.split(' ', 1)
# this header isn't tokenized because it may eat characters
# in the unquoted base64 encoded credentials
return scheme.lower(), rest
def parsePrefer(field):
etype, args = parseArgs(field)
etype = parseKeyValue(etype)
return (etype[0], etype[1], args)
#### Header generators
def generateAccept(accept):
mimeType, q = accept
out = "%s/%s" % (mimeType.mediaType, mimeType.mediaSubtype)
if mimeType.params:
out += ';' + generateKeyValues(mimeType.params.iteritems())
if q != 1.0:
out += (';q=%.3f' % (q,)).rstrip('0').rstrip('.')
return out
def removeDefaultEncoding(seq):
for item in seq:
if item[0] != 'identity' or item[1] != .0001:
yield item
def generateAcceptQvalue(keyvalue):
if keyvalue[1] == 1.0:
return "%s" % keyvalue[0:1]
else:
return ("%s;q=%.3f" % keyvalue).rstrip('0').rstrip('.')
def parseCacheControl(kv):
k, v = parseKeyValue(kv)
if k == 'max-age' or k == 'min-fresh' or k == 's-maxage':
# Required integer argument
if v is None:
v = 0
else:
v = int(v)
elif k == 'max-stale':
# Optional integer argument
if v is not None:
v = int(v)
elif k == 'private' or k == 'no-cache':
# Optional list argument
if v is not None:
v = [field.strip().lower() for field in v.split(',')]
return k, v
def generateCacheControl((k, v)):
if v is None:
return str(k)
else:
if k == 'no-cache' or k == 'private':
# quoted list of values
v = quoteString(generateList(
[header_case_mapping.get(name) or dashCapitalize(name) for name in v]))
return '%s=%s' % (k, v)
def generateContentRange(tup):
"""tup is (rtype, start, end, rlen)
rlen can be None.
"""
rtype, start, end, rlen = tup
if rlen == None:
rlen = '*'
else:
rlen = int(rlen)
if start == None and end == None:
startend = '*'
else:
startend = '%d-%d' % (start, end)
return '%s %s/%s' % (rtype, startend, rlen)
def generateDateTime(secSinceEpoch):
"""Convert seconds since epoch to HTTP datetime string."""
year, month, day, hh, mm, ss, wd, _ignore_y, _ignore_z = time.gmtime(secSinceEpoch)
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
weekdayname[wd],
day, monthname[month], year,
hh, mm, ss)
return s
def generateExpect(item):
if item[1][0] is None:
out = '%s' % (item[0],)
else:
out = '%s=%s' % (item[0], item[1][0])
if len(item[1]) > 1:
out += ';' + generateKeyValues(item[1][1:])
return out
def generateRange(crange):
def noneOr(s):
if s is None:
return ''
return s
rtype, ranges = crange
if rtype != 'bytes':
raise ValueError("Unknown range unit: " + rtype + ".")
return (rtype + '=' +
','.join(['%s-%s' % (noneOr(startend[0]), noneOr(startend[1]))
for startend in ranges]))
def generateRetryAfter(when):
# always generate delta seconds format
return str(int(when - time.time()))
def generateContentType(mimeType):
out = "%s/%s" % (mimeType.mediaType, mimeType.mediaSubtype)
if mimeType.params:
out += ';' + generateKeyValues(mimeType.params.iteritems())
return out
def generateContentDisposition(disposition):
out = disposition.dispositionType
if disposition.params:
out += ';' + generateKeyValues(disposition.params.iteritems())
return out
def generateIfRange(dateOrETag):
if isinstance(dateOrETag, ETag):
return dateOrETag.generate()
else:
return generateDateTime(dateOrETag)
# WWW-Authenticate and Authorization
def generateWWWAuthenticate(headers):
_generated = []
for seq in headers:
scheme, challenge = seq[0], seq[1]
# If we're going to parse out to something other than a dict
# we need to be able to generate from something other than a dict
try:
l = []
for k, v in dict(challenge).iteritems():
l.append("%s=%s" % (k, quoteString(v)))
_generated.append("%s %s" % (scheme, ", ".join(l)))
except ValueError:
_generated.append("%s %s" % (scheme, challenge))
return _generated
def generateAuthorization(seq):
return [' '.join(seq)]
def generatePrefer(items):
key, value, args = items
if value is None:
out = '%s' % (key,)
else:
out = '%s=%s' % (key, value)
if args:
out += ';' + generateKeyValues(args)
return out
####
class ETag(object):
def __init__(self, tag, weak=False):
self.tag = str(tag)
self.weak = weak
def match(self, other, strongCompare):
# Sec 13.3.
# The strong comparison function: in order to be considered equal, both
# validators MUST be identical in every way, and both MUST NOT be weak.
#
# The weak comparison function: in order to be considered equal, both
# validators MUST be identical in every way, but either or both of
# them MAY be tagged as "weak" without affecting the result.
if not isinstance(other, ETag) or other.tag != self.tag:
return False
if strongCompare and (other.weak or self.weak):
return False
return True
def __eq__(self, other):
return isinstance(other, ETag) and other.tag == self.tag and other.weak == self.weak
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "Etag(%r, weak=%r)" % (self.tag, self.weak)
def parse(tokens):
tokens = tuple(tokens)
if len(tokens) == 1 and not isinstance(tokens[0], Token):
return ETag(tokens[0])
if(len(tokens) == 3 and tokens[0] == "w"
and tokens[1] == Token('/')):
return ETag(tokens[2], weak=True)
raise ValueError("Invalid ETag.")
parse = staticmethod(parse)
def generate(self):
if self.weak:
return 'W/' + quoteString(self.tag)
else:
return quoteString(self.tag)
def parseStarOrETag(tokens):
tokens = tuple(tokens)
if tokens == ('*',):
return '*'
else:
return ETag.parse(tokens)
def generateStarOrETag(etag):
if etag == '*':
return etag
else:
return etag.generate()
#### Cookies. Blech!
class Cookie(object):
# __slots__ = ['name', 'value', 'path', 'domain', 'ports', 'expires', 'discard', 'secure', 'comment', 'commenturl', 'version']
def __init__(self, name, value, path=None, domain=None, ports=None, expires=None, discard=False, secure=False, comment=None, commenturl=None, version=0):
self.name = name
self.value = value
self.path = path
self.domain = domain
self.ports = ports
self.expires = expires
self.discard = discard
self.secure = secure
self.comment = comment
self.commenturl = commenturl
self.version = version
def __repr__(self):
s = "Cookie(%r=%r" % (self.name, self.value)
if self.path is not None:
s += ", path=%r" % (self.path,)
if self.domain is not None:
s += ", domain=%r" % (self.domain,)
if self.ports is not None:
s += ", ports=%r" % (self.ports,)
if self.expires is not None:
s += ", expires=%r" % (self.expires,)
if self.secure is not False:
s += ", secure=%r" % (self.secure,)
if self.comment is not None:
s += ", comment=%r" % (self.comment,)
if self.commenturl is not None:
s += ", commenturl=%r" % (self.commenturl,)
if self.version != 0:
s += ", version=%r" % (self.version,)
s += ")"
return s
def __eq__(self, other):
return (isinstance(other, Cookie) and
other.path == self.path and
other.domain == self.domain and
other.ports == self.ports and
other.expires == self.expires and
other.secure == self.secure and
other.comment == self.comment and
other.commenturl == self.commenturl and
other.version == self.version)
def __ne__(self, other):
return not self.__eq__(other)
def parseCookie(headers):
"""Bleargh, the cookie spec sucks.
This surely needs interoperability testing.
There are two specs that are supported:
Version 0) http://wp.netscape.com/newsref/std/cookie_spec.html
Version 1) http://www.faqs.org/rfcs/rfc2965.html
"""
cookies = []
# There can't really be multiple cookie headers according to RFC, because
# if multiple headers are allowed, they must be joinable with ",".
# Neither new RFC2965 cookies nor old netscape cookies are.
header = ';'.join(headers)
if header[0:8].lower() == "$version":
# RFC2965 cookie
h = tokenize([header], foldCase=False)
r_cookies = split(h, Token(','))
for r_cookie in r_cookies:
last_cookie = None
rr_cookies = split(r_cookie, Token(';'))
for cookie in rr_cookies:
nameval = tuple(split(cookie, Token('=')))
if len(nameval) == 2:
(name,), (value,) = nameval
else:
(name,), = nameval
value = None
name = name.lower()
if name == '$version':
continue
if name[0] == '$':
if last_cookie is not None:
if name == '$path':
last_cookie.path = value
elif name == '$domain':
last_cookie.domain = value
elif name == '$port':
if value is None:
last_cookie.ports = ()
else:
last_cookie.ports = tuple([int(s) for s in value.split(',')])
else:
last_cookie = Cookie(name, value, version=1)
cookies.append(last_cookie)
else:
# Oldstyle cookies don't do quoted strings or anything sensible.
# All characters are valid for names except ';' and '=', and all
# characters are valid for values except ';'. Spaces are stripped,
# however.
r_cookies = header.split(';')
for r_cookie in r_cookies:
name, value = r_cookie.split('=', 1)
name = name.strip(' \t')
value = value.strip(' \t')
cookies.append(Cookie(name, value))
return cookies
cookie_validname = "[^" + re.escape(http_tokens + http_ctls) + "]*$"
cookie_validname_re = re.compile(cookie_validname)
cookie_validvalue = cookie_validname + '|"([^"]|\\\\")*"$'
cookie_validvalue_re = re.compile(cookie_validvalue)
def generateCookie(cookies):
# There's a fundamental problem with the two cookie specifications.
# They both use the "Cookie" header, and the RFC Cookie header only allows
# one version to be specified. Thus, when you have a collection of V0 and
# V1 cookies, you have to either send them all as V0 or send them all as
# V1.
# I choose to send them all as V1.
# You might think converting a V0 cookie to a V1 cookie would be lossless,
# but you'd be wrong. If you do the conversion, and a V0 parser tries to
# read the cookie, it will see a modified form of the cookie, in cases
# where quotes must be added to conform to proper V1 syntax.
# (as a real example: "Cookie: cartcontents=oid:94680,qty:1,auto:0,esp:y")
# However, that is what we will do, anyways. It has a high probability of
# breaking applications that only handle oldstyle cookies, where some other
# application set a newstyle cookie that is applicable over for site
# (or host), AND where the oldstyle cookie uses a value which is invalid
# syntax in a newstyle cookie.
# Also, the cookie name *cannot* be quoted in V1, so some cookies just
# cannot be converted at all. (e.g. "Cookie: phpAds_capAd[32]=2"). These
# are just dicarded during conversion.
# As this is an unsolvable problem, I will pretend I can just say
# OH WELL, don't do that, or else upgrade your old applications to have
# newstyle cookie parsers.
# I will note offhandedly that there are *many* sites which send V0 cookies
# that are not valid V1 cookie syntax. About 20% for my cookies file.
# However, they do not generally mix them with V1 cookies, so this isn't
# an issue, at least right now. I have not tested to see how many of those
# webapps support RFC2965 V1 cookies. I suspect not many.
max_version = max([cookie.version for cookie in cookies])
if max_version == 0:
# no quoting or anything.
return ';'.join(["%s=%s" % (cookie.name, cookie.value) for cookie in cookies])
else:
str_cookies = ['$Version="1"']
for cookie in cookies:
if cookie.version == 0:
# Version 0 cookie: we make sure the name and value are valid
# V1 syntax.
# If they are, we use them as is. This means in *most* cases,
# the cookie will look literally the same on output as it did
# on input.
# If it isn't a valid name, ignore the cookie.
# If it isn't a valid value, quote it and hope for the best on
# the other side.
if cookie_validname_re.match(cookie.name) is None:
continue
value = cookie.value
if cookie_validvalue_re.match(cookie.value) is None:
value = quoteString(value)
str_cookies.append("%s=%s" % (cookie.name, value))
else:
# V1 cookie, nice and easy
str_cookies.append("%s=%s" % (cookie.name, quoteString(cookie.value)))
if cookie.path:
str_cookies.append("$Path=%s" % quoteString(cookie.path))
if cookie.domain:
str_cookies.append("$Domain=%s" % quoteString(cookie.domain))
if cookie.ports is not None:
if len(cookie.ports) == 0:
str_cookies.append("$Port")
else:
str_cookies.append("$Port=%s" % quoteString(",".join([str(x) for x in cookie.ports])))
return ';'.join(str_cookies)
def parseSetCookie(headers):
setCookies = []
for header in headers:
try:
parts = header.split(';')
l = []
for part in parts:
namevalue = part.split('=', 1)
if len(namevalue) == 1:
name = namevalue[0]
value = None
else:
name, value = namevalue
value = value.strip(' \t')
name = name.strip(' \t')
l.append((name, value))
setCookies.append(makeCookieFromList(l, True))
except ValueError:
# If we can't parse one Set-Cookie, ignore it,
# but not the rest of Set-Cookies.
pass
return setCookies
def parseSetCookie2(toks):
outCookies = []
for cookie in [[parseKeyValue(x) for x in split(y, Token(';'))]
for y in split(toks, Token(','))]:
try:
outCookies.append(makeCookieFromList(cookie, False))
except ValueError:
# Again, if we can't handle one cookie -- ignore it.
pass
return outCookies
def makeCookieFromList(tup, netscapeFormat):
name, value = tup[0]
if name is None or value is None:
raise ValueError("Cookie has missing name or value")
if name.startswith("$"):
raise ValueError("Invalid cookie name: %r, starts with '$'." % name)
cookie = Cookie(name, value)
hadMaxAge = False
for name, value in tup[1:]:
name = name.lower()
if value is None:
if name in ("discard", "secure"):
# Boolean attrs
value = True
elif name != "port":
# Can be either boolean or explicit
continue
if name in ("comment", "commenturl", "discard", "domain", "path", "secure"):
# simple cases
setattr(cookie, name, value)
elif name == "expires" and not hadMaxAge:
if netscapeFormat and value[0] == '"' and value[-1] == '"':
value = value[1:-1]
cookie.expires = parseDateTime(value)
elif name == "max-age":
hadMaxAge = True
cookie.expires = int(value) + time.time()
elif name == "port":
if value is None:
cookie.ports = ()
else:
if netscapeFormat and value[0] == '"' and value[-1] == '"':
value = value[1:-1]
cookie.ports = tuple([int(s) for s in value.split(',')])
elif name == "version":
cookie.version = int(value)
return cookie
def generateSetCookie(cookies):
setCookies = []
for cookie in cookies:
out = ["%s=%s" % (cookie.name, cookie.value)]
if cookie.expires:
out.append("expires=%s" % generateDateTime(cookie.expires))
if cookie.path:
out.append("path=%s" % cookie.path)
if cookie.domain:
out.append("domain=%s" % cookie.domain)
if cookie.secure:
out.append("secure")
setCookies.append('; '.join(out))
return setCookies
def generateSetCookie2(cookies):
setCookies = []
for cookie in cookies:
out = ["%s=%s" % (cookie.name, quoteString(cookie.value))]
if cookie.comment:
out.append("Comment=%s" % quoteString(cookie.comment))
if cookie.commenturl:
out.append("CommentURL=%s" % quoteString(cookie.commenturl))
if cookie.discard:
out.append("Discard")
if cookie.domain:
out.append("Domain=%s" % quoteString(cookie.domain))
if cookie.expires:
out.append("Max-Age=%s" % (cookie.expires - time.time()))
if cookie.path:
out.append("Path=%s" % quoteString(cookie.path))
if cookie.ports is not None:
if len(cookie.ports) == 0:
out.append("Port")
else:
out.append("Port=%s" % quoteString(",".join([str(x) for x in cookie.ports])))
if cookie.secure:
out.append("Secure")
out.append('Version="1"')
setCookies.append('; '.join(out))
return setCookies
def parseDepth(depth):
if depth not in ("0", "1", "infinity"):
raise ValueError("Invalid depth header value: %s" % (depth,))
return depth
def parseOverWrite(overwrite):
if overwrite == "F":
return False
elif overwrite == "T":
return True
raise ValueError("Invalid overwrite header value: %s" % (overwrite,))
def generateOverWrite(overwrite):
if overwrite:
return "T"
else:
return "F"
def parseBrief(brief):
# We accept upper or lower case
if brief.upper() == "F":
return False
elif brief.upper() == "T":
return True
raise ValueError("Invalid brief header value: %s" % (brief,))
def generateBrief(brief):
# MS definition uses lower case
return "t" if brief else "f"
##### Random stuff that looks useful.
# def sortMimeQuality(s):
# def sorter(item1, item2):
# if item1[0] == '*':
# if item2[0] == '*':
# return 0
# def sortQuality(s):
# def sorter(item1, item2):
# if item1[1] < item2[1]:
# return -1
# if item1[1] < item2[1]:
# return 1
# if item1[0] == item2[0]:
# return 0
# def getMimeQuality(mimeType, accepts):
# type,args = parseArgs(mimeType)
# type=type.split(Token('/'))
# if len(type) != 2:
# raise ValueError, "MIME Type "+s+" invalid."
# for accept in accepts:
# accept,acceptQual=accept
# acceptType=accept[0:1]
# acceptArgs=accept[2]
# if ((acceptType == type or acceptType == (type[0],'*') or acceptType==('*','*')) and
# (args == acceptArgs or len(acceptArgs) == 0)):
# return acceptQual
# def getQuality(type, accepts):
# qual = accepts.get(type)
# if qual is not None:
# return qual
# return accepts.get('*')
# Headers object
class __RecalcNeeded(object):
def __repr__(self):
return ""
_RecalcNeeded = __RecalcNeeded()
class Headers(object):
"""
This class stores the HTTP headers as both a parsed representation
and the raw string representation. It converts between the two on
demand.
"""
def __init__(self, headers=None, rawHeaders=None, handler=DefaultHTTPHandler):
self._raw_headers = {}
self._headers = {}
self.handler = handler
if headers is not None:
for key, value in headers.iteritems():
self.setHeader(key, value)
if rawHeaders is not None:
for key, value in rawHeaders.iteritems():
self.setRawHeaders(key, value)
def _setRawHeaders(self, headers):
self._raw_headers = headers
self._headers = {}
def _toParsed(self, name):
r = self._raw_headers.get(name, None)
h = self.handler.parse(name, r)
if h is not None:
self._headers[name] = h
return h
def _toRaw(self, name):
h = self._headers.get(name, None)
r = self.handler.generate(name, h)
if r is not None:
self._raw_headers[name] = r
return r
def hasHeader(self, name):
"""Does a header with the given name exist?"""
name = name.lower()
return name in self._raw_headers
def getRawHeaders(self, name, default=None):
"""Returns a list of headers matching the given name as the raw string given."""
name = name.lower()
raw_header = self._raw_headers.get(name, default)
if raw_header is not _RecalcNeeded:
return raw_header
return self._toRaw(name)
def getHeader(self, name, default=None):
"""Ret9urns the parsed representation of the given header.
The exact form of the return value depends on the header in question.
If no parser for the header exists, raise ValueError.
If the header doesn't exist, return default (or None if not specified)
"""
name = name.lower()
parsed = self._headers.get(name, default)
if parsed is not _RecalcNeeded:
return parsed
return self._toParsed(name)
def setRawHeaders(self, name, value):
"""Sets the raw representation of the given header.
Value should be a list of strings, each being one header of the
given name.
"""
name = name.lower()
self._raw_headers[name] = value
self._headers[name] = _RecalcNeeded
def setHeader(self, name, value):
"""Sets the parsed representation of the given header.
Value should be a list of objects whose exact form depends
on the header in question.
"""
name = name.lower()
self._raw_headers[name] = _RecalcNeeded
self._headers[name] = value
def addRawHeader(self, name, value):
"""
Add a raw value to a header that may or may not already exist.
If it exists, add it as a separate header to output; do not
replace anything.
"""
name = name.lower()
raw_header = self._raw_headers.get(name)
if raw_header is None:
# No header yet
raw_header = []
self._raw_headers[name] = raw_header
elif raw_header is _RecalcNeeded:
raw_header = self._toRaw(name)
raw_header.append(value)
self._headers[name] = _RecalcNeeded
def removeHeader(self, name):
"""Removes the header named."""
name = name.lower()
if name in self._raw_headers:
del self._raw_headers[name]
del self._headers[name]
def __repr__(self):
return '' % (self._raw_headers, self._headers)
def canonicalNameCaps(self, name):
"""Return the name with the canonical capitalization, if known,
otherwise, Caps-After-Dashes"""
return header_case_mapping.get(name) or dashCapitalize(name)
def getAllRawHeaders(self):
"""Return an iterator of key,value pairs of all headers
contained in this object, as strings. The keys are capitalized
in canonical capitalization."""
for k, v in self._raw_headers.iteritems():
if v is _RecalcNeeded:
v = self._toRaw(k)
yield self.canonicalNameCaps(k), v
def makeImmutable(self):
"""Make this header set immutable. All mutating operations will
raise an exception."""
self.setHeader = self.setRawHeaders = self.removeHeader = self._mutateRaise
def _mutateRaise(self, *args):
raise AttributeError("This header object is immutable as the headers have already been sent.")
"""The following dicts are all mappings of header to list of operations
to perform. The first operation should generally be 'tokenize' if the
header can be parsed according to the normal tokenization rules. If
it cannot, generally the first thing you want to do is take only the
last instance of the header (in case it was sent multiple times, which
is strictly an error, but we're nice.).
"""
iteritems = lambda x: x.iteritems()
parser_general_headers = {
'Cache-Control': (tokenize, listParser(parseCacheControl), dict),
'Connection': (tokenize, filterTokens),
'Date': (last, parseDateTime),
# 'Pragma': tokenize
# 'Trailer': tokenize
'Transfer-Encoding': (tokenize, filterTokens),
# 'Upgrade': tokenize
# 'Via': tokenize,stripComment
# 'Warning': tokenize
}
generator_general_headers = {
'Cache-Control': (iteritems, listGenerator(generateCacheControl), singleHeader),
'Connection': (generateList, singleHeader),
'Date': (generateDateTime, singleHeader),
# 'Pragma':
# 'Trailer':
'Transfer-Encoding': (generateList, singleHeader),
# 'Upgrade':
# 'Via':
# 'Warning':
}
parser_request_headers = {
'Accept': (tokenize, listParser(parseAccept), dict),
'Accept-Charset': (tokenize, listParser(parseAcceptQvalue), dict, addDefaultCharset),
'Accept-Encoding': (tokenize, listParser(parseAcceptQvalue), dict, addDefaultEncoding),
'Accept-Language': (tokenize, listParser(parseAcceptQvalue), dict),
'Authorization': (last, parseAuthorization),
'Cookie': (parseCookie,),
'Expect': (tokenize, listParser(parseExpect), dict),
'From': (last,),
'Host': (last,),
'If-Match': (tokenize, listParser(parseStarOrETag), list),
'If-Modified-Since': (last, parseIfModifiedSince),
'If-None-Match': (tokenize, listParser(parseStarOrETag), list),
'If-Range': (parseIfRange,),
'If-Unmodified-Since': (last, parseDateTime),
'Max-Forwards': (last, int),
'Prefer': (tokenize, listParser(parsePrefer), list),
# 'Proxy-Authorization': str, # what is "credentials"
'Range': (tokenize, parseRange),
'Referer': (last, str), # TODO: URI object?
'TE': (tokenize, listParser(parseAcceptQvalue), dict),
'User-Agent': (last, str),
}
generator_request_headers = {
'Accept': (iteritems, listGenerator(generateAccept), singleHeader),
'Accept-Charset': (iteritems, listGenerator(generateAcceptQvalue), singleHeader),
'Accept-Encoding': (iteritems, removeDefaultEncoding, listGenerator(generateAcceptQvalue), singleHeader),
'Accept-Language': (iteritems, listGenerator(generateAcceptQvalue), singleHeader),
'Authorization': (generateAuthorization,), # what is "credentials"
'Cookie': (generateCookie, singleHeader),
'Expect': (iteritems, listGenerator(generateExpect), singleHeader),
'From': (str, singleHeader),
'Host': (str, singleHeader),
'If-Match': (listGenerator(generateStarOrETag), singleHeader),
'If-Modified-Since': (generateDateTime, singleHeader),
'If-None-Match': (listGenerator(generateStarOrETag), singleHeader),
'If-Range': (generateIfRange, singleHeader),
'If-Unmodified-Since': (generateDateTime, singleHeader),
'Max-Forwards': (str, singleHeader),
'Prefer': (listGenerator(generatePrefer), singleHeader),
# 'Proxy-Authorization': str, # what is "credentials"
'Range': (generateRange, singleHeader),
'Referer': (str, singleHeader),
'TE': (iteritems, listGenerator(generateAcceptQvalue), singleHeader),
'User-Agent': (str, singleHeader),
}
parser_response_headers = {
'Accept-Ranges': (tokenize, filterTokens),
'Age': (last, int),
'ETag': (tokenize, ETag.parse),
'Location': (last,), # TODO: URI object?
# 'Proxy-Authenticate'
'Retry-After': (last, parseRetryAfter),
'Server': (last,),
'Set-Cookie': (parseSetCookie,),
'Set-Cookie2': (tokenize, parseSetCookie2),
'Vary': (tokenize, filterTokens),
'WWW-Authenticate': (lambda h: tokenize(h, foldCase=False),
parseWWWAuthenticate,)
}
generator_response_headers = {
'Accept-Ranges': (generateList, singleHeader),
'Age': (str, singleHeader),
'ETag': (ETag.generate, singleHeader),
'Location': (str, singleHeader),
# 'Proxy-Authenticate'
'Retry-After': (generateRetryAfter, singleHeader),
'Server': (str, singleHeader),
'Set-Cookie': (generateSetCookie,),
'Set-Cookie2': (generateSetCookie2,),
'Vary': (generateList, singleHeader),
'WWW-Authenticate': (generateWWWAuthenticate,)
}
parser_entity_headers = {
'Allow': (lambda hdr: tokenize(hdr, foldCase=False), filterTokens),
'Content-Disposition': (lambda hdr: tokenize(hdr, foldCase=False), parseContentDisposition),
'Content-Encoding': (tokenize, filterTokens),
'Content-Language': (tokenize, filterTokens),
'Content-Length': (last, int),
'Content-Location': (last,), # TODO: URI object?
'Content-MD5': (last, parseContentMD5),
'Content-Range': (last, parseContentRange),
'Content-Type': (lambda hdr: tokenize(hdr, foldCase=False), parseContentType),
'Expires': (last, parseExpires),
'Last-Modified': (last, parseDateTime),
}
generator_entity_headers = {
'Allow': (generateList, singleHeader),
'Content-Disposition': (generateContentDisposition, singleHeader),
'Content-Encoding': (generateList, singleHeader),
'Content-Language': (generateList, singleHeader),
'Content-Length': (str, singleHeader),
'Content-Location': (str, singleHeader),
'Content-MD5': (base64.encodestring, lambda x: x.strip("\n"), singleHeader),
'Content-Range': (generateContentRange, singleHeader),
'Content-Type': (generateContentType, singleHeader),
'Expires': (generateDateTime, singleHeader),
'Last-Modified': (generateDateTime, singleHeader),
}
parser_dav_headers = {
'Brief' : (last, parseBrief),
'DAV' : (tokenize, list),
'Depth' : (last, parseDepth),
'Destination' : (last,), # TODO: URI object?
# 'If' : (),
# 'Lock-Token' : (),
'Overwrite' : (last, parseOverWrite),
# 'Status-URI' : (),
# 'Timeout' : (),
}
generator_dav_headers = {
'Brief' : (),
'DAV' : (generateList, singleHeader),
'Depth' : (singleHeader),
'Destination' : (singleHeader),
# 'If' : (),
# 'Lock-Token' : (),
'Overwrite' : (),
# 'Status-URI' : (),
# 'Timeout' : (),
}
DefaultHTTPHandler.updateParsers(parser_general_headers)
DefaultHTTPHandler.updateParsers(parser_request_headers)
DefaultHTTPHandler.updateParsers(parser_response_headers)
DefaultHTTPHandler.updateParsers(parser_entity_headers)
DefaultHTTPHandler.updateParsers(parser_dav_headers)
DefaultHTTPHandler.updateGenerators(generator_general_headers)
DefaultHTTPHandler.updateGenerators(generator_request_headers)
DefaultHTTPHandler.updateGenerators(generator_response_headers)
DefaultHTTPHandler.updateGenerators(generator_entity_headers)
DefaultHTTPHandler.updateGenerators(generator_dav_headers)
# casemappingify(DefaultHTTPParsers)
# casemappingify(DefaultHTTPGenerators)
# lowerify(DefaultHTTPParsers)
# lowerify(DefaultHTTPGenerators)
calendarserver-5.2+dfsg/twext/web2/error.py 0000644 0001750 0001750 00000021177 12263343324 020036 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_log -*-
##
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
Default error output filter for twext.web2.
"""
from twext.web2 import stream, http_headers
from twext.web2.responsecode import (
MOVED_PERMANENTLY, FOUND, SEE_OTHER, USE_PROXY, TEMPORARY_REDIRECT,
BAD_REQUEST, UNAUTHORIZED, PAYMENT_REQUIRED, FORBIDDEN, NOT_FOUND,
NOT_ALLOWED, NOT_ACCEPTABLE, PROXY_AUTH_REQUIRED, REQUEST_TIMEOUT, CONFLICT,
GONE, LENGTH_REQUIRED, PRECONDITION_FAILED, REQUEST_ENTITY_TOO_LARGE,
REQUEST_URI_TOO_LONG, UNSUPPORTED_MEDIA_TYPE,
REQUESTED_RANGE_NOT_SATISFIABLE, EXPECTATION_FAILED, INTERNAL_SERVER_ERROR,
NOT_IMPLEMENTED, BAD_GATEWAY, SERVICE_UNAVAILABLE, GATEWAY_TIMEOUT,
HTTP_VERSION_NOT_SUPPORTED, INSUFFICIENT_STORAGE_SPACE, NOT_EXTENDED,
RESPONSES,
)
from twisted.web.template import Element, flattenString, XMLString, renderer
# 300 - Should include entity with choices
# 301 -
# 304 - Must include Date, ETag, Content-Location, Expires, Cache-Control, Vary.
# 401 - Must include WWW-Authenticate.
# 405 - Must include Allow.
# 406 - Should include entity describing allowable characteristics
# 407 - Must include Proxy-Authenticate
# 413 - May include Retry-After
# 416 - Should include Content-Range
# 503 - Should include Retry-After
ERROR_MESSAGES = {
# 300
# no MULTIPLE_CHOICES
MOVED_PERMANENTLY:
'The document has permanently moved here'
'.',
FOUND:
'The document has temporarily moved here'
'.',
SEE_OTHER:
'The results are available here'
'.',
# no NOT_MODIFIED
USE_PROXY:
'Access to this resource must be through the proxy '
'.',
# 306 unused
TEMPORARY_REDIRECT:
'The document has temporarily moved '
'here.',
# 400
BAD_REQUEST:
'Your browser sent an invalid request.',
UNAUTHORIZED:
'You are not authorized to view the resource at . '
"Perhaps you entered a wrong password, or perhaps your browser doesn't "
'support authentication.',
PAYMENT_REQUIRED:
'Payment Required (useful result code, this...).',
FORBIDDEN:
'You don\'t have permission to access .',
NOT_FOUND:
'The resource cannot be found.',
NOT_ALLOWED:
'The requested method is not supported by '
'.',
NOT_ACCEPTABLE:
'No representation of that is acceptable to your '
'client could be found.',
PROXY_AUTH_REQUIRED:
'You are not authorized to view the resource at . '
'Perhaps you entered a wrong password, or perhaps your browser doesn\'t '
'support authentication.',
REQUEST_TIMEOUT:
'Server timed out waiting for your client to finish sending the request.',
CONFLICT:
'Conflict (?)',
GONE:
'The resource has been permanently removed.',
LENGTH_REQUIRED:
'The resource requires a Content-Length header.',
PRECONDITION_FAILED:
'A precondition evaluated to false.',
REQUEST_ENTITY_TOO_LARGE:
'The provided request entity data is too longer than the maximum for '
'the method at .',
REQUEST_URI_TOO_LONG:
'The request URL is longer than the maximum on this server.',
UNSUPPORTED_MEDIA_TYPE:
'The provided request data has a format not understood by the resource '
'at .',
REQUESTED_RANGE_NOT_SATISFIABLE:
'None of the ranges given in the Range request header are satisfiable by '
'the resource .',
EXPECTATION_FAILED:
'The server does support one of the expectations given in the Expect '
'header.',
# 500
INTERNAL_SERVER_ERROR:
'An internal error occurred trying to process your request. Sorry.',
NOT_IMPLEMENTED:
'Some functionality requested is not implemented on this server.',
BAD_GATEWAY:
'An upstream server returned an invalid response.',
SERVICE_UNAVAILABLE:
'This server cannot service your request becaues it is overloaded.',
GATEWAY_TIMEOUT:
'An upstream server is not responding.',
HTTP_VERSION_NOT_SUPPORTED:
'HTTP Version not supported.',
INSUFFICIENT_STORAGE_SPACE:
'There is insufficient storage space available to perform that request.',
NOT_EXTENDED:
'This server does not support the a mandatory extension requested.'
}
class DefaultErrorElement(Element):
"""
An L{ErrorElement} is an L{Element} that renders some HTML for the default
rendering of an error page.
"""
loader = XMLString("""
""")
def __init__(self, request, response):
super(DefaultErrorElement, self).__init__()
self.request = request
self.response = response
@renderer
def error(self, request, tag):
"""
Top-level renderer for page.
"""
return tag.fillSlots(
code=str(self.response.code),
title=RESPONSES.get(self.response.code),
message=self.loadMessage(self.response.code).fillSlots(
uri=self.request.uri,
location=self.response.headers.getHeader('location'),
method=self.request.method,
)
)
def loadMessage(self, code):
tag = XMLString(('') +
ERROR_MESSAGES.get(code, "") +
'').load()[0]
return tag
def defaultErrorHandler(request, response):
"""
Handle errors which do not have any stream (i.e. output) associated with
them, so that users will see a nice message in their browser.
This is used as a response filter in L{twext.web2.server.Request}.
"""
if response.stream is not None:
# Already got an error message
return response
if response.code < 300:
# We only do error messages
return response
message = ERROR_MESSAGES.get(response.code, None)
if message is None:
# No message specified for that code
return response
message = message % {
'uri': request.uri,
'location': response.headers.getHeader('location'),
'method': request.method,
}
data = []
error = []
(flattenString(request, DefaultErrorElement(request, response))
.addCallbacks(data.append, error.append))
# No deferreds from our renderers above, so this has always already fired.
if data:
subtype = 'html'
body = data[0]
else:
subtype = 'error'
body = 'Error in default error handler:\n' + error[0].getTraceback()
ctype = http_headers.MimeType('text', subtype, {'charset':'utf-8'})
response.headers.setHeader("content-type", ctype)
response.stream = stream.MemoryStream(body)
return response
defaultErrorHandler.handleErrors = True
__all__ = ['defaultErrorHandler',]
calendarserver-5.2+dfsg/twext/web2/test/ 0000755 0001750 0001750 00000000000 12322625325 017302 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/test/test_fileupload.py 0000644 0001750 0001750 00000021000 11337102650 023024 0 ustar rahul rahul # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twext.web2.fileupload} and its different parsing functions.
"""
from twisted.internet import defer
from twisted.trial import unittest
from twisted.internet.defer import waitForDeferred, deferredGenerator
from twext.web2 import stream, fileupload
from twext.web2.http_headers import MimeType
class TestStream(stream.SimpleStream):
"""
A stream that reads less data at a time than it could.
"""
def __init__(self, mem, maxReturn=1000, start=0, length=None):
self.mem = mem
self.start = start
self.maxReturn = maxReturn
if length is None:
self.length = len(mem) - start
else:
if len(mem) < length:
raise ValueError("len(mem) < start + length")
self.length = length
def read(self):
if self.mem is None:
return None
if self.length == 0:
result = None
else:
amtToRead = min(self.maxReturn, self.length)
result = buffer(self.mem, self.start, amtToRead)
self.length -= amtToRead
self.start += amtToRead
return result
def close(self):
self.mem = None
stream.SimpleStream.close(self)
class MultipartTests(unittest.TestCase):
def doTestError(self, boundary, data, expected_error):
# Test different amounts of data at a time.
ds = [fileupload.parseMultipartFormData(TestStream(data,
maxReturn=bytes),
boundary)
for bytes in range(1, 20)]
d = defer.DeferredList(ds, consumeErrors=True)
d.addCallback(self._assertFailures, expected_error)
return d
def _assertFailures(self, failures, *expectedFailures):
for flag, failure in failures:
self.failUnlessEqual(flag, defer.FAILURE)
failure.trap(*expectedFailures)
def doTest(self, boundary, data, expected_args, expected_files):
#import time, gc, cgi, cStringIO
for bytes in range(1, 20):
#s = TestStream(data, maxReturn=bytes)
s = stream.IStream(data)
#t=time.time()
d = waitForDeferred(fileupload.parseMultipartFormData(s, boundary))
yield d; args, files = d.getResult()
#e=time.time()
#print "%.2g"%(e-t)
self.assertEquals(args, expected_args)
# Read file data back into memory to compare.
out = {}
for name, l in files.items():
out[name] = [(filename, ctype, f.read()) for (filename, ctype, f) in l]
self.assertEquals(out, expected_files)
#data=cStringIO.StringIO(data)
#t=time.time()
#d=cgi.parse_multipart(data, {'boundary':boundary})
#e=time.time()
#print "CGI: %.2g"%(e-t)
doTest = deferredGenerator(doTest)
def testNormalUpload(self):
return self.doTest(
'---------------------------155781040421463194511908194298',
"""-----------------------------155781040421463194511908194298\r
Content-Disposition: form-data; name="foo"\r
\r
Foo Bar\r
-----------------------------155781040421463194511908194298\r
Content-Disposition: form-data; name="file"; filename="filename"\r
Content-Type: text/html\r
\r
Contents of a file
blah
blah\r
-----------------------------155781040421463194511908194298--\r
""",
{'foo':['Foo Bar']},
{'file':[('filename', MimeType('text', 'html'),
"Contents of a file\nblah\nblah")]})
def testMultipleUpload(self):
return self.doTest(
'xyz',
"""--xyz\r
Content-Disposition: form-data; name="foo"\r
\r
Foo Bar\r
--xyz\r
Content-Disposition: form-data; name="foo"\r
\r
Baz\r
--xyz\r
Content-Disposition: form-data; name="file"; filename="filename"\r
Content-Type: text/html\r
\r
blah\r
--xyz\r
Content-Disposition: form-data; name="file"; filename="filename"\r
Content-Type: text/plain\r
\r
bleh\r
--xyz--\r
""",
{'foo':['Foo Bar', 'Baz']},
{'file':[('filename', MimeType('text', 'html'), "blah"),
('filename', MimeType('text', 'plain'), "bleh")]})
def testStupidFilename(self):
return self.doTest(
'----------0xKhTmLbOuNdArY',
"""------------0xKhTmLbOuNdArY\r
Content-Disposition: form-data; name="file"; filename="foo"; name="foobar.txt"\r
Content-Type: text/plain\r
\r
Contents of a file
blah
blah\r
------------0xKhTmLbOuNdArY--\r
""",
{},
{'file':[('foo"; name="foobar.txt', MimeType('text', 'plain'),
"Contents of a file\nblah\nblah")]})
def testEmptyFilename(self):
return self.doTest(
'curlPYafCMnsamUw9kSkJJkSen41sAV',
"""--curlPYafCMnsamUw9kSkJJkSen41sAV\r
cONTENT-tYPE: application/octet-stream\r
cONTENT-dISPOSITION: FORM-DATA; NAME="foo"; FILENAME=""\r
\r
qwertyuiop\r
--curlPYafCMnsamUw9kSkJJkSen41sAV--\r
""",
{},
{'foo':[('', MimeType('application', 'octet-stream'),
"qwertyuiop")]})
# Failing parses
def testMissingContentDisposition(self):
return self.doTestError(
'----------0xKhTmLbOuNdArY',
"""------------0xKhTmLbOuNdArY\r
Content-Type: text/html\r
\r
Blah blah I am a stupid webbrowser\r
------------0xKhTmLbOuNdArY--\r
""",
fileupload.MimeFormatError)
def testRandomData(self):
return self.doTestError(
'boundary',
"""--sdkjsadjlfjlj skjsfdkljsd
sfdkjsfdlkjhsfadklj sffkj""",
fileupload.MimeFormatError)
def test_tooBigUpload(self):
"""
Test that a too big form post fails.
"""
boundary = '---------------------------155781040421463194511908194298'
data = """-----------------------------155781040421463194511908194298\r
Content-Disposition: form-data; name="foo"\r
\r
Foo Bar\r
-----------------------------155781040421463194511908194298\r
Content-Disposition: form-data; name="file"; filename="filename"\r
Content-Type: text/html\r
\r
Contents of a file
blah
blah\r
-----------------------------155781040421463194511908194298--\r
"""
s = stream.IStream(data)
return self.assertFailure(
fileupload.parseMultipartFormData(s, boundary, maxSize=200),
fileupload.MimeFormatError)
def test_tooManyFields(self):
"""
Test when breaking the maximum number of fields.
"""
boundary = 'xyz'
data = """--xyz\r
Content-Disposition: form-data; name="foo"\r
\r
Foo Bar\r
--xyz\r
Content-Disposition: form-data; name="foo"\r
\r
Baz\r
--xyz\r
Content-Disposition: form-data; name="file"; filename="filename"\r
Content-Type: text/html\r
\r
blah\r
--xyz\r
Content-Disposition: form-data; name="file"; filename="filename"\r
Content-Type: text/plain\r
\r
bleh\r
--xyz--\r
"""
s = stream.IStream(data)
return self.assertFailure(
fileupload.parseMultipartFormData(s, boundary, maxFields=3),
fileupload.MimeFormatError)
def test_maxMem(self):
"""
An attachment with no filename goes to memory: check that the
C{maxMem} parameter limits the size of this kind of attachment.
"""
boundary = '---------------------------155781040421463194511908194298'
data = """-----------------------------155781040421463194511908194298\r
Content-Disposition: form-data; name="foo"\r
\r
Foo Bar and more content\r
-----------------------------155781040421463194511908194298\r
Content-Disposition: form-data; name="file"; filename="filename"\r
Content-Type: text/html\r
\r
Contents of a file
blah
blah\r
-----------------------------155781040421463194511908194298--\r
"""
s = stream.IStream(data)
return self.assertFailure(
fileupload.parseMultipartFormData(s, boundary, maxMem=10),
fileupload.MimeFormatError)
class TestURLEncoded(unittest.TestCase):
def doTest(self, data, expected_args):
for bytes in range(1, 20):
s = TestStream(data, maxReturn=bytes)
d = waitForDeferred(fileupload.parse_urlencoded(s))
yield d; args = d.getResult()
self.assertEquals(args, expected_args)
doTest = deferredGenerator(doTest)
def test_parseValid(self):
self.doTest("a=b&c=d&c=e", {'a':['b'], 'c':['d', 'e']})
self.doTest("a=b&c=d&c=e", {'a':['b'], 'c':['d', 'e']})
self.doTest("a=b+c%20d", {'a':['b c d']})
def test_parseInvalid(self):
self.doTest("a&b=c", {'b':['c']})
calendarserver-5.2+dfsg/twext/web2/test/test_log.py 0000644 0001750 0001750 00000010765 11340046753 021506 0 ustar rahul rahul # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.python.log import addObserver, removeObserver
from twext.web2.log import BaseCommonAccessLoggingObserver, LogWrapperResource
from twext.web2.http import Response
from twext.web2.resource import Resource, WrapperResource
from twext.web2.test.test_server import BaseCase, BaseTestResource
class BufferingLogObserver(BaseCommonAccessLoggingObserver):
"""
A web2 log observer that buffer messages.
"""
messages = []
def logMessage(self, message):
self.messages.append(message)
class SetDateWrapperResource(WrapperResource):
"""
A resource wrapper which sets the date header.
"""
def hook(self, req):
def _filter(req, resp):
resp.headers.setHeader('date', 0.0)
return resp
_filter.handleErrors = True
req.addResponseFilter(_filter, atEnd=True)
class NoneStreamResource(Resource):
"""
A basic empty resource.
"""
def render(self, req):
return Response(200)
class TestLogging(BaseCase):
def setUp(self):
self.blo = BufferingLogObserver()
addObserver(self.blo.emit)
# some default resource setup
self.resrc = BaseTestResource()
self.resrc.child_emptystream = NoneStreamResource()
self.root = SetDateWrapperResource(LogWrapperResource(self.resrc))
def tearDown(self):
removeObserver(self.blo.emit)
def assertLogged(self, **expected):
"""
Check that logged messages matches expected format.
"""
if 'date' not in expected:
epoch = BaseCommonAccessLoggingObserver().logDateString(0)
expected['date'] = epoch
if 'user' not in expected:
expected['user'] = '-'
if 'referer' not in expected:
expected['referer'] = '-'
if 'user-agent' not in expected:
expected['user-agent'] = '-'
if 'version' not in expected:
expected['version'] = '1.1'
if 'remotehost' not in expected:
expected['remotehost'] = 'remotehost'
messages = self.blo.messages[:]
del self.blo.messages[:]
expectedLog = ('%(remotehost)s - %(user)s [%(date)s] "%(method)s '
'%(uri)s HTTP/%(version)s" %(status)d %(length)d '
'"%(referer)s" "%(user-agent)s"')
if expected.get('logged', True):
# Ensure there weren't other messages hanging out
self.assertEquals(len(messages), 1, "len(%r) != 1" % (messages, ))
self.assertEquals(messages[0], expectedLog % expected)
else:
self.assertEquals(len(messages), 0, "len(%r) != 0" % (messages, ))
def test_logSimpleRequest(self):
"""
Check the log for a simple request.
"""
uri = 'http://localhost/'
method = 'GET'
def _cbCheckLog(response):
self.assertLogged(method=method, uri=uri, status=response[0],
length=response[1].getHeader('content-length'))
d = self.getResponseFor(self.root, uri, method=method)
d.addCallback(_cbCheckLog)
return d
def test_logErrors(self):
"""
Test the error log.
"""
def test(_, uri, method, **expected):
expected['uri'] = uri
expected['method'] = method
def _cbCheckLog(response):
self.assertEquals(response[0], expected['status'])
self.assertLogged(
length=response[1].getHeader('content-length'), **expected)
return self.getResponseFor(self.root,
uri,
method=method).addCallback(_cbCheckLog)
uri = 'http://localhost/foo' # doesn't exist
method = 'GET'
d = test(None, uri, method, status=404, logged=True)
# no host. this should result in a 400 which doesn't get logged
uri = 'http:///'
d.addCallback(test, uri, method, status=400, logged=False)
return d
def test_logNoneResponseStream(self):
"""
Test the log of an empty resource.
"""
uri = 'http://localhost/emptystream'
method = 'GET'
def _cbCheckLog(response):
self.assertLogged(method=method, uri=uri, status=200,
length=0)
d = self.getResponseFor(self.root, uri, method=method)
d.addCallback(_cbCheckLog)
return d
calendarserver-5.2+dfsg/twext/web2/test/test_static.py 0000644 0001750 0001750 00000011761 11667476304 022224 0 ustar rahul rahul # Copyright (c) 2008-2011 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twext.web2.static}.
"""
import os
from twext.web2.test.test_server import BaseCase
from twext.web2 import static
from twext.web2 import http_headers
from twext.web2 import stream
from twext.web2 import iweb
class TestData(BaseCase):
def setUp(self):
self.text = "Hello, World\n"
self.data = static.Data(self.text, "text/plain")
def test_dataState(self):
"""
Test the internal state of the Data object
"""
self.assert_(hasattr(self.data, "created_time"))
self.assertEquals(self.data.data, self.text)
self.assertEquals(self.data.type, http_headers.MimeType("text", "plain"))
self.assertEquals(self.data.contentType(), http_headers.MimeType("text", "plain"))
def test_etag(self):
"""
Test that we can get an ETag
"""
def _defer(result):
self.failUnless(result)
d = self.data.etag().addCallback(_defer)
return d
def test_render(self):
"""
Test that the result from Data.render is acceptable, including the
response code, the content-type header, and the actual response body
itself.
"""
response = iweb.IResponse(self.data.render(None))
self.assertEqual(response.code, 200)
self.assert_(response.headers.hasHeader("content-type"))
self.assertEqual(response.headers.getHeader("content-type"),
http_headers.MimeType("text", "plain"))
def checkStream(data):
self.assertEquals(str(data), self.text)
return stream.readStream(iweb.IResponse(self.data.render(None)).stream,
checkStream)
class TestFileSaver(BaseCase):
def setUp(self):
"""
Create an empty directory and a resource which will save uploads to
that directory.
"""
self.tempdir = self.mktemp()
os.mkdir(self.tempdir)
self.root = static.FileSaver(self.tempdir,
expectedFields=['FileNameOne'],
maxBytes=16)
self.root.addSlash = True
def uploadFile(self, fieldname, filename, mimetype, content, resrc=None,
host='foo', path='/'):
if not resrc:
resrc = self.root
ctype = http_headers.MimeType('multipart', 'form-data',
(('boundary', '---weeboundary'),))
return self.getResponseFor(resrc, '/',
headers={'host': 'foo',
'content-type': ctype },
length=len(content),
method='POST',
content="""-----weeboundary\r
Content-Disposition: form-data; name="%s"; filename="%s"\r
Content-Type: %s\r
\r
%s\r
-----weeboundary--\r
""" % (fieldname, filename, mimetype, content))
def _CbAssertInResponse(self, (code, headers, data, failed),
expected_response, expectedFailure=False):
expected_code, expected_headers, expected_data = expected_response
self.assertEquals(code, expected_code)
if expected_data is not None:
self.failUnlessSubstring(expected_data, data)
for key, value in expected_headers.iteritems():
self.assertEquals(headers.getHeader(key), value)
self.assertEquals(failed, expectedFailure)
def fileNameFromResponse(self, response):
(code, headers, data, failure) = response
return data[data.index('Saved file')+11:data.index('
')]
def assertInResponse(self, response, expected_response, failure=False):
d = response
d.addCallback(self._CbAssertInResponse, expected_response, failure)
return d
def test_enforcesMaxBytes(self):
return self.assertInResponse(
self.uploadFile('FileNameOne', 'myfilename', 'text/html', 'X'*32),
(200, {}, 'exceeds maximum length'))
def test_enforcesMimeType(self):
return self.assertInResponse(
self.uploadFile('FileNameOne', 'myfilename',
'application/x-python', 'X'),
(200, {}, 'type not allowed'))
def test_invalidField(self):
return self.assertInResponse(
self.uploadFile('NotARealField', 'myfilename', 'text/html', 'X'),
(200, {}, 'not a valid field'))
def test_reportFileSave(self):
return self.assertInResponse(
self.uploadFile('FileNameOne', 'myfilename', 'text/plain', 'X'),
(200, {}, 'Saved file'))
def test_compareFileContents(self):
def gotFname(fname):
contents = file(fname, 'rb').read()
self.assertEquals(contents, 'Test contents\n')
d = self.uploadFile('FileNameOne', 'myfilename', 'text/plain',
'Test contents\n')
d.addCallback(self.fileNameFromResponse)
d.addCallback(gotFname)
return d
calendarserver-5.2+dfsg/twext/web2/test/test_http.py 0000644 0001750 0001750 00000134176 12212514344 021702 0 ustar rahul rahul
from __future__ import nested_scopes
import time, sys, os
from zope.interface import implements
from twisted.trial import unittest
from twext.web2 import http, http_headers, responsecode, iweb, stream
from twext.web2 import channel
from twisted.internet import reactor, protocol, address, interfaces, utils
from twisted.internet import defer
from twisted.internet.defer import waitForDeferred, deferredGenerator
from twisted.protocols import loopback
from twisted.python import util, runtime
from twext.web2.channel.http import SSLRedirectRequest, HTTPFactory, HTTPChannel
from twisted.internet.task import deferLater
class RedirectResponseTestCase(unittest.TestCase):
def testTemporary(self):
"""
Verify the "temporary" parameter sets the appropriate response code
"""
req = http.RedirectResponse("http://example.com/", temporary=False)
self.assertEquals(req.code, responsecode.MOVED_PERMANENTLY)
req = http.RedirectResponse("http://example.com/", temporary=True)
self.assertEquals(req.code, responsecode.TEMPORARY_REDIRECT)
class PreconditionTestCase(unittest.TestCase):
def checkPreconditions(self, request, response, expectedResult, expectedCode,
**kw):
preconditionsPass = True
try:
http.checkPreconditions(request, response, **kw)
except http.HTTPError, e:
preconditionsPass = False
self.assertEquals(e.response.code, expectedCode)
self.assertEquals(preconditionsPass, expectedResult)
def testWithoutHeaders(self):
request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers())
out_headers = http_headers.Headers()
response = http.Response(responsecode.OK, out_headers, None)
self.checkPreconditions(request, response, True, responsecode.OK)
out_headers.setHeader("ETag", http_headers.ETag('foo'))
self.checkPreconditions(request, response, True, responsecode.OK)
out_headers.removeHeader("ETag")
out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT
self.checkPreconditions(request, response, True, responsecode.OK)
out_headers.setHeader("ETag", http_headers.ETag('foo'))
self.checkPreconditions(request, response, True, responsecode.OK)
def testIfMatch(self):
request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers())
out_headers = http_headers.Headers()
response = http.Response(responsecode.OK, out_headers, None)
# Behavior with no ETag set, should be same as with an ETag
request.headers.setRawHeaders("If-Match", ('*',))
self.checkPreconditions(request, response, True, responsecode.OK)
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED, entityExists=False)
# Ask for tag, but no etag set.
request.headers.setRawHeaders("If-Match", ('"frob"',))
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED)
## Actually set the ETag header
out_headers.setHeader("ETag", http_headers.ETag('foo'))
out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT
# behavior of entityExists
request.headers.setRawHeaders("If-Match", ('*',))
self.checkPreconditions(request, response, True, responsecode.OK)
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED, entityExists=False)
# tag matches
request.headers.setRawHeaders("If-Match", ('"frob", "foo"',))
self.checkPreconditions(request, response, True, responsecode.OK)
# none match
request.headers.setRawHeaders("If-Match", ('"baz", "bob"',))
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED)
# But if we have an error code already, ignore this header
response.code = responsecode.INTERNAL_SERVER_ERROR
self.checkPreconditions(request, response, True, responsecode.INTERNAL_SERVER_ERROR)
response.code = responsecode.OK
# Must only compare strong tags
out_headers.setHeader("ETag", http_headers.ETag('foo', weak=True))
request.headers.setRawHeaders("If-Match", ('W/"foo"',))
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED)
def testIfUnmodifiedSince(self):
request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers())
out_headers = http_headers.Headers()
response = http.Response(responsecode.OK, out_headers, None)
# No Last-Modified => always fail.
request.headers.setRawHeaders("If-Unmodified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',))
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED)
# Set output headers
out_headers.setHeader("ETag", http_headers.ETag('foo'))
out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT
request.headers.setRawHeaders("If-Unmodified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',))
self.checkPreconditions(request, response, True, responsecode.OK)
request.headers.setRawHeaders("If-Unmodified-Since", ('Sat, 01 Jan 2000 00:00:00 GMT',))
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED)
# But if we have an error code already, ignore this header
response.code = responsecode.INTERNAL_SERVER_ERROR
self.checkPreconditions(request, response, True, responsecode.INTERNAL_SERVER_ERROR)
response.code = responsecode.OK
# invalid date => header ignored
request.headers.setRawHeaders("If-Unmodified-Since", ('alalalalalalalalalala',))
self.checkPreconditions(request, response, True, responsecode.OK)
def testIfModifiedSince(self):
if time.time() < 946771200:
self.fail(RuntimeError("Your computer's clock is way wrong, "
"this test will be invalid."))
request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers())
out_headers = http_headers.Headers()
response = http.Response(responsecode.OK, out_headers, None)
# No Last-Modified => always succeed
request.headers.setRawHeaders("If-Modified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',))
self.checkPreconditions(request, response, True, responsecode.OK)
# Set output headers
out_headers.setHeader("ETag", http_headers.ETag('foo'))
out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT
request.headers.setRawHeaders("If-Modified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',))
self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED)
# With a non-GET method
request.method="PUT"
self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED)
request.method="GET"
request.headers.setRawHeaders("If-Modified-Since", ('Sat, 01 Jan 2000 00:00:00 GMT',))
self.checkPreconditions(request, response, True, responsecode.OK)
# But if we have an error code already, ignore this header
response.code = responsecode.INTERNAL_SERVER_ERROR
self.checkPreconditions(request, response, True, responsecode.INTERNAL_SERVER_ERROR)
response.code = responsecode.OK
# invalid date => header ignored
request.headers.setRawHeaders("If-Modified-Since", ('alalalalalalalalalala',))
self.checkPreconditions(request, response, True, responsecode.OK)
# date in the future => assume modified
request.headers.setHeader("If-Modified-Since", time.time() + 500)
self.checkPreconditions(request, response, True, responsecode.OK)
def testIfNoneMatch(self):
request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers())
out_headers = http_headers.Headers()
response = http.Response(responsecode.OK, out_headers, None)
request.headers.setRawHeaders("If-None-Match", ('"foo"',))
self.checkPreconditions(request, response, True, responsecode.OK)
out_headers.setHeader("ETag", http_headers.ETag('foo'))
out_headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT
# behavior of entityExists
request.headers.setRawHeaders("If-None-Match", ('*',))
request.method="PUT"
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED)
request.method="GET"
self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED)
self.checkPreconditions(request, response, True, responsecode.OK, entityExists=False)
# tag matches
request.headers.setRawHeaders("If-None-Match", ('"frob", "foo"',))
request.method="PUT"
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED)
request.method="GET"
self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED)
# now with IMS, also:
request.headers.setRawHeaders("If-Modified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',))
request.method="PUT"
self.checkPreconditions(request, response, False, responsecode.PRECONDITION_FAILED)
request.method="GET"
self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED)
request.headers.setRawHeaders("If-Modified-Since", ('Sat, 01 Jan 2000 00:00:00 GMT',))
self.checkPreconditions(request, response, True, responsecode.OK)
request.headers.removeHeader("If-Modified-Since")
# none match
request.headers.setRawHeaders("If-None-Match", ('"baz", "bob"',))
self.checkPreconditions(request, response, True, responsecode.OK)
# now with IMS, also:
request.headers.setRawHeaders("If-Modified-Since", ('Mon, 03 Jan 2000 00:00:00 GMT',))
self.checkPreconditions(request, response, True, responsecode.OK)
request.headers.setRawHeaders("If-Modified-Since", ('Sat, 01 Jan 2000 00:00:00 GMT',))
self.checkPreconditions(request, response, True, responsecode.OK)
request.headers.removeHeader("If-Modified-Since")
# But if we have an error code already, ignore this header
response.code = responsecode.INTERNAL_SERVER_ERROR
self.checkPreconditions(request, response, True, responsecode.INTERNAL_SERVER_ERROR)
response.code = responsecode.OK
# Weak tags okay for GET
out_headers.setHeader("ETag", http_headers.ETag('foo', weak=True))
request.headers.setRawHeaders("If-None-Match", ('W/"foo"',))
self.checkPreconditions(request, response, False, responsecode.NOT_MODIFIED)
# Weak tags not okay for other methods
request.method="PUT"
out_headers.setHeader("ETag", http_headers.ETag('foo', weak=True))
request.headers.setRawHeaders("If-None-Match", ('W/"foo"',))
self.checkPreconditions(request, response, True, responsecode.OK)
def testNoResponse(self):
# Ensure that passing etag/lastModified arguments instead of response works.
request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers())
request.method="PUT"
request.headers.setRawHeaders("If-None-Match", ('"foo"',))
self.checkPreconditions(request, None, True, responsecode.OK)
self.checkPreconditions(request, None, False, responsecode.PRECONDITION_FAILED,
etag=http_headers.ETag('foo'),
lastModified=946771200)
# Make sure that, while you shoudn't do this, that it doesn't cause an error
request.method="GET"
self.checkPreconditions(request, None, False, responsecode.NOT_MODIFIED,
etag=http_headers.ETag('foo'))
class IfRangeTestCase(unittest.TestCase):
def testIfRange(self):
request = http.Request(None, "GET", "/", "HTTP/1.1", 0, http_headers.Headers())
response = TestResponse()
self.assertEquals(http.checkIfRange(request, response), True)
request.headers.setRawHeaders("If-Range", ('"foo"',))
self.assertEquals(http.checkIfRange(request, response), False)
response.headers.setHeader("ETag", http_headers.ETag('foo'))
self.assertEquals(http.checkIfRange(request, response), True)
request.headers.setRawHeaders("If-Range", ('"bar"',))
response.headers.setHeader("ETag", http_headers.ETag('foo'))
self.assertEquals(http.checkIfRange(request, response), False)
request.headers.setRawHeaders("If-Range", ('W/"foo"',))
response.headers.setHeader("ETag", http_headers.ETag('foo', weak=True))
self.assertEquals(http.checkIfRange(request, response), False)
request.headers.setRawHeaders("If-Range", ('"foo"',))
response.headers.removeHeader("ETag")
self.assertEquals(http.checkIfRange(request, response), False)
request.headers.setRawHeaders("If-Range", ('Sun, 02 Jan 2000 00:00:00 GMT',))
response.headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT
self.assertEquals(http.checkIfRange(request, response), True)
request.headers.setRawHeaders("If-Range", ('Sun, 02 Jan 2000 00:00:01 GMT',))
response.headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT
self.assertEquals(http.checkIfRange(request, response), False)
request.headers.setRawHeaders("If-Range", ('Sun, 01 Jan 2000 23:59:59 GMT',))
response.headers.setHeader("Last-Modified", 946771200) # Sun, 02 Jan 2000 00:00:00 GMT
self.assertEquals(http.checkIfRange(request, response), False)
request.headers.setRawHeaders("If-Range", ('Sun, 01 Jan 2000 23:59:59 GMT',))
response.headers.removeHeader("Last-Modified")
self.assertEquals(http.checkIfRange(request, response), False)
request.headers.setRawHeaders("If-Range", ('jwerlqjL#$Y*KJAN',))
self.assertEquals(http.checkIfRange(request, response), False)
class LoopbackRelay(loopback.LoopbackRelay):
implements(interfaces.IProducer)
def pauseProducing(self):
self.paused = True
def resumeProducing(self):
self.paused = False
def stopProducing(self):
self.loseConnection()
def loseWriteConnection(self):
# HACK.
self.loseConnection()
def abortConnection(self):
self.aborted = True
def getHost(self):
"""
Synthesize a slightly more realistic 'host' thing.
"""
return address.IPv4Address('TCP', 'localhost', 4321)
class TestRequestMixin(object):
def __init__(self, *args, **kwargs):
super(TestRequestMixin, self).__init__(*args, **kwargs)
self.cmds = []
headers = list(self.headers.getAllRawHeaders())
headers.sort()
self.cmds.append(('init', self.method, self.uri, self.clientproto, self.stream.length, tuple(headers)))
def process(self):
pass
def handleContentChunk(self, data):
self.cmds.append(('contentChunk', data))
def handleContentComplete(self):
self.cmds.append(('contentComplete',))
def connectionLost(self, reason):
self.cmds.append(('connectionLost', reason))
def _finished(self, x):
self._reallyFinished(x)
class TestRequest(TestRequestMixin, http.Request):
"""
Stub request for testing.
"""
class TestSSLRedirectRequest(TestRequestMixin, SSLRedirectRequest):
"""
Stub request for HSTS testing.
"""
class TestResponse(object):
implements(iweb.IResponse)
code = responsecode.OK
headers = None
def __init__(self):
self.headers = http_headers.Headers()
self.stream = stream.ProducerStream()
def write(self, data):
self.stream.write(data)
def finish(self):
self.stream.finish()
class TestClient(protocol.Protocol):
data = ""
done = False
def dataReceived(self, data):
self.data+=data
def write(self, data):
self.transport.write(data)
def connectionLost(self, reason):
self.done = True
self.transport.loseConnection()
def loseConnection(self):
self.done = True
self.transport.loseConnection()
class TestConnection:
def __init__(self):
self.requests = []
self.client = None
self.callLaters = []
def fakeCallLater(self, secs, f):
assert secs == 0
self.callLaters.append(f)
class HTTPTests(unittest.TestCase):
requestClass = TestRequest
def setUp(self):
super(HTTPTests, self).setUp()
# We always need this set to True - previous tests may have changed it
HTTPChannel.allowPersistentConnections = True
def connect(self, logFile=None, **protocol_kwargs):
cxn = TestConnection()
def makeTestRequest(*args):
cxn.requests.append(self.requestClass(*args))
return cxn.requests[-1]
factory = channel.HTTPFactory(requestFactory=makeTestRequest,
_callLater=cxn.fakeCallLater,
**protocol_kwargs)
cxn.client = TestClient()
cxn.server = factory.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 2345))
cxn.serverToClient = LoopbackRelay(cxn.client, logFile)
cxn.clientToServer = LoopbackRelay(cxn.server, logFile)
cxn.server.makeConnection(cxn.serverToClient)
cxn.client.makeConnection(cxn.clientToServer)
return cxn
def iterate(self, cxn):
callLaters = cxn.callLaters
cxn.callLaters = []
for f in callLaters:
f()
cxn.serverToClient.clearBuffer()
cxn.clientToServer.clearBuffer()
if cxn.serverToClient.shouldLose:
cxn.serverToClient.clearBuffer()
if cxn.clientToServer.shouldLose:
cxn.clientToServer.clearBuffer()
def compareResult(self, cxn, cmds, data):
self.iterate(cxn)
for receivedRequest, expectedCommands in map(None, cxn.requests, cmds):
sortedHeaderCommands = []
for cmd in expectedCommands:
if len(cmd) == 6:
sortedHeaders = list(cmd[5])
sortedHeaders.sort()
sortedHeaderCommands.append(cmd[:5] + (tuple(sortedHeaders),))
else:
sortedHeaderCommands.append(cmd)
self.assertEquals(receivedRequest.cmds, sortedHeaderCommands)
self.assertEquals(cxn.client.data, data)
def assertDone(self, cxn, done=True):
self.iterate(cxn)
self.assertEquals(cxn.client.done, done)
class GracefulShutdownTestCase(HTTPTests):
def _callback(self, result):
self.callbackFired = True
def testAllConnectionsClosedWithoutConnectedChannels(self):
"""
allConnectionsClosed( ) should fire right away if no connected channels
"""
self.callbackFired = False
factory = HTTPFactory(None)
factory.allConnectionsClosed().addCallback(self._callback)
self.assertTrue(self.callbackFired) # now!
def testallConnectionsClosedWithConnectedChannels(self):
"""
allConnectionsClosed( ) should only fire after all connected channels
have been removed
"""
self.callbackFired = False
factory = HTTPFactory(None)
factory.addConnectedChannel("A")
factory.addConnectedChannel("B")
factory.addConnectedChannel("C")
factory.allConnectionsClosed().addCallback(self._callback)
factory.removeConnectedChannel("A")
self.assertFalse(self.callbackFired) # wait for it...
factory.removeConnectedChannel("B")
self.assertFalse(self.callbackFired) # wait for it...
factory.removeConnectedChannel("C")
self.assertTrue(self.callbackFired) # now!
class CoreHTTPTestCase(HTTPTests):
# Note: these tests compare the client output using string
# matching. It is acceptable for this to change and break
# the test if you know what you are doing.
def testHTTP0_9(self, nouri=False):
cxn = self.connect()
cmds = [[]]
data = ""
if nouri:
cxn.client.write("GET\r\n")
else:
cxn.client.write("GET /\r\n")
# Second request which should not be handled
cxn.client.write("GET /two\r\n")
cmds[0] += [('init', 'GET', '/', (0,9), 0, ()), ('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.headers.setRawHeaders("Yo", ("One", "Two"))
cxn.requests[0].writeResponse(response)
response.write("")
self.compareResult(cxn, cmds, data)
response.write("Output")
data += "Output"
self.compareResult(cxn, cmds, data)
response.finish()
self.compareResult(cxn, cmds, data)
self.assertDone(cxn)
def testHTTP0_9_nouri(self):
self.testHTTP0_9(True)
def testHTTP1_0(self):
cxn = self.connect()
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.0\r\nContent-Length: 5\r\nHost: localhost\r\n\r\nInput")
# Second request which should not be handled
cxn.client.write("GET /two HTTP/1.0\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1,0), 5,
(('Host', ['localhost']),)),
('contentChunk', 'Input'),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.headers.setRawHeaders("Yo", ("One", "Two"))
cxn.requests[0].writeResponse(response)
response.write("")
data += "HTTP/1.1 200 OK\r\nYo: One\r\nYo: Two\r\nConnection: close\r\n\r\n"
self.compareResult(cxn, cmds, data)
response.write("Output")
data += "Output"
self.compareResult(cxn, cmds, data)
response.finish()
self.compareResult(cxn, cmds, data)
self.assertDone(cxn)
def testHTTP1_0_keepalive(self):
cxn = self.connect()
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.0\r\nConnection: keep-alive\r\nContent-Length: 5\r\nHost: localhost\r\n\r\nInput")
cxn.client.write("GET /two HTTP/1.0\r\n\r\n")
# Third request shouldn't be handled
cxn.client.write("GET /three HTTP/1.0\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1,0), 5,
(('Host', ['localhost']),)),
('contentChunk', 'Input'),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response0 = TestResponse()
response0.headers.setRawHeaders("Content-Length", ("6", ))
response0.headers.setRawHeaders("Yo", ("One", "Two"))
cxn.requests[0].writeResponse(response0)
response0.write("")
data += "HTTP/1.1 200 OK\r\nContent-Length: 6\r\nYo: One\r\nYo: Two\r\nConnection: Keep-Alive\r\n\r\n"
self.compareResult(cxn, cmds, data)
response0.write("Output")
data += "Output"
self.compareResult(cxn, cmds, data)
response0.finish()
# Now for second request:
cmds.append([])
cmds[1] += [('init', 'GET', '/two', (1,0), 0, ()),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response1 = TestResponse()
response1.headers.setRawHeaders("Content-Length", ("0", ))
cxn.requests[1].writeResponse(response1)
response1.write("")
data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
self.compareResult(cxn, cmds, data)
response1.finish()
self.assertDone(cxn)
def testHTTP1_1_pipelining(self):
cxn = self.connect(maxPipeline=2)
cmds = []
data = ""
# Both these show up immediately.
cxn.client.write("GET / HTTP/1.1\r\nContent-Length: 5\r\nHost: localhost\r\n\r\nInput")
cxn.client.write("GET /two HTTP/1.1\r\nHost: localhost\r\n\r\n")
# Doesn't show up until the first is done.
cxn.client.write("GET /three HTTP/1.1\r\nHost: localhost\r\n\r\n")
# Doesn't show up until the second is done.
cxn.client.write("GET /four HTTP/1.1\r\nHost: localhost\r\n\r\n")
cmds.append([])
cmds[0] += [('init', 'GET', '/', (1,1), 5,
(('Host', ['localhost']),)),
('contentChunk', 'Input'),
('contentComplete',)]
cmds.append([])
cmds[1] += [('init', 'GET', '/two', (1,1), 0,
(('Host', ['localhost']),)),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response0 = TestResponse()
response0.headers.setRawHeaders("Content-Length", ("6", ))
cxn.requests[0].writeResponse(response0)
response0.write("")
data += "HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\n"
self.compareResult(cxn, cmds, data)
response0.write("Output")
data += "Output"
self.compareResult(cxn, cmds, data)
response0.finish()
# Now the third request gets read:
cmds.append([])
cmds[2] += [('init', 'GET', '/three', (1,1), 0,
(('Host', ['localhost']),)),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
# Let's write out the third request before the second.
# This should not cause anything to be written to the client.
response2 = TestResponse()
response2.headers.setRawHeaders("Content-Length", ("5", ))
cxn.requests[2].writeResponse(response2)
response2.write("Three")
response2.finish()
self.compareResult(cxn, cmds, data)
response1 = TestResponse()
response1.headers.setRawHeaders("Content-Length", ("3", ))
cxn.requests[1].writeResponse(response1)
response1.write("Two")
data += "HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nTwo"
self.compareResult(cxn, cmds, data)
response1.finish()
# Fourth request shows up
cmds.append([])
cmds[3] += [('init', 'GET', '/four', (1,1), 0,
(('Host', ['localhost']),)),
('contentComplete',)]
data += "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nThree"
self.compareResult(cxn, cmds, data)
response3 = TestResponse()
response3.headers.setRawHeaders("Content-Length", ("0",))
cxn.requests[3].writeResponse(response3)
response3.finish()
data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"
self.compareResult(cxn, cmds, data)
self.assertDone(cxn, done=False)
cxn.client.loseConnection()
self.assertDone(cxn)
def testHTTP1_1_chunking(self, extraHeaders=""):
cxn = self.connect()
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\nHost: localhost\r\n\r\n5\r\nInput\r\n")
cmds[0] += [('init', 'GET', '/', (1,1), None,
(('Host', ['localhost']),)),
('contentChunk', 'Input')]
self.compareResult(cxn, cmds, data)
cxn.client.write("1; blahblahblah\r\na\r\n10\r\nabcdefghijklmnop\r\n")
cmds[0] += [('contentChunk', 'a'),('contentChunk', 'abcdefghijklmnop')]
self.compareResult(cxn, cmds, data)
cxn.client.write("0\r\nRandom-Ignored-Trailer: foo\r\n\r\n")
cmds[0] += [('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
cxn.requests[0].writeResponse(response)
response.write("Output")
expected = ["HTTP/1.1 200 OK"]
if extraHeaders:
expected.append(extraHeaders)
expected.extend([
"Transfer-Encoding: chunked",
"",
"6",
"Output",
"",
])
data += "\r\n".join(expected)
self.compareResult(cxn, cmds, data)
response.write("blahblahblah")
data += "C\r\nblahblahblah\r\n"
self.compareResult(cxn, cmds, data)
response.finish()
data += "0\r\n\r\n"
self.compareResult(cxn, cmds, data)
cxn.client.loseConnection()
self.assertDone(cxn)
def testHTTP1_1_expect_continue(self):
cxn = self.connect()
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\nContent-Length: 5\r\nHost: localhost\r\nExpect: 100-continue\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1,1), 5,
(('Expect', ['100-continue']), ('Host', ['localhost'])))]
self.compareResult(cxn, cmds, data)
cxn.requests[0].stream.read()
data += "HTTP/1.1 100 Continue\r\n\r\n"
self.compareResult(cxn, cmds, data)
cxn.client.write("Input")
cmds[0] += [('contentChunk', 'Input'),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.headers.setRawHeaders("Content-Length", ("6",))
cxn.requests[0].writeResponse(response)
response.write("Output")
response.finish()
data += "HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nOutput"
self.compareResult(cxn, cmds, data)
cxn.client.loseConnection()
self.assertDone(cxn)
def testHTTP1_1_expect_continue_early_reply(self):
cxn = self.connect()
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\nContent-Length: 5\r\nHost: localhost\r\nExpect: 100-continue\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1,1), 5,
(('Host', ['localhost']), ('Expect', ['100-continue'])))]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.headers.setRawHeaders("Content-Length", ("6",))
cxn.requests[0].writeResponse(response)
response.write("Output")
response.finish()
cmds[0] += [('contentComplete',)]
data += "HTTP/1.1 200 OK\r\nContent-Length: 6\r\nConnection: close\r\n\r\nOutput"
self.compareResult(cxn, cmds, data)
cxn.client.loseConnection()
self.assertDone(cxn)
def testHeaderContinuation(self):
cxn = self.connect()
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\nHost: localhost\r\nFoo: yada\r\n yada\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1,1), 0,
(('Host', ['localhost']), ('Foo', ['yada yada']),)),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
cxn.client.loseConnection()
self.assertDone(cxn)
def testTimeout_immediate(self):
# timeout 0 => timeout on first iterate call
cxn = self.connect(inputTimeOut = 0)
return deferLater(reactor, 0, self.assertDone, cxn)
def testTimeout_inRequest(self):
cxn = self.connect(inputTimeOut = 0.3)
cxn.client.write("GET / HTTP/1.1\r\n")
return deferLater(reactor, 0.5, self.assertDone, cxn)
def testTimeout_betweenRequests(self):
cxn = self.connect(betweenRequestsTimeOut = 0.3)
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1,1), 0, ()),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.headers.setRawHeaders("Content-Length", ("0",))
cxn.requests[0].writeResponse(response)
response.finish()
data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"
self.compareResult(cxn, cmds, data)
return deferLater(reactor, 0.5, self.assertDone, cxn) # Wait for timeout
def testTimeout_idleRequest(self):
cxn = self.connect(idleTimeOut=0.3)
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1, 1), 0, ()),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
return deferLater(reactor, 0.5, self.assertDone, cxn) # Wait for timeout
def testTimeout_abortRequest(self):
cxn = self.connect(allowPersistentConnections=False, closeTimeOut=0.3)
cxn.client.transport.loseConnection = lambda : None
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1, 1), 0, ()),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.headers.setRawHeaders("Content-Length", ("0",))
cxn.requests[0].writeResponse(response)
response.finish()
data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
self.compareResult(cxn, cmds, data)
def _check(cxn):
self.assertDone(cxn)
self.assertTrue(cxn.serverToClient.aborted)
return deferLater(reactor, 0.5, self.assertDone, cxn) # Wait for timeout
def testConnectionCloseRequested(self):
cxn = self.connect()
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1,1), 0, ()),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
cxn.client.write("GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
cmds.append([])
cmds[1] += [('init', 'GET', '/', (1,1), 0, ()),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.headers.setRawHeaders("Content-Length", ("0",))
cxn.requests[0].writeResponse(response)
response.finish()
data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"
response = TestResponse()
response.headers.setRawHeaders("Content-Length", ("0",))
cxn.requests[1].writeResponse(response)
response.finish()
data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
self.compareResult(cxn, cmds, data)
self.assertDone(cxn)
def testConnectionKeepAliveOff(self):
cxn = self.connect(allowPersistentConnections=False)
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1, 1), 0, ()),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.headers.setRawHeaders("Content-Length", ("0",))
cxn.requests[0].writeResponse(response)
response.finish()
data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
self.compareResult(cxn, cmds, data)
self.assertDone(cxn)
def testExtraCRLFs(self):
cxn = self.connect()
cmds = [[]]
data = ""
# Some broken clients (old IEs) send an extra CRLF after post
cxn.client.write("POST / HTTP/1.1\r\nContent-Length: 5\r\nHost: localhost\r\n\r\nInput\r\n")
cmds[0] += [('init', 'POST', '/', (1,1), 5,
(('Host', ['localhost']),)),
('contentChunk', 'Input'),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
cxn.client.write("GET /two HTTP/1.1\r\n\r\n")
cmds.append([])
cmds[1] += [('init', 'GET', '/two', (1,1), 0, ()),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
cxn.client.loseConnection()
self.assertDone(cxn)
def testDisallowPersistentConnections(self):
cxn = self.connect(allowPersistentConnections=False)
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nGET / HTTP/1.1\r\nHost: localhost\r\n\r\n")
cmds[0] += [('init', 'GET', '/', (1,1), 0,
(('Host', ['localhost']),)),
('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.finish()
cxn.requests[0].writeResponse(response)
data += 'HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n'
self.compareResult(cxn, cmds, data)
self.assertDone(cxn)
def testIgnoreBogusContentLength(self):
# Ensure that content-length is ignored when transfer-encoding
# is also specified.
cxn = self.connect()
cmds = [[]]
data = ""
cxn.client.write("GET / HTTP/1.1\r\nContent-Length: 100\r\nTransfer-Encoding: chunked\r\nHost: localhost\r\n\r\n5\r\nInput\r\n")
cmds[0] += [('init', 'GET', '/', (1,1), None,
(('Host', ['localhost']),)),
('contentChunk', 'Input')]
self.compareResult(cxn, cmds, data)
cxn.client.write("0\r\n\r\n")
cmds[0] += [('contentComplete',)]
self.compareResult(cxn, cmds, data)
response = TestResponse()
response.finish()
cxn.requests[0].writeResponse(response)
data += "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"
self.compareResult(cxn, cmds, data)
cxn.client.loseConnection()
self.assertDone(cxn)
class ErrorTestCase(HTTPTests):
def assertStartsWith(self, first, second, msg=None):
self.assert_(first.startswith(second), '%r.startswith(%r)' % (first, second))
def checkError(self, cxn, code):
self.iterate(cxn)
self.assertStartsWith(cxn.client.data, "HTTP/1.1 %d "%code)
self.assertIn("\r\nConnection: close\r\n", cxn.client.data)
# Ensure error messages have a defined content-length.
self.assertIn("\r\nContent-Length:", cxn.client.data)
self.assertDone(cxn)
def testChunkingError1(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\nasdf\r\n")
self.checkError(cxn, 400)
def testChunkingError2(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n1\r\nblahblah\r\n")
self.checkError(cxn, 400)
def testChunkingError3(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n-1\r\nasdf\r\n")
self.checkError(cxn, 400)
def testTooManyHeaders(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\n")
cxn.client.write("Foo: Bar\r\n"*5000)
self.checkError(cxn, 400)
def testLineTooLong(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\n")
cxn.client.write("Foo: "+("Bar"*10000))
self.checkError(cxn, 400)
def testLineTooLong2(self):
cxn = self.connect()
cxn.client.write("GET "+("/Bar")*10000 +" HTTP/1.1\r\n")
self.checkError(cxn, 414)
def testNoColon(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\n")
cxn.client.write("Blahblah\r\n\r\n")
self.checkError(cxn, 400)
def test_nonAsciiHeader(self):
"""
As per U{RFC 822 section 3,
}, headers are
ASCII only.
"""
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\nX-Extra-Header: \xff\r\n\r\n")
self.checkError(cxn, responsecode.BAD_REQUEST)
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\nX-E\xfftra-Header: foo\r\n\r\n")
self.checkError(cxn, responsecode.BAD_REQUEST)
def testBadRequest(self):
cxn = self.connect()
cxn.client.write("GET / more HTTP/1.1\r\n")
self.checkError(cxn, 400)
def testWrongProtocol(self):
cxn = self.connect()
cxn.client.write("GET / Foobar/1.0\r\n")
self.checkError(cxn, 400)
def testBadProtocolVersion(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/1\r\n")
self.checkError(cxn, 400)
def testBadProtocolVersion2(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/-1.0\r\n")
self.checkError(cxn, 400)
def testWrongProtocolVersion(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/2.0\r\n")
self.checkError(cxn, 505)
def testUnsupportedTE(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\n")
cxn.client.write("Transfer-Encoding: blahblahblah, chunked\r\n\r\n")
self.checkError(cxn, 501)
def testTEWithoutChunked(self):
cxn = self.connect()
cxn.client.write("GET / HTTP/1.1\r\n")
cxn.client.write("Transfer-Encoding: gzip\r\n\r\n")
self.checkError(cxn, 400)
class PipelinedErrorTestCase(ErrorTestCase):
# Make sure that even low level reading errors don't corrupt the data stream,
# but always wait until their turn to respond.
def connect(self):
cxn = ErrorTestCase.connect(self)
cxn.client.write("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")
cmds = [[('init', 'GET', '/', (1,1), 0,
(('Host', ['localhost']),)),
('contentComplete', )]]
data = ""
self.compareResult(cxn, cmds, data)
return cxn
def checkError(self, cxn, code):
self.iterate(cxn)
self.assertEquals(cxn.client.data, '')
response = TestResponse()
response.headers.setRawHeaders("Content-Length", ("0",))
cxn.requests[0].writeResponse(response)
response.write('')
data = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"
self.iterate(cxn)
self.assertEquals(cxn.client.data, data)
# Reset the data so the checkError's startswith test can work right.
cxn.client.data = ""
response.finish()
ErrorTestCase.checkError(self, cxn, code)
class SimpleFactory(channel.HTTPFactory):
def buildProtocol(self, addr):
# Do a bunch of crazy crap just so that the test case can know when the
# connection is done.
p = channel.HTTPFactory.buildProtocol(self, addr)
cl = p.connectionLost
def newCl(reason):
reactor.callLater(0, lambda: self.testcase.connlost.callback(None))
return cl(reason)
p.connectionLost = newCl
self.conn = p
return p
class SimpleRequest(http.Request):
def process(self):
response = TestResponse()
if self.uri == "/error":
response.code=402
elif self.uri == "/forbidden":
response.code=403
else:
response.code=404
response.write("URI %s unrecognized." % self.uri)
response.finish()
self.writeResponse(response)
class AbstractServerTestMixin:
type = None
def testBasicWorkingness(self):
args = ('-u', util.sibpath(__file__, "simple_client.py"), "basic",
str(self.port), self.type)
d = waitForDeferred(
utils.getProcessOutputAndValue(sys.executable, args=args,
env=os.environ)
)
yield d; out,err,code = d.getResult()
self.assertEquals(code, 0, "Error output:\n%s" % (err,))
self.assertEquals(out, "HTTP/1.1 402 Payment Required\r\nContent-Length: 0\r\nConnection: close\r\n\r\n")
testBasicWorkingness = deferredGenerator(testBasicWorkingness)
def testLingeringClose(self):
args = ('-u', util.sibpath(__file__, "simple_client.py"),
"lingeringClose", str(self.port), self.type)
d = waitForDeferred(
utils.getProcessOutputAndValue(sys.executable, args=args,
env=os.environ)
)
yield d; out,err,code = d.getResult()
self.assertEquals(code, 0, "Error output:\n%s" % (err,))
self.assertEquals(out, "HTTP/1.1 402 Payment Required\r\nContent-Length: 0\r\nConnection: close\r\n\r\n")
testLingeringClose = deferredGenerator(testLingeringClose)
class TCPServerTest(unittest.TestCase, AbstractServerTestMixin):
type = 'tcp'
def setUp(self):
factory=SimpleFactory(requestFactory=SimpleRequest)
factory.testcase = self
self.factory = factory
self.connlost = defer.Deferred()
self.socket = reactor.listenTCP(0, factory)
self.port = self.socket.getHost().port
def tearDown(self):
# Make sure the listening port is closed
d = defer.maybeDeferred(self.socket.stopListening)
def finish(v):
# And make sure the established connection is, too
self.factory.conn.transport.loseConnection()
return self.connlost
return d.addCallback(finish)
try:
from twisted.internet import ssl
ssl # pyflakes
except ImportError:
# happens the first time the interpreter tries to import it
ssl = None
if ssl and not ssl.supported:
# happens second and later times
ssl = None
certPath = util.sibpath(__file__, "server.pem")
class SSLServerTest(unittest.TestCase, AbstractServerTestMixin):
type = 'ssl'
def setUp(self):
sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath)
factory=SimpleFactory(requestFactory=SimpleRequest)
factory.testcase = self
self.factory = factory
self.connlost = defer.Deferred()
self.socket = reactor.listenSSL(0, factory, sCTX)
self.port = self.socket.getHost().port
def tearDown(self):
# Make sure the listening port is closed
d = defer.maybeDeferred(self.socket.stopListening)
def finish(v):
# And make sure the established connection is, too
self.factory.conn.transport.loseConnection()
return self.connlost
return d.addCallback(finish)
def testLingeringClose(self):
return super(SSLServerTest, self).testLingeringClose()
if runtime.platform.isWindows():
# This may not just be Windows, but all platforms with more recent
# versions of OpenSSL. Do some more experimentation...
testLingeringClose.todo = "buffering kills the connection too early; test this some other way"
if interfaces.IReactorProcess(reactor, None) is None:
TCPServerTest.skip = SSLServerTest.skip = "Required process support missing from reactor"
elif interfaces.IReactorSSL(reactor, None) is None:
SSLServerTest.skip = "Required SSL support missing from reactor"
elif ssl is None:
SSLServerTest.skip = "SSL not available, cannot test SSL."
calendarserver-5.2+dfsg/twext/web2/test/test_resource.py 0000644 0001750 0001750 00000016055 11457335713 022560 0 ustar rahul rahul # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A test harness for twext.web2.resource.
"""
from sets import Set as set
from zope.interface import implements
from twisted.internet.defer import succeed, fail, inlineCallbacks
from twisted.trial import unittest
from twext.web2 import responsecode
from twext.web2.iweb import IResource
from twext.web2.http import Response
from twext.web2.stream import MemoryStream
from twext.web2.resource import RenderMixin, LeafResource
from twext.web2.server import Site, StopTraversal
from twext.web2.test.test_server import SimpleRequest
class PreconditionError (Exception):
"Precondition Failure"
class TestResource (RenderMixin):
implements(IResource)
def _handler(self, request):
if request is None:
return responsecode.INTERNAL_SERVER_ERROR
return responsecode.NO_CONTENT
http_BLEARGH = _handler
http_HUCKHUCKBLORP = _handler
http_SWEETHOOKUPS = _handler
http_HOOKUPS = _handler
def preconditions_BLEARGH(self, request):
raise PreconditionError()
def precondition_HUCKHUCKBLORP(self, request):
return fail(None)
def preconditions_SWEETHOOKUPS(self, request):
return None
def preconditions_HOOKUPS(self, request):
return succeed(None)
renderOutput = "Snootch to the hootch"
def render(self, request):
response = Response()
response.stream = MemoryStream(self.renderOutput)
return response
def generateResponse(method):
resource = TestResource()
method = getattr(resource, "http_" + method)
return method(SimpleRequest(Site(resource), method, "/"))
class RenderMixInTestCase (unittest.TestCase):
"""
Test RenderMixin.
"""
_my_allowed_methods = set((
"HEAD", "OPTIONS", "GET",
"BLEARGH", "HUCKHUCKBLORP",
"SWEETHOOKUPS", "HOOKUPS",
))
def test_allowedMethods(self):
"""
RenderMixin.allowedMethods()
"""
self.assertEquals(
set(TestResource().allowedMethods()),
self._my_allowed_methods
)
@inlineCallbacks
def test_checkPreconditions_raises(self):
"""
RenderMixin.checkPreconditions()
Exception raised in checkPreconditions()
"""
resource = TestResource()
request = SimpleRequest(Site(resource), "BLEARGH", "/")
# Check that checkPreconditions raises as expected
self.assertRaises(
PreconditionError, resource.checkPreconditions, request
)
# Check that renderHTTP calls checkPreconditions
yield self.failUnlessFailure(
resource.renderHTTP(request), PreconditionError
)
@inlineCallbacks
def test_checkPreconditions_none(self):
"""
RenderMixin.checkPreconditions()
checkPreconditions() returns None
"""
resource = TestResource()
request = SimpleRequest(Site(resource), "SWEETHOOKUPS", "/")
# Check that checkPreconditions without a raise doesn't barf
self.assertEquals(
(yield resource.renderHTTP(request)),
responsecode.NO_CONTENT
)
def test_checkPreconditions_deferred(self):
"""
RenderMixin.checkPreconditions()
checkPreconditions() returns a deferred
"""
resource = TestResource()
request = SimpleRequest(Site(resource), "HOOKUPS", "/")
# Check that checkPreconditions without a raise doesn't barf
def checkResponse(response):
self.assertEquals(response, responsecode.NO_CONTENT)
d = resource.renderHTTP(request)
d.addCallback(checkResponse)
def test_OPTIONS_status(self):
"""
RenderMixin.http_OPTIONS()
Response code is OK
"""
response = generateResponse("OPTIONS")
self.assertEquals(response.code, responsecode.OK)
def test_OPTIONS_allow(self):
"""
RenderMixin.http_OPTIONS()
Allow header indicates allowed methods
"""
response = generateResponse("OPTIONS")
self.assertEquals(
set(response.headers.getHeader("allow")),
self._my_allowed_methods
)
def test_TRACE_status(self):
"""
RenderMixin.http_TRACE()
Response code is OK
"""
response = generateResponse("TRACE")
self.assertEquals(response.code, responsecode.OK)
test_TRACE_status.skip = "TRACE is disabled now."
def test_TRACE_body(self):
"""
RenderMixin.http_TRACE()
Check body for traciness
"""
raise NotImplementedError()
test_TRACE_body.todo = "Someone should write this test"
def test_HEAD_status(self):
"""
RenderMixin.http_HEAD()
Response code is OK
"""
response = generateResponse("HEAD")
self.assertEquals(response.code, responsecode.OK)
def test_HEAD_body(self):
"""
RenderMixin.http_HEAD()
Check body is empty
"""
response = generateResponse("HEAD")
self.assertEquals(response.stream.length, 0)
test_HEAD_body.todo = (
"http_HEAD is implemented in a goober way that "
"relies on the server code to clean up after it."
)
def test_GET_status(self):
"""
RenderMixin.http_GET()
Response code is OK
"""
response = generateResponse("GET")
self.assertEquals(response.code, responsecode.OK)
def test_GET_body(self):
"""
RenderMixin.http_GET()
Check body is empty
"""
response = generateResponse("GET")
self.assertEquals(
str(response.stream.read()),
TestResource.renderOutput
)
class ResourceTestCase (unittest.TestCase):
"""
Test Resource.
"""
def test_addSlash(self):
# I think this would include a test of http_GET()
raise NotImplementedError()
test_addSlash.todo = "Someone should write this test"
def test_locateChild(self):
raise NotImplementedError()
test_locateChild.todo = "Someone should write this test"
def test_child_nonsense(self):
raise NotImplementedError()
test_child_nonsense.todo = "Someone should write this test"
class PostableResourceTestCase (unittest.TestCase):
"""
Test PostableResource.
"""
def test_POST(self):
raise NotImplementedError()
test_POST.todo = "Someone should write this test"
class LeafResourceTestCase (unittest.TestCase):
"""
Test LeafResource.
"""
def test_locateChild(self):
resource = LeafResource()
child, segments = (
resource.locateChild(
SimpleRequest(Site(resource), "GET", "/"),
("", "foo"),
)
)
self.assertEquals(child, resource)
self.assertEquals(segments, StopTraversal)
class WrapperResourceTestCase (unittest.TestCase):
"""
Test WrapperResource.
"""
def test_hook(self):
raise NotImplementedError()
test_hook.todo = "Someone should write this test"
calendarserver-5.2+dfsg/twext/web2/test/simple_client.py 0000644 0001750 0001750 00000001561 11337102650 022502 0 ustar rahul rahul import socket, sys
test_type = sys.argv[1]
port = int(sys.argv[2])
socket_type = sys.argv[3]
s = socket.socket(socket.AF_INET)
s.connect(("127.0.0.1", port))
s.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 40000)
if socket_type == 'ssl':
s2 = socket.ssl(s)
send=s2.write
recv=s2.read
else:
send=s.send
recv=s.recv
print >> sys.stderr, ">> Making %s request to port %d" % (socket_type, port)
send("GET /error HTTP/1.0\r\n")
send("Host: localhost\r\n")
if test_type == "lingeringClose":
print >> sys.stderr, ">> Sending lots of data"
send("Content-Length: 1000000\r\n\r\n")
send("X"*1000000)
else:
send('\r\n')
#import time
#time.sleep(5)
print >> sys.stderr, ">> Getting data"
data=''
while len(data) < 299999:
try:
x=recv(10000)
except:
break
if x == '':
break
data+=x
sys.stdout.write(data)
calendarserver-5.2+dfsg/twext/web2/test/stream_data.txt 0000644 0001750 0001750 00000000025 11337102650 022320 0 ustar rahul rahul We've got some text!
calendarserver-5.2+dfsg/twext/web2/test/test_metafd.py 0000644 0001750 0001750 00000024520 12306427141 022154 0 ustar rahul rahul ##
# Copyright (c) 2011-2014 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
"""
Tests for twext.web2.metafd.
"""
from socket import error as SocketError, AF_INET
from errno import ENOTCONN
from twext.internet import sendfdport
from twext.web2 import metafd
from twext.web2.channel.http import HTTPChannel
from twext.web2.metafd import ReportingHTTPService, ConnectionLimiter
from twisted.internet.tcp import Server
from twisted.application.service import Service
from twext.internet.test.test_sendfdport import ReaderAdder
from twext.web2.metafd import WorkerStatus
from twisted.trial.unittest import TestCase
class FakeSocket(object):
"""
A fake socket for testing.
"""
def __init__(self, test):
self.test = test
def fileno(self):
return "not a socket"
def setblocking(self, blocking):
return
def getpeername(self):
if self.test.peerNameSucceed:
return ("4.3.2.1", 4321)
else:
raise SocketError(ENOTCONN, "Transport endpoint not connected")
def getsockname(self):
return ("4.3.2.1", 4321)
class InheritedPortForTesting(sendfdport.InheritedPort):
"""
L{sendfdport.InheritedPort} subclass that prevents certain I/O operations
for better unit testing.
"""
def startReading(self):
"Do nothing."
def stopReading(self):
"Do nothing."
def startWriting(self):
"Do nothing."
def stopWriting(self):
"Do nothing."
class ServerTransportForTesting(Server):
"""
tcp.Server replacement for testing purposes.
"""
def startReading(self):
"Do nothing."
def stopReading(self):
"Do nothing."
def startWriting(self):
"Do nothing."
def stopWriting(self):
"Do nothing."
def __init__(self, *a, **kw):
super(ServerTransportForTesting, self).__init__(*a, **kw)
self.reactor = None
class ReportingHTTPServiceTests(TestCase):
"""
Tests for L{ReportingHTTPService}
"""
peerNameSucceed = True
def setUp(self):
def fakefromfd(fd, addressFamily, socketType):
return FakeSocket(self)
def fakerecvfd(fd):
return "not an fd", "not a description"
def fakeclose(fd):
""
def fakegetsockfam(fd):
return AF_INET
self.patch(sendfdport, 'recvfd', fakerecvfd)
self.patch(sendfdport, 'fromfd', fakefromfd)
self.patch(sendfdport, 'close', fakeclose)
self.patch(sendfdport, 'getsockfam', fakegetsockfam)
self.patch(metafd, 'InheritedPort', InheritedPortForTesting)
self.patch(metafd, 'Server', ServerTransportForTesting)
# This last stubbed out just to prevent dirty reactor warnings.
self.patch(HTTPChannel, "callLater", lambda *a, **k: None)
self.svc = ReportingHTTPService(None, None, None)
self.svc.startService()
def test_quickClosedSocket(self):
"""
If a socket is closed very quickly after being {accept()}ed, requesting
its peer (or even host) address may fail with C{ENOTCONN}. If this
happens, its transport should be supplied with a dummy peer address.
"""
self.peerNameSucceed = False
self.svc.reportingFactory.inheritedPort.doRead()
channels = self.svc.reportingFactory.connectedChannels
self.assertEqual(len(channels), 1)
self.assertEqual(list(channels)[0].transport.getPeer().host, "0.0.0.0")
class ConnectionLimiterTests(TestCase):
"""
Tests for L{ConnectionLimiter}
"""
def test_loadReducedStartsReadingAgain(self):
"""
L{ConnectionLimiter.statusesChanged} determines whether the current
"load" of all subprocesses - that is, the total outstanding request
count - is high enough that the listening ports attached to it should
be suspended.
"""
builder = LimiterBuilder(self)
builder.fillUp()
self.assertEquals(builder.port.reading, False) # sanity check
self.assertEquals(builder.highestLoad(), builder.requestsPerSocket)
builder.loadDown()
self.assertEquals(builder.port.reading, True)
def test_processRestartedStartsReadingAgain(self):
"""
L{ConnectionLimiter.statusesChanged} determines whether the current
number of outstanding requests is above the limit, and either stops or
resumes reading on the listening port.
"""
builder = LimiterBuilder(self)
builder.fillUp()
self.assertEquals(builder.port.reading, False)
self.assertEquals(builder.highestLoad(), builder.requestsPerSocket)
builder.processRestart()
self.assertEquals(builder.port.reading, True)
def test_unevenLoadDistribution(self):
"""
Subprocess sockets should be selected for subsequent socket sends by
ascending status. Status should sum sent and successfully subsumed
sockets.
"""
builder = LimiterBuilder(self)
# Give one simulated worker a higher acknowledged load than the other.
builder.fillUp(True, 1)
# There should still be plenty of spare capacity.
self.assertEquals(builder.port.reading, True)
# Then slam it with a bunch of incoming requests.
builder.fillUp(False, builder.limiter.maxRequests - 1)
# Now capacity is full.
self.assertEquals(builder.port.reading, False)
# And everyone should have an even amount of work.
self.assertEquals(builder.highestLoad(), builder.requestsPerSocket)
def test_processStopsReadingEvenWhenConnectionsAreNotAcknowledged(self):
"""
L{ConnectionLimiter.statusesChanged} determines whether the current
number of outstanding requests is above the limit.
"""
builder = LimiterBuilder(self)
builder.fillUp(acknowledged=False)
self.assertEquals(builder.highestLoad(), builder.requestsPerSocket)
self.assertEquals(builder.port.reading, False)
builder.processRestart()
self.assertEquals(builder.port.reading, True)
def test_workerStatusRepr(self):
"""
L{WorkerStatus.__repr__} will show all the values associated with the
status of the worker.
"""
self.assertEquals(repr(WorkerStatus(1, 2, 3, 4, 5, 6, 7, 8)),
"")
def test_workerStatusNonNegative(self):
"""
L{WorkerStatus.__repr__} will show all the values associated with the
status of the worker.
"""
w = WorkerStatus()
w.adjust(
acknowledged=1,
unacknowledged=-1,
total=1,
)
self.assertEquals(w.acknowledged, 1)
self.assertEquals(w.unacknowledged, 0)
self.assertEquals(w.total, 1)
class LimiterBuilder(object):
"""
A L{LimiterBuilder} can build a L{ConnectionLimiter} and associated objects
for a given unit test.
"""
def __init__(self, test, requestsPerSocket=3, socketCount=2):
# Similar to MaxRequests in the configuration.
self.requestsPerSocket = requestsPerSocket
# Similar to ProcessCount in the configuration.
self.socketCount = socketCount
self.limiter = ConnectionLimiter(
2, maxRequests=requestsPerSocket * socketCount
)
self.dispatcher = self.limiter.dispatcher
self.dispatcher.reactor = ReaderAdder()
self.service = Service()
self.limiter.addPortService("TCP", 4321, "127.0.0.1", 5,
self.serverServiceMakerMaker(self.service))
for ignored in xrange(socketCount):
subskt = self.dispatcher.addSocket()
subskt.start()
subskt.restarted()
# Has to be running in order to add stuff.
self.limiter.startService()
self.port = self.service.myPort
def highestLoad(self):
return max(
skt.status.effective()
for skt in self.limiter.dispatcher._subprocessSockets
)
def serverServiceMakerMaker(self, s):
"""
Make a serverServiceMaker for use with
L{ConnectionLimiter.addPortService}.
"""
class NotAPort(object):
def startReading(self):
self.reading = True
def stopReading(self):
self.reading = False
def serverServiceMaker(port, factory, *a, **k):
s.factory = factory
s.myPort = NotAPort()
# TODO: technically, the following should wait for startService
s.myPort.startReading()
factory.myServer = s
return s
return serverServiceMaker
def fillUp(self, acknowledged=True, count=0):
"""
Fill up all the slots on the connection limiter.
@param acknowledged: Should the virtual connections created by this
method send a message back to the dispatcher indicating that the
subprocess has acknowledged receipt of the file descriptor?
@param count: Amount of load to add; default to the maximum that the
limiter.
"""
for _ignore_x in range(count or self.limiter.maxRequests):
self.dispatcher.sendFileDescriptor(None, "SSL")
if acknowledged:
self.dispatcher.statusMessage(
self.dispatcher._subprocessSockets[0], "+"
)
def processRestart(self):
self.dispatcher._subprocessSockets[0].stop()
self.dispatcher._subprocessSockets[0].start()
self.dispatcher.statusMessage(
self.dispatcher._subprocessSockets[0], "0"
)
def loadDown(self):
self.dispatcher.statusMessage(
self.dispatcher._subprocessSockets[0], "-"
)
calendarserver-5.2+dfsg/twext/web2/test/test_stream.py 0000644 0001750 0001750 00000052427 11736444441 022226 0 ustar rahul rahul # Copyright (c) 2008 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for the stream implementations in L{twext.web2}.
"""
import tempfile, sys, os
from zope.interface import implements
from twisted.python.util import sibpath
sibpath # sibpath is *not* unused - the doctests use it.
from twisted.python.hashlib import md5
from twisted.internet import reactor, defer, interfaces
from twisted.trial import unittest
from twext.web2 import stream
def bufstr(data):
try:
return str(buffer(data))
except TypeError:
raise TypeError("%s doesn't conform to the buffer interface" % (data,))
class SimpleStreamTests:
text = '1234567890'
def test_split(self):
for point in range(10):
s = self.makeStream(0)
a,b = s.split(point)
if point > 0:
self.assertEquals(bufstr(a.read()), self.text[:point])
self.assertEquals(a.read(), None)
if point < len(self.text):
self.assertEquals(bufstr(b.read()), self.text[point:])
self.assertEquals(b.read(), None)
for point in range(7):
s = self.makeStream(2, 6)
self.assertEquals(s.length, 6)
a,b = s.split(point)
if point > 0:
self.assertEquals(bufstr(a.read()), self.text[2:point+2])
self.assertEquals(a.read(), None)
if point < 6:
self.assertEquals(bufstr(b.read()), self.text[point+2:8])
self.assertEquals(b.read(), None)
def test_read(self):
s = self.makeStream()
self.assertEquals(s.length, len(self.text))
self.assertEquals(bufstr(s.read()), self.text)
self.assertEquals(s.read(), None)
s = self.makeStream(0, 4)
self.assertEquals(s.length, 4)
self.assertEquals(bufstr(s.read()), self.text[0:4])
self.assertEquals(s.read(), None)
self.assertEquals(s.length, 0)
s = self.makeStream(4, 6)
self.assertEquals(s.length, 6)
self.assertEquals(bufstr(s.read()), self.text[4:10])
self.assertEquals(s.read(), None)
self.assertEquals(s.length, 0)
class FileStreamTest(SimpleStreamTests, unittest.TestCase):
def makeStream(self, *args, **kw):
return stream.FileStream(self.f, *args, **kw)
def setUp(self):
"""
Create a file containing C{self.text} to be streamed.
"""
f = tempfile.TemporaryFile('w+')
f.write(self.text)
f.seek(0, 0)
self.f = f
def test_close(self):
s = self.makeStream()
s.close()
self.assertEquals(s.length, 0)
# Make sure close doesn't close file
# would raise exception if f is closed
self.f.seek(0, 0)
def test_read2(self):
s = self.makeStream(0)
s.CHUNK_SIZE = 6
self.assertEquals(s.length, 10)
self.assertEquals(bufstr(s.read()), self.text[0:6])
self.assertEquals(bufstr(s.read()), self.text[6:10])
self.assertEquals(s.read(), None)
s = self.makeStream(0)
s.CHUNK_SIZE = 5
self.assertEquals(s.length, 10)
self.assertEquals(bufstr(s.read()), self.text[0:5])
self.assertEquals(bufstr(s.read()), self.text[5:10])
self.assertEquals(s.read(), None)
s = self.makeStream(0, 20)
self.assertEquals(s.length, 20)
self.assertEquals(bufstr(s.read()), self.text)
self.assertRaises(RuntimeError, s.read) # ran out of data
class MMapFileStreamTest(SimpleStreamTests, unittest.TestCase):
text = SimpleStreamTests.text
text = text * (stream.MMAP_THRESHOLD // len(text) + 1)
def makeStream(self, *args, **kw):
return stream.FileStream(self.f, *args, **kw)
def setUp(self):
"""
Create a file containing C{self.text}, which should be long enough to
trigger the mmap-case in L{stream.FileStream}.
"""
f = tempfile.TemporaryFile('w+')
f.write(self.text)
f.seek(0, 0)
self.f = f
def test_mmapwrapper(self):
self.assertRaises(TypeError, stream.mmapwrapper)
self.assertRaises(TypeError, stream.mmapwrapper, offset = 0)
self.assertRaises(TypeError, stream.mmapwrapper, offset = None)
if not stream.mmap:
test_mmapwrapper.skip = 'mmap not supported here'
class MemoryStreamTest(SimpleStreamTests, unittest.TestCase):
def makeStream(self, *args, **kw):
return stream.MemoryStream(self.text, *args, **kw)
def test_close(self):
s = self.makeStream()
s.close()
self.assertEquals(s.length, 0)
def test_read2(self):
self.assertRaises(ValueError, self.makeStream, 0, 20)
testdata = """I was angry with my friend:
I told my wrath, my wrath did end.
I was angry with my foe:
I told it not, my wrath did grow.
And I water'd it in fears,
Night and morning with my tears;
And I sunned it with smiles,
And with soft deceitful wiles.
And it grew both day and night,
Till it bore an apple bright;
And my foe beheld it shine,
And he knew that is was mine,
And into my garden stole
When the night had veil'd the pole:
In the morning glad I see
My foe outstretch'd beneath the tree"""
class TestBufferedStream(unittest.TestCase):
def setUp(self):
self.data = testdata.replace('\n', '\r\n')
s = stream.MemoryStream(self.data)
self.s = stream.BufferedStream(s)
def _cbGotData(self, data, expected):
self.assertEqual(data, expected)
def test_readline(self):
"""Test that readline reads a line."""
d = self.s.readline()
d.addCallback(self._cbGotData, 'I was angry with my friend:\r\n')
return d
def test_readlineWithSize(self):
"""Test the size argument to readline"""
d = self.s.readline(size = 5)
d.addCallback(self._cbGotData, 'I was')
return d
def test_readlineWithBigSize(self):
"""Test the size argument when it's bigger than the length of the line."""
d = self.s.readline(size = 40)
d.addCallback(self._cbGotData, 'I was angry with my friend:\r\n')
return d
def test_readlineWithZero(self):
"""Test readline with size = 0."""
d = self.s.readline(size = 0)
d.addCallback(self._cbGotData, '')
return d
def test_readlineFinished(self):
"""Test readline on a finished stream."""
nolines = len(self.data.split('\r\n'))
for i in range(nolines):
self.s.readline()
d = self.s.readline()
d.addCallback(self._cbGotData, '')
return d
def test_readlineNegSize(self):
"""Ensure that readline with a negative size raises an exception."""
self.assertRaises(ValueError, self.s.readline, size = -1)
def test_readlineSizeInDelimiter(self):
"""
Test behavior of readline when size falls inside the
delimiter.
"""
d = self.s.readline(size=28)
d.addCallback(self._cbGotData, "I was angry with my friend:\r")
d.addCallback(lambda _: self.s.readline())
d.addCallback(self._cbGotData, "\nI told my wrath, my wrath did end.\r\n")
def test_readExactly(self):
"""Make sure readExactly with no arg reads all the data."""
d = self.s.readExactly()
d.addCallback(self._cbGotData, self.data)
return d
def test_readExactlyLimited(self):
"""
Test readExactly with a number.
"""
d = self.s.readExactly(10)
d.addCallback(self._cbGotData, self.data[:10])
return d
def test_readExactlyBig(self):
"""
Test readExactly with a number larger than the size of the
datastream.
"""
d = self.s.readExactly(100000)
d.addCallback(self._cbGotData, self.data)
return d
def test_read(self):
"""
Make sure read() also functions. (note that this test uses
an implementation detail of this particular stream. s.read()
isn't guaranteed to return self.data on all streams.)
"""
self.assertEqual(str(self.s.read()), self.data)
class TestStreamer:
implements(stream.IStream, stream.IByteStream)
length = None
readCalled=0
closeCalled=0
def __init__(self, list):
self.list = list
def read(self):
self.readCalled+=1
if self.list:
return self.list.pop(0)
return None
def close(self):
self.closeCalled+=1
self.list = []
class FallbackSplitTest(unittest.TestCase):
def test_split(self):
s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl'])
left,right = stream.fallbackSplit(s, 5)
self.assertEquals(left.length, 5)
self.assertEquals(right.length, None)
self.assertEquals(bufstr(left.read()), 'abcd')
d = left.read()
d.addCallback(self._cbSplit, left, right)
return d
def _cbSplit(self, result, left, right):
self.assertEquals(bufstr(result), 'e')
self.assertEquals(left.read(), None)
self.assertEquals(bufstr(right.read().result), 'fgh')
self.assertEquals(bufstr(right.read()), 'ijkl')
self.assertEquals(right.read(), None)
def test_split2(self):
s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl'])
left,right = stream.fallbackSplit(s, 4)
self.assertEquals(left.length, 4)
self.assertEquals(right.length, None)
self.assertEquals(bufstr(left.read()), 'abcd')
self.assertEquals(left.read(), None)
self.assertEquals(bufstr(right.read().result), 'efgh')
self.assertEquals(bufstr(right.read()), 'ijkl')
self.assertEquals(right.read(), None)
def test_splitsplit(self):
s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl'])
left,right = stream.fallbackSplit(s, 5)
left,middle = left.split(3)
self.assertEquals(left.length, 3)
self.assertEquals(middle.length, 2)
self.assertEquals(right.length, None)
self.assertEquals(bufstr(left.read()), 'abc')
self.assertEquals(left.read(), None)
self.assertEquals(bufstr(middle.read().result), 'd')
self.assertEquals(bufstr(middle.read().result), 'e')
self.assertEquals(middle.read(), None)
self.assertEquals(bufstr(right.read().result), 'fgh')
self.assertEquals(bufstr(right.read()), 'ijkl')
self.assertEquals(right.read(), None)
def test_closeboth(self):
s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl'])
left,right = stream.fallbackSplit(s, 5)
left.close()
self.assertEquals(s.closeCalled, 0)
right.close()
# Make sure nothing got read
self.assertEquals(s.readCalled, 0)
self.assertEquals(s.closeCalled, 1)
def test_closeboth_rev(self):
s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl'])
left,right = stream.fallbackSplit(s, 5)
right.close()
self.assertEquals(s.closeCalled, 0)
left.close()
# Make sure nothing got read
self.assertEquals(s.readCalled, 0)
self.assertEquals(s.closeCalled, 1)
def test_closeleft(self):
s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl'])
left,right = stream.fallbackSplit(s, 5)
left.close()
d = right.read()
d.addCallback(self._cbCloseleft, right)
return d
def _cbCloseleft(self, result, right):
self.assertEquals(bufstr(result), 'fgh')
self.assertEquals(bufstr(right.read()), 'ijkl')
self.assertEquals(right.read(), None)
def test_closeright(self):
s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl'])
left,right = stream.fallbackSplit(s, 3)
right.close()
self.assertEquals(bufstr(left.read()), 'abc')
self.assertEquals(left.read(), None)
self.assertEquals(s.closeCalled, 1)
class ProcessStreamerTest(unittest.TestCase):
if interfaces.IReactorProcess(reactor, None) is None:
skip = "Platform lacks spawnProcess support, can't test process streaming."
def runCode(self, code, inputStream=None):
if inputStream is None:
inputStream = stream.MemoryStream("")
return stream.ProcessStreamer(inputStream, sys.executable,
[sys.executable, "-u", "-c", code],
os.environ)
def test_output(self):
p = self.runCode("import sys\nfor i in range(100): sys.stdout.write('x' * 1000)")
l = []
d = stream.readStream(p.outStream, l.append)
def verify(_):
self.assertEquals("".join(l), ("x" * 1000) * 100)
d2 = p.run()
return d.addCallback(verify).addCallback(lambda _: d2)
def test_errouput(self):
p = self.runCode("import sys\nfor i in range(100): sys.stderr.write('x' * 1000)")
l = []
d = stream.readStream(p.errStream, l.append)
def verify(_):
self.assertEquals("".join(l), ("x" * 1000) * 100)
p.run()
return d.addCallback(verify)
def test_input(self):
p = self.runCode("import sys\nsys.stdout.write(sys.stdin.read())",
"hello world")
l = []
d = stream.readStream(p.outStream, l.append)
d2 = p.run()
def verify(_):
self.assertEquals("".join(l), "hello world")
return d2
return d.addCallback(verify)
def test_badexit(self):
p = self.runCode("raise ValueError")
l = []
from twisted.internet.error import ProcessTerminated
def verify(_):
self.assertEquals(l, [1])
self.assert_(p.outStream.closed)
self.assert_(p.errStream.closed)
return p.run().addErrback(lambda _: _.trap(ProcessTerminated) and l.append(1)).addCallback(verify)
def test_inputerror(self):
p = self.runCode("import sys\nsys.stdout.write(sys.stdin.read())",
TestStreamer(["hello", defer.fail(ZeroDivisionError())]))
l = []
d = stream.readStream(p.outStream, l.append)
d2 = p.run()
def verify(_):
self.assertEquals("".join(l), "hello")
return d2
def cbVerified(ignored):
excs = self.flushLoggedErrors(ZeroDivisionError)
self.assertEqual(len(excs), 1)
return d.addCallback(verify).addCallback(cbVerified)
def test_processclosedinput(self):
p = self.runCode("import sys; sys.stdout.write(sys.stdin.read(3));" +
"sys.stdin.close(); sys.stdout.write('def')",
"abc123")
l = []
d = stream.readStream(p.outStream, l.append)
def verify(_):
self.assertEquals("".join(l), "abcdef")
d2 = p.run()
return d.addCallback(verify).addCallback(lambda _: d2)
class AdapterTestCase(unittest.TestCase):
def test_adapt(self):
fName = self.mktemp()
f = file(fName, "w")
f.write("test")
f.close()
for i in ("test", buffer("test"), file(fName)):
s = stream.IByteStream(i)
self.assertEquals(str(s.read()), "test")
self.assertEquals(s.read(), None)
class ReadStreamTestCase(unittest.TestCase):
def test_pull(self):
l = []
s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl'])
return stream.readStream(s, l.append).addCallback(
lambda _: self.assertEquals(l, ["abcd", "efgh", "ijkl"]))
def test_pullFailure(self):
l = []
s = TestStreamer(['abcd', defer.fail(RuntimeError()), 'ijkl'])
def test(result):
result.trap(RuntimeError)
self.assertEquals(l, ["abcd"])
return stream.readStream(s, l.append).addErrback(test)
def test_pullException(self):
class Failer:
def read(self): raise RuntimeError
return stream.readStream(Failer(), lambda _: None).addErrback(
lambda _: _.trap(RuntimeError))
def test_processingException(self):
s = TestStreamer(['abcd', defer.succeed('efgh'), 'ijkl'])
return stream.readStream(s, lambda x: 1/0).addErrback(
lambda _: _.trap(ZeroDivisionError))
class ProducerStreamTestCase(unittest.TestCase):
def test_failfinish(self):
p = stream.ProducerStream()
p.write("hello")
p.finish(RuntimeError())
self.assertEquals(p.read(), "hello")
d = p.read()
l = []
d.addErrback(lambda _: (l.append(1), _.trap(RuntimeError))).addCallback(
lambda _: self.assertEquals(l, [1]))
return d
class CompoundStreamTest:
"""
CompoundStream lets you combine many streams into one continuous stream.
For example, let's make a stream:
>>> s = stream.CompoundStream()
Then, add a couple streams:
>>> s.addStream(stream.MemoryStream("Stream1"))
>>> s.addStream(stream.MemoryStream("Stream2"))
The length is the sum of all the streams:
>>> s.length
14
We can read data from the stream:
>>> str(s.read())
'Stream1'
After having read some data, length is now smaller, as you might expect:
>>> s.length
7
So, continue reading...
>>> str(s.read())
'Stream2'
Now that the stream is exhausted:
>>> s.read() is None
True
>>> s.length
0
We can also create CompoundStream more easily like so:
>>> s = stream.CompoundStream(['hello', stream.MemoryStream(' world')])
>>> str(s.read())
'hello'
>>> str(s.read())
' world'
For a more complicated example, let's try reading from a file:
>>> s = stream.CompoundStream()
>>> s.addStream(stream.FileStream(open(sibpath(__file__, "stream_data.txt"))))
>>> s.addStream("================")
>>> s.addStream(stream.FileStream(open(sibpath(__file__, "stream_data.txt"))))
Again, the length is the sum:
>>> int(s.length)
58
>>> str(s.read())
"We've got some text!\\n"
>>> str(s.read())
'================'
What if you close the stream?
>>> s.close()
>>> s.read() is None
True
>>> s.length
0
Error handling works using Deferreds:
>>> m = stream.MemoryStream("after")
>>> s = stream.CompoundStream([TestStreamer([defer.fail(ZeroDivisionError())]), m]) # z<
>>> l = []; x = s.read().addErrback(lambda _: l.append(1))
>>> l
[1]
>>> s.length
0
>>> m.length # streams after the failed one got closed
0
"""
class AsynchronousDummyStream(object):
"""
An L{IByteStream} implementation which always returns a
L{defer.Deferred} from C{read} and lets an external driver fire
them.
"""
def __init__(self):
self._readResults = []
def read(self):
result = defer.Deferred()
self._readResults.append(result)
return result
def _write(self, bytes):
self._readResults.pop(0).callback(bytes)
class MD5StreamTest(unittest.TestCase):
"""
Tests for L{stream.MD5Stream}.
"""
data = "I am sorry Dave, I can't do that.\n--HAL 9000"
digest = md5(data).hexdigest()
def test_synchronous(self):
"""
L{stream.MD5Stream} computes the MD5 hash of the contents of the stream
around which it is wrapped. It supports L{IByteStream} providers which
return C{str} from their C{read} method.
"""
dataStream = stream.MemoryStream(self.data)
md5Stream = stream.MD5Stream(dataStream)
self.assertEquals(str(md5Stream.read()), self.data)
self.assertIdentical(md5Stream.read(), None)
md5Stream.close()
self.assertEquals(self.digest, md5Stream.getMD5())
def test_asynchronous(self):
"""
L{stream.MD5Stream} also supports L{IByteStream} providers which return
L{Deferreds} from their C{read} method.
"""
dataStream = AsynchronousDummyStream()
md5Stream = stream.MD5Stream(dataStream)
result = md5Stream.read()
dataStream._write(self.data)
result.addCallback(self.assertEquals, self.data)
def cbRead(ignored):
result = md5Stream.read()
dataStream._write(None)
result.addCallback(self.assertIdentical, None)
return result
result.addCallback(cbRead)
def cbClosed(ignored):
md5Stream.close()
self.assertEquals(md5Stream.getMD5(), self.digest)
result.addCallback(cbClosed)
return result
def test_getMD5FailsBeforeClose(self):
"""
L{stream.MD5Stream.getMD5} raises L{RuntimeError} if called before
L{stream.MD5Stream.close}.
"""
dataStream = stream.MemoryStream(self.data)
md5Stream = stream.MD5Stream(dataStream)
self.assertRaises(RuntimeError, md5Stream.getMD5)
def test_initializationFailsWithoutStream(self):
"""
L{stream.MD5Stream.__init__} raises L{ValueError} if passed C{None} as
the stream to wrap.
"""
self.assertRaises(ValueError, stream.MD5Stream, None)
def test_readAfterClose(self):
"""
L{stream.MD5Stream.read} raises L{RuntimeError} if called after
L{stream.MD5Stream.close}.
"""
dataStream = stream.MemoryStream(self.data)
md5Stream = stream.MD5Stream(dataStream)
md5Stream.close()
self.assertRaises(RuntimeError, md5Stream.read)
__doctests__ = ['twext.web2.test.test_stream', 'twext.web2.stream']
# TODO:
# CompoundStreamTest
# more tests for ProducerStreamTest
# StreamProducerTest
calendarserver-5.2+dfsg/twext/web2/test/test_httpauth.py 0000644 0001750 0001750 00000102506 12103053166 022553 0 ustar rahul rahul # Copyright (c) 2006-2009 Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.python.hashlib import md5
from twisted.internet import address
from twisted.trial import unittest
from twisted.cred import error
from twext.web2 import http, responsecode
from twext.web2.auth import basic, digest, wrapper
from twext.web2.auth.interfaces import IAuthenticatedRequest, IHTTPUser
from twext.web2.test.test_server import SimpleRequest
from twext.web2.test import test_server
import base64
_trivial_GET = SimpleRequest(None, 'GET', '/')
FAKE_STATIC_NONCE = '178288758716122392881254770685'
def makeDigestDeterministic(twistedDigestFactory, key="0",
nonce=FAKE_STATIC_NONCE, time=0):
"""
Patch up various bits of private state to make a digest credential factory
(the one that comes from Twisted) behave deterministically.
"""
def _fakeStaticNonce():
"""
Generate a static nonce
"""
return nonce
def _fakeStaticTime():
"""
Return a stable time
"""
return time
twistedDigestFactory.privateKey = key
# FIXME: These tests are somewhat redundant with the tests for Twisted's
# built-in digest auth; these private values need to be patched to
# create deterministic results, but at some future point the whole
# digest module should be removed from twext.web2 (as all of twext.web2
# should be removed) and we can just get rid of this.
twistedDigestFactory._generateNonce = _fakeStaticNonce
twistedDigestFactory._getTime = _fakeStaticTime
class FakeDigestCredentialFactory(digest.DigestCredentialFactory):
"""
A Fake Digest Credential Factory that generates a predictable
nonce and opaque
"""
def __init__(self, *args, **kwargs):
super(FakeDigestCredentialFactory, self).__init__(*args, **kwargs)
makeDigestDeterministic(self._real, self._fakeStaticPrivateKey)
_fakeStaticPrivateKey = "0"
class BasicAuthTestCase(unittest.TestCase):
def setUp(self):
self.credentialFactory = basic.BasicCredentialFactory('foo')
self.username = 'dreid'
self.password = 'S3CuR1Ty'
def test_usernamePassword(self):
"""
Test acceptance of username/password in basic auth.
"""
response = base64.encodestring('%s:%s' % (
self.username,
self.password))
d = self.credentialFactory.decode(response, _trivial_GET)
return d.addCallback(
lambda creds: self.failUnless(creds.checkPassword(self.password)))
def test_incorrectPassword(self):
"""
Incorrect passwords cause auth to fail.
"""
response = base64.encodestring('%s:%s' % (
self.username,
'incorrectPassword'))
d = self.credentialFactory.decode(response, _trivial_GET)
return d.addCallback(
lambda creds: self.failIf(creds.checkPassword(self.password)))
def test_incorrectPadding(self):
"""
Responses that have incorrect padding cause auth to fail.
"""
response = base64.encodestring('%s:%s' % (
self.username,
self.password))
response = response.strip('=')
d = self.credentialFactory.decode(response, _trivial_GET)
def _test(creds):
self.failUnless(creds.checkPassword(self.password))
return d.addCallback(_test)
def test_invalidCredentials(self):
"""
Auth attempts with no password should fail.
"""
response = base64.encodestring(self.username)
d = self.credentialFactory.decode(response, _trivial_GET)
self.assertFailure(d, error.LoginFailed)
clientAddress = address.IPv4Address('TCP', '127.0.0.1', 80)
challengeOpaque = ('75c4bd95b96b7b7341c646c6502f0833-MTc4Mjg4NzU'
'4NzE2MTIyMzkyODgxMjU0NzcwNjg1LHJlbW90ZWhvc3Q'
'sMA==')
challengeNonce = '178288758716122392881254770685'
challengeResponse = ('digest',
{'nonce': challengeNonce,
'qop': 'auth', 'realm': 'test realm',
'algorithm': 'md5',
'opaque': challengeOpaque})
cnonce = "29fc54aa1641c6fa0e151419361c8f23"
authRequest1 = ('username="username", realm="test realm", nonce="%s", '
'uri="/write/", response="%s", opaque="%s", algorithm="md5", '
'cnonce="29fc54aa1641c6fa0e151419361c8f23", nc=00000001, '
'qop="auth"')
authRequest2 = ('username="username", realm="test realm", nonce="%s", '
'uri="/write/", response="%s", opaque="%s", algorithm="md5", '
'cnonce="29fc54aa1641c6fa0e151419361c8f23", nc=00000002, '
'qop="auth"')
namelessAuthRequest = 'realm="test realm",nonce="doesn\'t matter"'
class DigestAuthTestCase(unittest.TestCase):
"""
Test the behavior of DigestCredentialFactory
"""
def setUp(self):
"""
Create a DigestCredentialFactory for testing
"""
self.credentialFactory = digest.DigestCredentialFactory('md5',
'test realm')
def getDigestResponse(self, challenge, ncount):
"""
Calculate the response for the given challenge
"""
nonce = challenge.get('nonce')
algo = challenge.get('algorithm').lower()
qop = challenge.get('qop')
expected = digest.calcResponse(
digest.calcHA1(algo,
"username",
"test realm",
"password",
nonce,
cnonce),
algo, nonce, ncount, cnonce, qop, "GET", "/write/", None
)
return expected
def test_getChallenge(self):
"""
Test that all the required fields exist in the challenge,
and that the information matches what we put into our
DigestCredentialFactory
"""
d = self.credentialFactory.getChallenge(clientAddress)
def _test(challenge):
self.assertEquals(challenge['qop'], 'auth')
self.assertEquals(challenge['realm'], 'test realm')
self.assertEquals(challenge['algorithm'], 'md5')
self.assertTrue(challenge.has_key("nonce"))
self.assertTrue(challenge.has_key("opaque"))
return d.addCallback(_test)
def _createAndDecodeChallenge(self, chalID="00000001", req=_trivial_GET):
d = self.credentialFactory.getChallenge(clientAddress)
def _getChallenge(challenge):
return authRequest1 % (
challenge['nonce'],
self.getDigestResponse(challenge, chalID),
challenge['opaque'])
def _getResponse(clientResponse):
return self.credentialFactory.decode(clientResponse, req)
return d.addCallback(_getChallenge).addCallback(_getResponse)
def test_response(self):
"""
Test that we can decode a valid response to our challenge
"""
d = self._createAndDecodeChallenge()
def _test(creds):
self.failUnless(creds.checkPassword('password'))
return d.addCallback(_test)
def test_multiResponse(self):
"""
Test that multiple responses to to a single challenge are handled
successfully.
"""
d = self._createAndDecodeChallenge()
def _test(creds):
self.failUnless(creds.checkPassword('password'))
def _test2(_):
d2 = self._createAndDecodeChallenge("00000002")
return d2.addCallback(_test)
return d.addCallback(_test)
def test_failsWithDifferentMethod(self):
"""
Test that the response fails if made for a different request method
than it is being issued for.
"""
d = self._createAndDecodeChallenge(req=SimpleRequest(None, 'POST', '/'))
def _test(creds):
self.failIf(creds.checkPassword('password'))
return d.addCallback(_test)
def test_noUsername(self):
"""
Test that login fails when our response does not contain a username,
or the username field is empty.
"""
# Check for no username
e = self.assertRaises(error.LoginFailed,
self.credentialFactory.decode,
namelessAuthRequest,
_trivial_GET)
self.assertEquals(str(e), "Invalid response, no username given.")
# Check for an empty username
e = self.assertRaises(error.LoginFailed,
self.credentialFactory.decode,
namelessAuthRequest + ',username=""',
_trivial_GET)
self.assertEquals(str(e), "Invalid response, no username given.")
def test_noNonce(self):
"""
Test that login fails when our response does not contain a nonce
"""
e = self.assertRaises(error.LoginFailed,
self.credentialFactory.decode,
'realm="Test",username="Foo",opaque="bar"',
_trivial_GET)
self.assertEquals(str(e), "Invalid response, no nonce given.")
def test_noOpaque(self):
"""
Test that login fails when our response does not contain a nonce
"""
e = self.assertRaises(error.LoginFailed,
self.credentialFactory.decode,
'realm="Test",username="Foo"',
_trivial_GET)
self.assertEquals(str(e), "Invalid response, no opaque given.")
def test_checkHash(self):
"""
Check that given a hash of the form 'username:realm:password'
we can verify the digest challenge
"""
d = self._createAndDecodeChallenge()
def _test(creds):
self.failUnless(creds.checkHash(
md5('username:test realm:password').hexdigest()))
self.failIf(creds.checkHash(
md5('username:test realm:bogus').hexdigest()))
return d.addCallback(_test)
def test_invalidOpaque(self):
"""
Test that login fails when the opaque does not contain all the required
parts.
"""
credentialFactory = FakeDigestCredentialFactory('md5', 'test realm')
d = credentialFactory.getChallenge(clientAddress)
def _test(challenge):
self.assertRaises(
error.LoginFailed,
credentialFactory.verifyOpaque,
'badOpaque',
challenge['nonce'],
clientAddress.host)
badOpaque = ('foo-%s' % (
'nonce,clientip'.encode('base64').strip('\n'),))
self.assertRaises(
error.LoginFailed,
credentialFactory.verifyOpaque,
badOpaque,
challenge['nonce'],
clientAddress.host)
self.assertRaises(
error.LoginFailed,
credentialFactory.verifyOpaque,
'',
challenge['nonce'],
clientAddress.host)
return d.addCallback(_test)
def test_incompatibleNonce(self):
"""
Test that login fails when the given nonce from the response, does not
match the nonce encoded in the opaque.
"""
credentialFactory = FakeDigestCredentialFactory('md5', 'test realm')
d = credentialFactory.getChallenge(clientAddress)
def _test(challenge):
badNonceOpaque = credentialFactory.generateOpaque(
'1234567890',
clientAddress.host)
self.assertRaises(
error.LoginFailed,
credentialFactory.verifyOpaque,
badNonceOpaque,
challenge['nonce'],
clientAddress.host)
self.assertRaises(
error.LoginFailed,
credentialFactory.verifyOpaque,
badNonceOpaque,
'',
clientAddress.host)
return d.addCallback(_test)
def test_incompatibleClientIp(self):
"""
Test that the login fails when the request comes from a client ip
other than what is encoded in the opaque.
"""
credentialFactory = FakeDigestCredentialFactory('md5', 'test realm')
d = credentialFactory.getChallenge(clientAddress)
def _test(challenge):
badNonceOpaque = credentialFactory.generateOpaque(
challenge['nonce'],
'10.0.0.1')
self.assertRaises(
error.LoginFailed,
credentialFactory.verifyOpaque,
badNonceOpaque,
challenge['nonce'],
clientAddress.host)
return d.addCallback(_test)
def test_oldNonce(self):
"""
Test that the login fails when the given opaque is older than
DigestCredentialFactory.CHALLENGE_LIFETIME_SECS
"""
credentialFactory = FakeDigestCredentialFactory('md5', 'test realm')
d = credentialFactory.getChallenge(clientAddress)
def _test(challenge):
key = '%s,%s,%s' % (challenge['nonce'],
clientAddress.host,
'-137876876')
digest = (md5(key + credentialFactory._fakeStaticPrivateKey)
.hexdigest())
ekey = key.encode('base64')
oldNonceOpaque = '%s-%s' % (digest, ekey.strip('\n'))
self.assertRaises(
error.LoginFailed,
credentialFactory.verifyOpaque,
oldNonceOpaque,
challenge['nonce'],
clientAddress.host)
return d.addCallback(_test)
def test_mismatchedOpaqueChecksum(self):
"""
Test that login fails when the opaque checksum fails verification
"""
credentialFactory = FakeDigestCredentialFactory('md5', 'test realm')
d = credentialFactory.getChallenge(clientAddress)
def _test(challenge):
key = '%s,%s,%s' % (challenge['nonce'],
clientAddress.host,
'0')
digest = md5(key + 'this is not the right pkey').hexdigest()
badChecksum = '%s-%s' % (digest,
key.encode('base64').strip('\n'))
self.assertRaises(
error.LoginFailed,
credentialFactory.verifyOpaque,
badChecksum,
challenge['nonce'],
clientAddress.host)
return d.addCallback(_test)
def test_incompatibleCalcHA1Options(self):
"""
Test that the appropriate error is raised when any of the
pszUsername, pszRealm, or pszPassword arguments are specified with
the preHA1 keyword argument.
"""
arguments = (
("user", "realm", "password", "preHA1"),
(None, "realm", None, "preHA1"),
(None, None, "password", "preHA1"),
)
for pszUsername, pszRealm, pszPassword, preHA1 in arguments:
self.assertRaises(
TypeError,
digest.calcHA1,
"md5",
pszUsername,
pszRealm,
pszPassword,
"nonce",
"cnonce",
preHA1=preHA1
)
def test_noNewlineOpaque(self):
"""
L{digest.DigestCredentialFactory._generateOpaque} returns a value
without newlines, regardless of the length of the nonce.
"""
opaque = self.credentialFactory.generateOpaque(
"long nonce " * 10, None)
self.assertNotIn('\n', opaque)
from zope.interface import implements
from twisted.cred import portal, checkers
class TestHTTPUser(object):
"""
Test avatar implementation for http auth with cred
"""
implements(IHTTPUser)
username = None
def __init__(self, username):
"""
@param username: The str username sent as part of the HTTP auth
response.
"""
self.username = username
class TestAuthRealm(object):
"""
Test realm that supports the IHTTPUser interface
"""
implements(portal.IRealm)
def requestAvatar(self, avatarId, mind, *interfaces):
if IHTTPUser in interfaces:
if avatarId == checkers.ANONYMOUS:
return IHTTPUser, TestHTTPUser('anonymous')
return IHTTPUser, TestHTTPUser(avatarId)
raise NotImplementedError("Only IHTTPUser interface is supported")
class ProtectedResource(test_server.BaseTestResource):
"""
A test resource for use with HTTPAuthWrapper that holds on to it's
request and segments so we can assert things about them.
"""
addSlash = True
request = None
segments = None
def render(self, req):
self.request = req
return super(ProtectedResource, self).render(req)
def locateChild(self, req, segments):
self.segments = segments
return super(ProtectedResource, self).locateChild(req, segments)
class NonAnonymousResource(test_server.BaseTestResource):
"""
A resource that forces authentication by raising an
HTTPError with an UNAUTHORIZED code if the request is
an anonymous one.
"""
addSlash = True
sendOwnHeaders = False
def render(self, req):
if req.avatar.username == 'anonymous':
if not self.sendOwnHeaders:
raise http.HTTPError(responsecode.UNAUTHORIZED)
else:
return http.Response(
responsecode.UNAUTHORIZED,
{'www-authenticate': [('basic', {'realm': 'foo'})]})
else:
return super(NonAnonymousResource, self).render(req)
class HTTPAuthResourceTest(test_server.BaseCase):
"""
Tests for the HTTPAuthWrapper Resource
"""
def setUp(self):
"""
Create a portal and add an in memory checker to it.
Then set up a protectedResource that will be wrapped in each test.
"""
self.portal = portal.Portal(TestAuthRealm())
c = checkers.InMemoryUsernamePasswordDatabaseDontUse()
c.addUser('username', 'password')
self.portal.registerChecker(c)
self.credFactory = basic.BasicCredentialFactory('test realm')
self.protectedResource = ProtectedResource()
self.protectedResource.responseText = "You shouldn't see me."
def tearDown(self):
"""
Clean up by getting rid of the portal, credentialFactory, and
protected resource
"""
del self.portal
del self.credFactory
del self.protectedResource
def test_authenticatedRequest(self):
"""
Test that after successful authentication the request provides
IAuthenticatedRequest and that the request.avatar implements
the proper interfaces for this realm and has the proper values
for this request.
"""
self.protectedResource.responseText = "I hope you can see me."
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces=(IHTTPUser,))
credentials = base64.encodestring('username:password')
d = self.assertResponse((root, 'http://localhost/',
{'authorization': ('basic', credentials)}),
(200,
{}, 'I hope you can see me.'))
def checkRequest(result):
resource = self.protectedResource
self.failUnless(hasattr(resource, "request"))
request = resource.request
self.failUnless(IAuthenticatedRequest.providedBy(request))
self.failUnless(hasattr(request, "avatar"))
self.failUnless(IHTTPUser.providedBy(request.avatar))
self.failUnless(hasattr(request, "avatarInterface"))
self.assertEquals(request.avatarInterface, IHTTPUser)
self.assertEquals(request.avatar.username, 'username')
d.addCallback(checkRequest)
return d
def test_allowedMethods(self):
"""
Test that unknown methods result in a 401 instead of a 405 when
authentication hasn't been completed.
"""
self.method = 'PROPFIND'
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces=(IHTTPUser,))
d = self.assertResponse(
(root, 'http://localhost/'),
(401,
{'WWW-Authenticate': [('basic',
{'realm': "test realm"})]},
None))
self.method = 'GET'
return d
def test_unauthorizedResponse(self):
"""
Test that a request with no credentials results in a
valid Unauthorized response.
"""
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces=(IHTTPUser,))
def makeDeepRequest(res):
return self.assertResponse(
(root,
'http://localhost/foo/bar/baz/bax'),
(401,
{'WWW-Authenticate': [('basic',
{'realm': "test realm"})]},
None))
d = self.assertResponse(
(root, 'http://localhost/'),
(401,
{'WWW-Authenticate': [('basic',
{'realm': "test realm"})]},
None))
return d.addCallback(makeDeepRequest)
def test_badCredentials(self):
"""
Test that a request with bad credentials results in a valid
Unauthorized response
"""
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces=(IHTTPUser,))
credentials = base64.encodestring('bad:credentials')
d = self.assertResponse(
(root, 'http://localhost/',
{'authorization': [('basic', credentials)]}),
(401,
{'WWW-Authenticate': [('basic',
{'realm': "test realm"})]},
None))
return d
def test_successfulLogin(self):
"""
Test that a request with good credentials results in the
appropriate response from the protected resource
"""
self.protectedResource.responseText = "I hope you can see me."
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces=(IHTTPUser,))
credentials = base64.encodestring('username:password')
d = self.assertResponse((root, 'http://localhost/',
{'authorization': ('basic', credentials)}),
(200,
{}, 'I hope you can see me.'))
return d
def test_wrongScheme(self):
"""
Test that a request with credentials for a scheme that is not
advertised by this resource results in the appropriate
unauthorized response.
"""
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces=(IHTTPUser,))
d = self.assertResponse((root, 'http://localhost/',
{'authorization':
[('digest',
'realm="foo", response="crap"')]}),
(401,
{'www-authenticate':
[('basic', {'realm': 'test realm'})]},
None))
return d
def test_multipleWWWAuthenticateSchemes(self):
"""
Test that our unauthorized response can contain challenges for
multiple authentication schemes.
"""
root = wrapper.HTTPAuthResource(
self.protectedResource,
(basic.BasicCredentialFactory('test realm'),
FakeDigestCredentialFactory('md5', 'test realm')),
self.portal,
interfaces=(IHTTPUser,))
d = self.assertResponse((root, 'http://localhost/', {}),
(401,
{'www-authenticate':
[challengeResponse,
('basic', {'realm': 'test realm'})]},
None))
return d
def test_authorizationAgainstMultipleSchemes(self):
"""
Test that we can successfully authenticate when presented
with multiple WWW-Authenticate headers
"""
root = wrapper.HTTPAuthResource(
self.protectedResource,
(basic.BasicCredentialFactory('test realm'),
FakeDigestCredentialFactory('md5', 'test realm')),
self.portal,
interfaces=(IHTTPUser,))
def respondBasic(ign):
credentials = base64.encodestring('username:password')
d = self.assertResponse((root, 'http://localhost/',
{'authorization':
('basic', credentials)}),
(200,
{}, None))
return d
def respond(ign):
d = self.assertResponse((root, 'http://localhost/',
{'authorization': authRequest1}),
(200,
{},
None))
return d.addCallback(respondBasic)
d = self.assertResponse((root, 'http://localhost/', {}),
(401,
{'www-authenticate':
[challengeResponse,
('basic', {'realm': 'test realm'})]},
None))
return d
def test_wrappedResourceGetsFullSegments(self):
"""
Test that the wrapped resource gets all the URL segments in it's
locateChild.
"""
self.protectedResource.responseText = "I hope you can see me."
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces=(IHTTPUser,))
credentials = base64.encodestring('username:password')
d = self.assertResponse((root, 'http://localhost/foo/bar/baz/bax',
{'authorization': ('basic', credentials)}),
(404,
{}, None))
def checkSegments(ign):
resource = self.protectedResource
self.assertEquals(resource.segments, ['foo', 'bar', 'baz', 'bax'])
d.addCallback(checkSegments)
return d
def test_invalidCredentials(self):
"""
Malformed or otherwise invalid credentials (as determined by
the credential factory) should result in an Unauthorized response
"""
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces=(IHTTPUser,))
credentials = base64.encodestring('Not Good Credentials')
d = self.assertResponse((root, 'http://localhost/',
{'authorization': ('basic', credentials)}),
(401,
{'WWW-Authenticate': [('basic',
{'realm': "test realm"})]},
None))
return d
def test_anonymousAuthentication(self):
"""
If our portal has a credentials checker for IAnonymous credentials
authentication succeeds if no Authorization header is present
"""
self.portal.registerChecker(checkers.AllowAnonymousAccess())
self.protectedResource.responseText = "Anonymous access allowed"
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces=(IHTTPUser,))
def _checkRequest(ign):
self.assertEquals(
self.protectedResource.request.avatar.username,
'anonymous')
d = self.assertResponse((root, 'http://localhost/',
{}),
(200,
{},
"Anonymous access allowed"))
d.addCallback(_checkRequest)
return d
def test_forceAuthentication(self):
"""
Test that if an HTTPError with an Unauthorized status code is raised
from within our protected resource, we add the WWW-Authenticate
headers if they do not already exist.
"""
self.portal.registerChecker(checkers.AllowAnonymousAccess())
nonAnonResource = NonAnonymousResource()
nonAnonResource.responseText = "We don't like anonymous users"
root = wrapper.HTTPAuthResource(nonAnonResource,
[self.credFactory],
self.portal,
interfaces = (IHTTPUser,))
def _tryAuthenticate(result):
credentials = base64.encodestring('username:password')
d2 = self.assertResponse(
(root, 'http://localhost/',
{'authorization': ('basic', credentials)}),
(200,
{},
"We don't like anonymous users"))
return d2
d = self.assertResponse(
(root, 'http://localhost/',
{}),
(401,
{'WWW-Authenticate': [('basic',
{'realm': "test realm"})]},
None))
d.addCallback(_tryAuthenticate)
return d
def test_responseFilterDoesntClobberHeaders(self):
"""
Test that if an UNAUTHORIZED response is returned and
already has 'WWW-Authenticate' headers we don't add them.
"""
self.portal.registerChecker(checkers.AllowAnonymousAccess())
nonAnonResource = NonAnonymousResource()
nonAnonResource.responseText = "We don't like anonymous users"
nonAnonResource.sendOwnHeaders = True
root = wrapper.HTTPAuthResource(nonAnonResource,
[self.credFactory],
self.portal,
interfaces = (IHTTPUser,))
d = self.assertResponse(
(root, 'http://localhost/',
{}),
(401,
{'WWW-Authenticate': [('basic',
{'realm': "foo"})]},
None))
return d
def test_renderHTTP(self):
"""
Test that if the renderHTTP method is ever called we authenticate
the request and delegate rendering to the wrapper.
"""
self.protectedResource.responseText = "I hope you can see me."
self.protectedResource.addSlash = True
root = wrapper.HTTPAuthResource(self.protectedResource,
[self.credFactory],
self.portal,
interfaces = (IHTTPUser,))
request = SimpleRequest(None, "GET", "/")
request.prepath = ['']
def _gotSecondResponse(response):
self.assertEquals(response.code, 200)
self.assertEquals(str(response.stream.read()),
"I hope you can see me.")
def _gotResponse(exception):
response = exception.response
self.assertEquals(response.code, 401)
self.failUnless(response.headers.hasHeader('WWW-Authenticate'))
self.assertEquals(response.headers.getHeader('WWW-Authenticate'),
[('basic', {'realm': "test realm"})])
credentials = base64.encodestring('username:password')
request.headers.setHeader('authorization',
['basic', credentials])
d = root.renderHTTP(request)
d.addCallback(_gotSecondResponse)
d = self.assertFailure(root.renderHTTP(request),
http.HTTPError)
d.addCallback(_gotResponse)
return d
calendarserver-5.2+dfsg/twext/web2/test/test_http_headers.py 0000644 0001750 0001750 00000100706 12107006303 023357 0 ustar rahul rahul # Copyright (c) 2008 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twext.web2.http_headers}.
"""
from twisted.trial import unittest
import random
import time
from twext.web2 import http_headers
from twext.web2.http_headers import Cookie, HeaderHandler, quoteString, generateKeyValues
from twisted.python import util
class parsedvalue:
"""Marker class"""
def __init__(self, raw):
self.raw = raw
def __eq__(self, other):
return isinstance(other, parsedvalue) and other.raw == self.raw
class HeadersAPITest(unittest.TestCase):
"""Make sure the public API exists and works."""
def testRaw(self):
rawvalue = ("value1", "value2")
h = http_headers.Headers(handler=HeaderHandler(parsers={}, generators={}))
h.setRawHeaders("test", rawvalue)
self.assertEquals(h.hasHeader("test"), True)
self.assertEquals(h.getRawHeaders("test"), rawvalue)
self.assertEquals(list(h.getAllRawHeaders()), [('Test', rawvalue)])
self.assertEquals(h.getRawHeaders("foobar"), None)
h.removeHeader("test")
self.assertEquals(h.getRawHeaders("test"), None)
def testParsed(self):
parsed = parsedvalue(("value1", "value2"))
h = http_headers.Headers(handler=HeaderHandler(parsers={}, generators={}))
h.setHeader("test", parsed)
self.assertEquals(h.hasHeader("test"), True)
self.assertEquals(h.getHeader("test"), parsed)
self.assertEquals(h.getHeader("foobar"), None)
h.removeHeader("test")
self.assertEquals(h.getHeader("test"), None)
def testParsedAndRaw(self):
def parse(raw):
return parsedvalue(raw)
def generate(parsed):
return parsed.raw
rawvalue = ("value1", "value2")
rawvalue2 = ("value3", "value4")
handler = HeaderHandler(parsers={'test': (parse,)},
generators={'test': (generate,)})
h = http_headers.Headers(handler=handler)
h.setRawHeaders("test", rawvalue)
self.assertEquals(h.getHeader("test"), parsedvalue(rawvalue))
h.setHeader("test", parsedvalue(rawvalue2))
self.assertEquals(h.getRawHeaders("test"), rawvalue2)
# Check the initializers
h = http_headers.Headers(rawHeaders={"test": rawvalue},
handler=handler)
self.assertEquals(h.getHeader("test"), parsedvalue(rawvalue))
h = http_headers.Headers({"test": parsedvalue(rawvalue2)},
handler=handler)
self.assertEquals(h.getRawHeaders("test"), rawvalue2)
def testImmutable(self):
h = http_headers.Headers(handler=HeaderHandler(parsers={}, generators={}))
h.makeImmutable()
self.assertRaises(AttributeError, h.setRawHeaders, "test", [1])
self.assertRaises(AttributeError, h.setHeader, "test", 1)
self.assertRaises(AttributeError, h.removeHeader, "test")
class TokenizerTest(unittest.TestCase):
"""Test header list parsing functions."""
def testParse(self):
parser = lambda val: list(http_headers.tokenize([val, ]))
Token = http_headers.Token
tests = (('foo,bar', ['foo', Token(','), 'bar']),
('FOO,BAR', ['foo', Token(','), 'bar']),
(' \t foo \t bar \t , \t baz ', ['foo', Token(' '), 'bar', Token(','), 'baz']),
('()<>@,;:\\/[]?={}', [Token('('), Token(')'), Token('<'), Token('>'), Token('@'), Token(','), Token(';'), Token(':'), Token('\\'), Token('/'), Token('['), Token(']'), Token('?'), Token('='), Token('{'), Token('}')]),
(' "foo" ', ['foo']),
('"FOO(),\\"BAR,"', ['FOO(),"BAR,']))
raiseTests = ('"open quote', '"ending \\', "control character: \x127", "\x00", "\x1f")
for test, result in tests:
self.assertEquals(parser(test), result)
for test in raiseTests:
self.assertRaises(ValueError, parser, test)
def testGenerate(self):
pass
def testRoundtrip(self):
pass
def atSpecifiedTime(when, func):
def inner(*a, **kw):
orig = time.time
time.time = lambda: when
try:
return func(*a, **kw)
finally:
time.time = orig
return util.mergeFunctionMetadata(func, inner)
def parseHeader(name, val):
head = http_headers.Headers(handler=http_headers.DefaultHTTPHandler)
head.setRawHeaders(name, val)
return head.getHeader(name)
parseHeader = atSpecifiedTime(999999990, parseHeader) # Sun, 09 Sep 2001 01:46:30 GMT
def generateHeader(name, val):
head = http_headers.Headers(handler=http_headers.DefaultHTTPHandler)
head.setHeader(name, val)
return head.getRawHeaders(name)
generateHeader = atSpecifiedTime(999999990, generateHeader) # Sun, 09 Sep 2001 01:46:30 GMT
class HeaderParsingTestBase(unittest.TestCase):
def runRoundtripTest(self, headername, table):
"""
Perform some assertions about the behavior of parsing and
generating HTTP headers. Specifically: parse an HTTP header
value, assert that the parsed form contains all the available
information with the correct structure; generate the HTTP
header value from the parsed form, assert that it contains
certain literal strings; finally, re-parse the generated HTTP
header value and assert that the resulting structured data is
the same as the first-pass parsed form.
@type headername: C{str}
@param headername: The name of the HTTP header L{table} contains values for.
@type table: A sequence of tuples describing inputs to and
outputs from header parsing and generation. The tuples may be
either 2 or 3 elements long. In either case: the first
element is a string representing an HTTP-format header value;
the second element is a dictionary mapping names of parameters
to values of those parameters (the parsed form of the header).
If there is a third element, it is a list of strings which
must occur exactly in the HTTP header value
string which is re-generated from the parsed form.
"""
for row in table:
if len(row) == 2:
rawHeaderInput, parsedHeaderData = row
requiredGeneratedElements = []
elif len(row) == 3:
rawHeaderInput, parsedHeaderData, requiredGeneratedElements = row
assert isinstance(requiredGeneratedElements, list)
# parser
parsed = parseHeader(headername, [rawHeaderInput, ])
self.assertEquals(parsed, parsedHeaderData)
regeneratedHeaderValue = generateHeader(headername, parsed)
if requiredGeneratedElements:
# generator
for regeneratedElement in regeneratedHeaderValue:
reqEle = requiredGeneratedElements[regeneratedHeaderValue.index(regeneratedElement)]
elementIndex = regeneratedElement.find(reqEle)
self.assertNotEqual(
elementIndex, -1,
"%r did not appear in generated HTTP header %r: %r" % (reqEle,
headername,
regeneratedElement))
# parser/generator
reparsed = parseHeader(headername, regeneratedHeaderValue)
self.assertEquals(parsed, reparsed)
def invalidParseTest(self, headername, values):
for val in values:
parsed = parseHeader(headername, val)
self.assertEquals(parsed, None)
class GeneralHeaderParsingTests(HeaderParsingTestBase):
def testCacheControl(self):
table = (
("no-cache",
{'no-cache': None}),
("no-cache, no-store, max-age=5, max-stale=3, min-fresh=5, no-transform, only-if-cached, blahblah-extension-thingy",
{'no-cache': None,
'no-store': None,
'max-age': 5,
'max-stale': 3,
'min-fresh': 5,
'no-transform': None,
'only-if-cached': None,
'blahblah-extension-thingy': None}),
("max-stale",
{'max-stale': None}),
("public, private, no-cache, no-store, no-transform, must-revalidate, proxy-revalidate, max-age=5, s-maxage=10, blahblah-extension-thingy",
{'public': None,
'private': None,
'no-cache': None,
'no-store': None,
'no-transform': None,
'must-revalidate': None,
'proxy-revalidate': None,
'max-age': 5,
's-maxage': 10,
'blahblah-extension-thingy': None}),
('private="Set-Cookie, Set-Cookie2", no-cache="PROXY-AUTHENTICATE"',
{'private': ['set-cookie', 'set-cookie2'],
'no-cache': ['proxy-authenticate']},
['private="Set-Cookie, Set-Cookie2"', 'no-cache="Proxy-Authenticate"']),
)
self.runRoundtripTest("Cache-Control", table)
def testConnection(self):
table = (
("close", ['close', ]),
("close, foo-bar", ['close', 'foo-bar'])
)
self.runRoundtripTest("Connection", table)
def testDate(self):
# Don't need major tests since the datetime parser has its own tests
self.runRoundtripTest("Date", (("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),))
# def testPragma(self):
# fail
# def testTrailer(self):
# fail
def testTransferEncoding(self):
table = (
('chunked', ['chunked']),
('gzip, chunked', ['gzip', 'chunked'])
)
self.runRoundtripTest("Transfer-Encoding", table)
# def testUpgrade(self):
# fail
# def testVia(self):
# fail
# def testWarning(self):
# fail
class RequestHeaderParsingTests(HeaderParsingTestBase):
# FIXME test ordering too.
def testAccept(self):
table = (
("audio/*;q=0.2, audio/basic",
{http_headers.MimeType('audio', '*'): 0.2,
http_headers.MimeType('audio', 'basic'): 1.0}),
("text/plain;q=0.5, text/html, text/x-dvi;q=0.8, text/x-c",
{http_headers.MimeType('text', 'plain'): 0.5,
http_headers.MimeType('text', 'html'): 1.0,
http_headers.MimeType('text', 'x-dvi'): 0.8,
http_headers.MimeType('text', 'x-c'): 1.0}),
("text/*, text/html, text/html;level=1, */*",
{http_headers.MimeType('text', '*'): 1.0,
http_headers.MimeType('text', 'html'): 1.0,
http_headers.MimeType('text', 'html', (('level', '1'),)): 1.0,
http_headers.MimeType('*', '*'): 1.0}),
("text/*;q=0.3, text/html;q=0.7, text/html;level=1, text/html;level=2;q=0.4, */*;q=0.5",
{http_headers.MimeType('text', '*'): 0.3,
http_headers.MimeType('text', 'html'): 0.7,
http_headers.MimeType('text', 'html', (('level', '1'),)): 1.0,
http_headers.MimeType('text', 'html', (('level', '2'),)): 0.4,
http_headers.MimeType('*', '*'): 0.5}),
)
self.runRoundtripTest("Accept", table)
def testAcceptCharset(self):
table = (
("iso-8859-5, unicode-1-1;q=0.8",
{'iso-8859-5': 1.0, 'iso-8859-1': 1.0, 'unicode-1-1': 0.8},
["iso-8859-5", "unicode-1-1;q=0.8", "iso-8859-1"]),
("iso-8859-1;q=0.7",
{'iso-8859-1': 0.7}),
("*;q=.7",
{'*': 0.7},
["*;q=0.7"]),
("",
{'iso-8859-1': 1.0},
["iso-8859-1"]), # Yes this is an actual change -- we'll say that's okay. :)
)
self.runRoundtripTest("Accept-Charset", table)
def testAcceptEncoding(self):
table = (
("compress, gzip",
{'compress': 1.0, 'gzip': 1.0, 'identity': 0.0001}),
("",
{'identity': 0.0001}),
("*",
{'*': 1}),
("compress;q=0.5, gzip;q=1.0",
{'compress': 0.5, 'gzip': 1.0, 'identity': 0.0001},
["compress;q=0.5", "gzip"]),
("gzip;q=1.0, identity;q=0.5, *;q=0",
{'gzip': 1.0, 'identity': 0.5, '*': 0},
["gzip", "identity;q=0.5", "*;q=0"]),
)
self.runRoundtripTest("Accept-Encoding", table)
def testAcceptLanguage(self):
table = (
("da, en-gb;q=0.8, en;q=0.7",
{'da': 1.0, 'en-gb': 0.8, 'en': 0.7}),
("*",
{'*': 1}),
)
self.runRoundtripTest("Accept-Language", table)
def testAuthorization(self):
table = (
("Basic dXNlcm5hbWU6cGFzc3dvcmQ=",
("basic", "dXNlcm5hbWU6cGFzc3dvcmQ="),
["basic dXNlcm5hbWU6cGFzc3dvcmQ="]),
('Digest nonce="bar", realm="foo", username="baz", response="bax"',
('digest', 'nonce="bar", realm="foo", username="baz", response="bax"'),
['digest', 'nonce="bar"', 'realm="foo"', 'username="baz"', 'response="bax"'])
)
self.runRoundtripTest("Authorization", table)
def testCookie(self):
table = (
('name=value', [Cookie('name', 'value')]),
('"name"="value"', [Cookie('"name"', '"value"')]),
('name,"blah=value,"', [Cookie('name,"blah', 'value,"')]),
('name,"blah = value," ', [Cookie('name,"blah', 'value,"')], ['name,"blah=value,"']),
("`~!@#$%^&*()-_+[{]}\\|:'\",<.>/?=`~!@#$%^&*()-_+[{]}\\|:'\",<.>/?", [Cookie("`~!@#$%^&*()-_+[{]}\\|:'\",<.>/?", "`~!@#$%^&*()-_+[{]}\\|:'\",<.>/?")]),
('name,"blah = value," ; name2=val2',
[Cookie('name,"blah', 'value,"'), Cookie('name2', 'val2')],
['name,"blah=value,"', 'name2=val2']),
)
self.runRoundtripTest("Cookie", table)
# newstyle RFC2965 Cookie
table2 = (
('$Version="1";'
'name="value";$Path="/foo";$Domain="www.local";$Port="80,8000";'
'name2="value"',
[Cookie('name', 'value', path='/foo', domain='www.local', ports=(80, 8000), version=1), Cookie('name2', 'value', version=1)]),
('$Version="1";'
'name="value";$Port',
[Cookie('name', 'value', ports=(), version=1)]),
('$Version = 1, NAME = "qq\\"qq",Frob=boo',
[Cookie('name', 'qq"qq', version=1), Cookie('frob', 'boo', version=1)],
['$Version="1";name="qq\\"qq";frob="boo"']),
)
self.runRoundtripTest("Cookie", table2)
# Generate only!
# make headers by combining oldstyle and newstyle cookies
table3 = (
([Cookie('name', 'value'), Cookie('name2', 'value2', version=1)],
'$Version="1";name=value;name2="value2"'),
([Cookie('name', 'value', path="/foo"), Cookie('name2', 'value2', domain="bar.baz", version=1)],
'$Version="1";name=value;$Path="/foo";name2="value2";$Domain="bar.baz"'),
([Cookie('invalid,"name', 'value'), Cookie('name2', 'value2', version=1)],
'$Version="1";name2="value2"'),
([Cookie('name', 'qq"qq'), Cookie('name2', 'value2', version=1)],
'$Version="1";name="qq\\"qq";name2="value2"'),
)
for row in table3:
self.assertEquals(generateHeader("Cookie", row[0]), [row[1], ])
def testSetCookie(self):
table = (
('name,"blah=value,; expires=Sun, 09 Sep 2001 01:46:40 GMT; path=/foo; domain=bar.baz; secure',
[Cookie('name,"blah', 'value,', expires=1000000000, path="/foo", domain="bar.baz", secure=True)]),
('name,"blah = value, ; expires="Sun, 09 Sep 2001 01:46:40 GMT"',
[Cookie('name,"blah', 'value,', expires=1000000000)],
['name,"blah=value,', 'expires=Sun, 09 Sep 2001 01:46:40 GMT']),
)
self.runRoundtripTest("Set-Cookie", table)
def testSetCookie2(self):
table = (
('name="value"; Comment="YadaYada"; CommentURL="http://frobnotz/"; Discard; Domain="blah.blah"; Max-Age=10; Path="/foo"; Port="80,8080"; Secure; Version="1"',
[Cookie("name", "value", comment="YadaYada", commenturl="http://frobnotz/", discard=True, domain="blah.blah", expires=1000000000, path="/foo", ports=(80, 8080), secure=True, version=1)]),
)
self.runRoundtripTest("Set-Cookie2", table)
def testExpect(self):
table = (
("100-continue",
{"100-continue": (None,)}),
('foobar=twiddle',
{'foobar': ('twiddle',)}),
("foo=bar;a=b;c",
{'foo': ('bar', ('a', 'b'), ('c', None))})
)
self.runRoundtripTest("Expect", table)
def testPrefer(self):
table = (
("wait",
[("wait", None, [])]),
("return = representation",
[("return", "representation", [])]),
("return =minimal;arg1;arg2=val2",
[("return", "minimal", [("arg1", None), ("arg2", "val2")])]),
)
self.runRoundtripTest("Prefer", table)
def testFrom(self):
self.runRoundtripTest("From", (("webmaster@w3.org", "webmaster@w3.org"),))
def testHost(self):
self.runRoundtripTest("Host", (("www.w3.org", "www.w3.org"),))
def testIfMatch(self):
table = (
('"xyzzy"', [http_headers.ETag('xyzzy')]),
('"xyzzy", "r2d2xxxx", "c3piozzzz"', [http_headers.ETag('xyzzy'),
http_headers.ETag('r2d2xxxx'),
http_headers.ETag('c3piozzzz')]),
('*', ['*']),
)
self.runRoundtripTest("If-Match", table)
def testIfModifiedSince(self):
# Don't need major tests since the datetime parser has its own test
# Just test stupid ; length= brokenness.
table = (
("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),
("Sun, 09 Sep 2001 01:46:40 GMT; length=500", 1000000000, ["Sun, 09 Sep 2001 01:46:40 GMT"]),
)
self.runRoundtripTest("If-Modified-Since", table)
def testIfNoneMatch(self):
table = (
('"xyzzy"', [http_headers.ETag('xyzzy')]),
('W/"xyzzy", "r2d2xxxx", "c3piozzzz"', [http_headers.ETag('xyzzy', weak=True),
http_headers.ETag('r2d2xxxx'),
http_headers.ETag('c3piozzzz')]),
('W/"xyzzy", W/"r2d2xxxx", W/"c3piozzzz"', [http_headers.ETag('xyzzy', weak=True),
http_headers.ETag('r2d2xxxx', weak=True),
http_headers.ETag('c3piozzzz', weak=True)]),
('*', ['*']),
)
self.runRoundtripTest("If-None-Match", table)
def testIfRange(self):
table = (
('"xyzzy"', http_headers.ETag('xyzzy')),
('W/"xyzzy"', http_headers.ETag('xyzzy', weak=True)),
('W/"xyzzy"', http_headers.ETag('xyzzy', weak=True)),
("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),
)
self.runRoundtripTest("If-Range", table)
def testIfUnmodifiedSince(self):
self.runRoundtripTest("If-Unmodified-Since", (("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),))
def testMaxForwards(self):
self.runRoundtripTest("Max-Forwards", (("15", 15),))
# def testProxyAuthorize(self):
# fail
def testRange(self):
table = (
("bytes=0-499", ('bytes', [(0, 499), ])),
("bytes=500-999", ('bytes', [(500, 999), ])),
("bytes=-500", ('bytes', [(None, 500), ])),
("bytes=9500-", ('bytes', [(9500, None), ])),
("bytes=0-0,-1", ('bytes', [(0, 0), (None, 1)])),
)
self.runRoundtripTest("Range", table)
def testReferer(self):
self.runRoundtripTest("Referer", (("http://www.w3.org/hypertext/DataSources/Overview.html",
"http://www.w3.org/hypertext/DataSources/Overview.html"),))
def testTE(self):
table = (
("deflate", {'deflate': 1}),
("", {}),
("trailers, deflate;q=0.5", {'trailers': 1, 'deflate': 0.5}),
)
self.runRoundtripTest("TE", table)
def testUserAgent(self):
self.runRoundtripTest("User-Agent", (("CERN-LineMode/2.15 libwww/2.17b3", "CERN-LineMode/2.15 libwww/2.17b3"),))
class ResponseHeaderParsingTests(HeaderParsingTestBase):
def testAcceptRanges(self):
self.runRoundtripTest("Accept-Ranges", (("bytes", ["bytes"]), ("none", ["none"])))
def testAge(self):
self.runRoundtripTest("Age", (("15", 15),))
def testETag(self):
table = (
('"xyzzy"', http_headers.ETag('xyzzy')),
('W/"xyzzy"', http_headers.ETag('xyzzy', weak=True)),
('""', http_headers.ETag('')),
)
self.runRoundtripTest("ETag", table)
def testLocation(self):
self.runRoundtripTest("Location", (("http://www.w3.org/pub/WWW/People.htm",
"http://www.w3.org/pub/WWW/People.htm"),))
# def testProxyAuthenticate(self):
# fail
def testRetryAfter(self):
# time() is always 999999990 when being tested.
table = (
("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000, ["10"]),
("120", 999999990 + 120),
)
self.runRoundtripTest("Retry-After", table)
def testServer(self):
self.runRoundtripTest("Server", (("CERN/3.0 libwww/2.17", "CERN/3.0 libwww/2.17"),))
def testVary(self):
table = (
("*", ["*"]),
("Accept, Accept-Encoding", ["accept", "accept-encoding"], ["accept", "accept-encoding"])
)
self.runRoundtripTest("Vary", table)
def testWWWAuthenticate(self):
digest = ('Digest realm="digest realm", nonce="bAr", qop="auth"',
[('Digest', {'realm': 'digest realm', 'nonce': 'bAr',
'qop': 'auth'})],
['Digest', 'realm="digest realm"',
'nonce="bAr"', 'qop="auth"'])
basic = ('Basic realm="foo"',
[('Basic', {'realm': 'foo'})], ['Basic', 'realm="foo"'])
ntlm = ('NTLM',
[('NTLM', {})], ['NTLM', ''])
negotiate = ('Negotiate SomeGssAPIData',
[('Negotiate', 'SomeGssAPIData')],
['Negotiate', 'SomeGssAPIData'])
table = (digest,
basic,
(digest[0] + ', ' + basic[0],
digest[1] + basic[1],
[digest[2], basic[2]]),
ntlm,
negotiate,
(ntlm[0] + ', ' + basic[0],
ntlm[1] + basic[1],
[ntlm[2], basic[2]]),
(digest[0] + ', ' + negotiate[0],
digest[1] + negotiate[1],
[digest[2], negotiate[2]]),
(negotiate[0] + ', ' + negotiate[0],
negotiate[1] + negotiate[1],
[negotiate[2] + negotiate[2]]),
(ntlm[0] + ', ' + ntlm[0],
ntlm[1] + ntlm[1],
[ntlm[2], ntlm[2]]),
(basic[0] + ', ' + ntlm[0],
basic[1] + ntlm[1],
[basic[2], ntlm[2]]),
)
# runRoundtripTest doesn't work because we don't generate a single
# header
headername = 'WWW-Authenticate'
for row in table:
rawHeaderInput, parsedHeaderData, requiredGeneratedElements = row
parsed = parseHeader(headername, [rawHeaderInput, ])
self.assertEquals(parsed, parsedHeaderData)
regeneratedHeaderValue = generateHeader(headername, parsed)
for regeneratedElement in regeneratedHeaderValue:
requiredElements = requiredGeneratedElements[
regeneratedHeaderValue.index(
regeneratedElement)]
for reqEle in requiredElements:
elementIndex = regeneratedElement.find(reqEle)
self.assertNotEqual(
elementIndex, -1,
"%r did not appear in generated HTTP header %r: %r" % (reqEle,
headername,
regeneratedElement))
# parser/generator
reparsed = parseHeader(headername, regeneratedHeaderValue)
self.assertEquals(parsed, reparsed)
class EntityHeaderParsingTests(HeaderParsingTestBase):
def testAllow(self):
# Allow is a silly case-sensitive header unlike all the rest
table = (
("GET", ['GET', ]),
("GET, HEAD, PUT", ['GET', 'HEAD', 'PUT']),
)
self.runRoundtripTest("Allow", table)
def testContentEncoding(self):
table = (
("gzip", ['gzip', ]),
)
self.runRoundtripTest("Content-Encoding", table)
def testContentLanguage(self):
table = (
("da", ['da', ]),
("mi, en", ['mi', 'en']),
)
self.runRoundtripTest("Content-Language", table)
def testContentLength(self):
self.runRoundtripTest("Content-Length", (("15", 15),))
self.invalidParseTest("Content-Length", ("asdf",))
def testContentLocation(self):
self.runRoundtripTest("Content-Location",
(("http://www.w3.org/pub/WWW/People.htm",
"http://www.w3.org/pub/WWW/People.htm"),))
def testContentMD5(self):
self.runRoundtripTest("Content-MD5", (("Q2hlY2sgSW50ZWdyaXR5IQ==", "Check Integrity!"),))
self.invalidParseTest("Content-MD5", ("sdlaksjdfhlkaj",))
def testContentRange(self):
table = (
("bytes 0-499/1234", ("bytes", 0, 499, 1234)),
("bytes 500-999/1234", ("bytes", 500, 999, 1234)),
("bytes 500-1233/1234", ("bytes", 500, 1233, 1234)),
("bytes 734-1233/1234", ("bytes", 734, 1233, 1234)),
("bytes 734-1233/*", ("bytes", 734, 1233, None)),
("bytes */1234", ("bytes", None, None, 1234)),
("bytes */*", ("bytes", None, None, None))
)
self.runRoundtripTest("Content-Range", table)
def testContentType(self):
table = (
("text/html;charset=iso-8859-4", http_headers.MimeType('text', 'html', (('charset', 'iso-8859-4'),))),
("text/html", http_headers.MimeType('text', 'html')),
)
self.runRoundtripTest("Content-Type", table)
def testContentDisposition(self):
table = (
("attachment;filename=foo.txt", http_headers.MimeDisposition('attachment', (('filename', 'foo.txt'),))),
("inline", http_headers.MimeDisposition('inline')),
)
self.runRoundtripTest("Content-Disposition", table)
def testExpires(self):
self.runRoundtripTest("Expires", (("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),))
# Invalid expires MUST return date in the past.
self.assertEquals(parseHeader("Expires", ["0"]), 0)
self.assertEquals(parseHeader("Expires", ["wejthnaljn"]), 0)
def testLastModified(self):
# Don't need major tests since the datetime parser has its own test
self.runRoundtripTest("Last-Modified", (("Sun, 09 Sep 2001 01:46:40 GMT", 1000000000),))
class DateTimeTest(unittest.TestCase):
"""Test date parsing functions."""
def testParse(self):
timeNum = 784111777
timeStrs = ('Sun, 06 Nov 1994 08:49:37 GMT',
'Sunday, 06-Nov-94 08:49:37 GMT',
'Sun Nov 6 08:49:37 1994',
# Also some non-RFC formats, for good measure.
'Somefakeday 6 Nov 1994 8:49:37',
'6 Nov 1994 8:49:37',
'Sun, 6 Nov 1994 8:49:37',
'6 Nov 1994 8:49:37 GMT',
'06-Nov-94 08:49:37',
'Sunday, 06-Nov-94 08:49:37',
'06-Nov-94 08:49:37 GMT',
'Nov 6 08:49:37 1994',
)
for timeStr in timeStrs:
self.assertEquals(http_headers.parseDateTime(timeStr), timeNum)
# Test 2 Digit date wraparound yuckiness.
self.assertEquals(http_headers.parseDateTime(
'Monday, 11-Oct-04 14:56:50 GMT'), 1097506610)
self.assertEquals(http_headers.parseDateTime(
'Monday, 11-Oct-2004 14:56:50 GMT'), 1097506610)
def testGenerate(self):
self.assertEquals(http_headers.generateDateTime(784111777), 'Sun, 06 Nov 1994 08:49:37 GMT')
def testRoundtrip(self):
for _ignore in range(2000):
randomTime = random.randint(0, 2000000000)
timestr = http_headers.generateDateTime(randomTime)
time2 = http_headers.parseDateTime(timestr)
self.assertEquals(randomTime, time2)
class TestMimeType(unittest.TestCase):
def testEquality(self):
"""Test that various uses of the constructer are equal
"""
kwargMime = http_headers.MimeType('text', 'plain',
key='value',
param=None)
dictMime = http_headers.MimeType('text', 'plain',
{'param': None,
'key': 'value'})
tupleMime = http_headers.MimeType('text', 'plain',
(('param', None),
('key', 'value')))
stringMime = http_headers.MimeType.fromString('text/plain;key=value;param')
self.assertEquals(kwargMime, dictMime)
self.assertEquals(dictMime, tupleMime)
self.assertEquals(kwargMime, tupleMime)
self.assertEquals(kwargMime, stringMime)
class TestMimeDisposition(unittest.TestCase):
def testEquality(self):
"""Test that various uses of the constructer are equal
"""
kwargMime = http_headers.MimeDisposition('attachment',
key='value')
dictMime = http_headers.MimeDisposition('attachment',
{'key': 'value'})
tupleMime = http_headers.MimeDisposition('attachment',
(('key', 'value'),))
stringMime = http_headers.MimeDisposition.fromString('attachment;key=value')
self.assertEquals(kwargMime, dictMime)
self.assertEquals(dictMime, tupleMime)
self.assertEquals(kwargMime, tupleMime)
self.assertEquals(kwargMime, stringMime)
class FormattingUtilityTests(unittest.TestCase):
"""
Tests for various string formatting functionality required to generate
headers.
"""
def test_quoteString(self):
"""
L{quoteString} returns a string which when interpreted according to the
rules for I{quoted-string} (RFC 2616 section 2.2) matches the input
string.
"""
self.assertEqual(
quoteString('a\\b"c'),
'"a\\\\b\\"c"')
def test_generateKeyValues(self):
"""
L{generateKeyValues} accepts an iterable of parameters and returns a
string formatted according to RFC 2045 section 5.1.
"""
self.assertEqual(
generateKeyValues(iter([("foo", "bar"), ("baz", "quux")])),
"foo=bar;baz=quux")
def test_generateKeyValuesNone(self):
"""
L{generateKeyValues} accepts C{None} as the 2nd element of a tuple and
includes just the 1st element in the output without an C{"="}.
"""
self.assertEqual(
generateKeyValues([("foo", None), ("bar", "baz")]),
"foo;bar=baz")
def test_generateKeyValuesQuoting(self):
"""
L{generateKeyValues} quotes the value of the 2nd element of a tuple if
it includes a character which cannot be in an HTTP token as defined in
RFC 2616 section 2.2.
"""
for needsQuote in [' ', '\t', '(', ')', '<', '>', '@', ',', ';', ':',
'\\', '"', '/', '[', ']', '?', '=', '{', '}']:
self.assertEqual(
generateKeyValues([("foo", needsQuote)]),
'foo=%s' % (quoteString(needsQuote),))
calendarserver-5.2+dfsg/twext/web2/test/test_client.py 0000644 0001750 0001750 00000035615 12113213176 022176 0 ustar rahul rahul # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import print_function
"""
Tests for HTTP client.
"""
from twisted.internet import protocol, defer
from twext.web2.client import http
from twext.web2 import http_headers
from twext.web2 import stream
from twext.web2.test.test_http import LoopbackRelay, HTTPTests, TestConnection
class TestServer(protocol.Protocol):
data = ""
done = False
def dataReceived(self, data):
self.data += data
def write(self, data):
self.transport.write(data)
def connectionLost(self, reason):
self.done = True
self.transport.loseConnection()
def loseConnection(self):
self.done = True
self.transport.loseConnection()
class ClientTests(HTTPTests):
def connect(self, logFile=None, maxPipeline=4,
inputTimeOut=60000, betweenRequestsTimeOut=600000):
cxn = TestConnection()
cxn.client = http.HTTPClientProtocol()
cxn.client.inputTimeOut = inputTimeOut
cxn.server = TestServer()
cxn.serverToClient = LoopbackRelay(cxn.client, logFile)
cxn.clientToServer = LoopbackRelay(cxn.server, logFile)
cxn.server.makeConnection(cxn.serverToClient)
cxn.client.makeConnection(cxn.clientToServer)
return cxn
def writeToClient(self, cxn, data):
cxn.server.write(data)
self.iterate(cxn)
def writeLines(self, cxn, lines):
self.writeToClient(cxn, '\r\n'.join(lines))
def assertReceived(self, cxn, expectedStatus, expectedHeaders,
expectedContent=None):
self.iterate(cxn)
headers, content = cxn.server.data.split('\r\n\r\n', 1)
status, headers = headers.split('\r\n', 1)
headers = headers.split('\r\n')
# check status line
self.assertEquals(status, expectedStatus)
# check headers (header order isn't guraunteed so we use
# self.assertIn
for x in headers:
self.assertIn(x, expectedHeaders)
if not expectedContent:
expectedContent = ''
self.assertEquals(content, expectedContent)
def assertDone(self, cxn):
self.iterate(cxn)
self.assertEquals(cxn.server.done, True, 'Connection not closed.')
def assertHeaders(self, resp, expectedHeaders):
headers = list(resp.headers.getAllRawHeaders())
headers.sort()
self.assertEquals(headers, expectedHeaders)
def checkResponse(self, resp, code, headers, length, data):
"""
Assert various things about a response: http code, headers, stream
length, and data in stream.
"""
def gotData(gotdata):
self.assertEquals(gotdata, data)
self.assertEquals(resp.code, code)
self.assertHeaders(resp, headers)
self.assertEquals(resp.stream.length, length)
return defer.maybeDeferred(resp.stream.read).addCallback(gotData)
class TestHTTPClient(ClientTests):
"""
Test that the http client works.
"""
def test_simpleRequest(self):
"""
Your basic simple HTTP Request.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('GET', '/', None, None)
d = cxn.client.submitRequest(req).addCallback(self.checkResponse, 200, [], 10, '1234567890')
self.assertReceived(cxn, 'GET / HTTP/1.1',
['Connection: close'])
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'Content-Length: 10',
'Connection: close',
'',
'1234567890'))
return d.addCallback(lambda _: self.assertDone(cxn))
def test_delayedContent(self):
"""
Make sure that the client returns the response object as soon as the
headers are received, even if the data hasn't arrived yet.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('GET', '/', None, None)
def gotData(data):
self.assertEquals(data, '1234567890')
def gotResp(resp):
self.assertEquals(resp.code, 200)
self.assertHeaders(resp, [])
self.assertEquals(resp.stream.length, 10)
self.writeToClient(cxn, '1234567890')
return defer.maybeDeferred(resp.stream.read).addCallback(gotData)
d = cxn.client.submitRequest(req).addCallback(gotResp)
self.assertReceived(cxn, 'GET / HTTP/1.1',
['Connection: close'])
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'Content-Length: 10',
'Connection: close',
'\r\n'))
return d.addCallback(lambda _: self.assertDone(cxn))
def test_prematurePipelining(self):
"""
Ensure that submitting a second request before it's allowed results
in an AssertionError.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('GET', '/', None, None)
req2 = http.ClientRequest('GET', '/bar', None, None)
d = cxn.client.submitRequest(req, closeAfter=False).addCallback(
self.checkResponse, 200, [], 0, None)
self.assertRaises(AssertionError,
cxn.client.submitRequest, req2)
self.assertReceived(cxn, 'GET / HTTP/1.1',
['Connection: Keep-Alive'])
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'Content-Length: 0',
'Connection: close',
'\r\n'))
return d
def test_userHeaders(self):
"""
Make sure that headers get through in both directions.
"""
cxn = self.connect(inputTimeOut=None)
def submitNext(_):
headers = http_headers.Headers(
headers={'Accept-Language': {'en': 1.0}},
rawHeaders={'X-My-Other-Header': ['socks']})
req = http.ClientRequest('GET', '/', headers, None)
cxn.server.data = ''
d = cxn.client.submitRequest(req, closeAfter=True)
self.assertReceived(cxn, 'GET / HTTP/1.1',
['Connection: close',
'X-My-Other-Header: socks',
'Accept-Language: en'])
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'Content-Length: 0',
'Connection: close',
'\r\n'))
return d
req = http.ClientRequest('GET', '/',
{'Accept-Language': {'en': 1.0}}, None)
d = cxn.client.submitRequest(req, closeAfter=False).addCallback(
self.checkResponse, 200, [('X-Foobar', ['Yes'])], 0, None).addCallback(
submitNext)
self.assertReceived(cxn, 'GET / HTTP/1.1',
['Connection: Keep-Alive',
'Accept-Language: en'])
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'Content-Length: 0',
'X-Foobar: Yes',
'\r\n'))
return d.addCallback(lambda _: self.assertDone(cxn))
def test_streamedUpload(self):
"""
Make sure that sending request content works.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('PUT', '/foo', None, 'Helloooo content')
d = cxn.client.submitRequest(req).addCallback(self.checkResponse, 202, [], 0, None)
self.assertReceived(cxn, 'PUT /foo HTTP/1.1',
['Connection: close',
'Content-Length: 16'],
'Helloooo content')
self.writeLines(cxn, ('HTTP/1.1 202 Accepted',
'Content-Length: 0',
'Connection: close',
'\r\n'))
return d.addCallback(lambda _: self.assertDone(cxn))
def test_sentHead(self):
"""
Ensure that HEAD requests work, and return Content-Length.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('HEAD', '/', None, None)
d = cxn.client.submitRequest(req).addCallback(self.checkResponse, 200, [('Content-Length', ['5'])], 0, None)
self.assertReceived(cxn, 'HEAD / HTTP/1.1',
['Connection: close'])
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'Connection: close',
'Content-Length: 5',
'',
'Pants')) # bad server
return d.addCallback(lambda _: self.assertDone(cxn))
def test_sentHeadKeepAlive(self):
"""
Ensure that keepalive works right after a HEAD request.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('HEAD', '/', None, None)
didIt = [0]
def gotData(data):
self.assertEquals(data, None)
def gotResp(resp):
self.assertEquals(resp.code, 200)
self.assertEquals(resp.stream.length, 0)
self.assertHeaders(resp, [])
return defer.maybeDeferred(resp.stream.read).addCallback(gotData)
def submitRequest(second):
if didIt[0]:
return
didIt[0] = second
if second:
keepAlive='close'
else:
keepAlive='Keep-Alive'
cxn.server.data = ''
d = cxn.client.submitRequest(req, closeAfter=second).addCallback(
self.checkResponse, 200, [('Content-Length', ['5'])], 0, None)
self.assertReceived(cxn, 'HEAD / HTTP/1.1',
['Connection: '+ keepAlive])
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'Connection: '+ keepAlive,
'Content-Length: 5',
'\r\n'))
return d.addCallback(lambda _: submitRequest(1))
d = submitRequest(0)
return d.addCallback(lambda _: self.assertDone(cxn))
def test_chunkedUpload(self):
"""
Ensure chunked data is correctly decoded on upload.
"""
cxn = self.connect(inputTimeOut=None)
data = 'Foo bar baz bax'
s = stream.ProducerStream(length=None)
s.write(data)
req = http.ClientRequest('PUT', '/', None, s)
d = cxn.client.submitRequest(req)
s.finish()
self.assertReceived(cxn, 'PUT / HTTP/1.1',
['Connection: close',
'Transfer-Encoding: chunked'],
'%X\r\n%s\r\n0\r\n\r\n' % (len(data), data))
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'Connection: close',
'Content-Length: 0',
'\r\n'))
return d.addCallback(lambda _: self.assertDone(cxn))
class TestEdgeCases(ClientTests):
def test_serverDoesntSendConnectionClose(self):
"""
Check that a lost connection is treated as end of response, if we
requested connection: close, even if the server didn't respond with
connection: close.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('GET', '/', None, None)
d = cxn.client.submitRequest(req).addCallback(self.checkResponse, 200, [], None, 'Some Content')
self.assertReceived(cxn, 'GET / HTTP/1.1',
['Connection: close'])
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'',
'Some Content'))
return d.addCallback(lambda _: self.assertDone(cxn))
def test_serverIsntHttp(self):
"""
Check that an error is returned if the server doesn't talk HTTP.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('GET', '/', None, None)
def gotResp(r):
print(r)
d = cxn.client.submitRequest(req).addCallback(gotResp)
self.assertFailure(d, http.ProtocolError)
self.writeLines(cxn, ('HTTP-NG/1.1 200 OK',
'\r\n'))
def test_newServer(self):
"""
Check that an error is returned if the server is a new major version.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('GET', '/', None, None)
d = cxn.client.submitRequest(req)
self.assertFailure(d, http.ProtocolError)
self.writeLines(cxn, ('HTTP/2.3 200 OK',
'\r\n'))
def test_shortStatus(self):
"""
Check that an error is returned if the response line is invalid.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('GET', '/', None, None)
d = cxn.client.submitRequest(req)
self.assertFailure(d, http.ProtocolError)
self.writeLines(cxn, ('HTTP/1.1 200',
'\r\n'))
def test_errorReadingRequestStream(self):
"""
Ensure that stream errors are propagated to the response.
"""
cxn = self.connect(inputTimeOut=None)
s = stream.ProducerStream()
s.write('Foo')
req = http.ClientRequest('GET', '/', None, s)
d = cxn.client.submitRequest(req)
s.finish(IOError('Test Error'))
return self.assertFailure(d, IOError)
def test_connectionLost(self):
"""
Check that closing the connection is propagated to the response
deferred.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('GET', '/', None, None)
d = cxn.client.submitRequest(req)
self.assertReceived(cxn, 'GET / HTTP/1.1',
['Connection: close'])
cxn.client.connectionLost(ValueError("foo"))
return self.assertFailure(d, ValueError)
def test_connectionLostAfterHeaders(self):
"""
Test that closing the connection after headers are sent is propagated
to the response stream.
"""
cxn = self.connect(inputTimeOut=None)
req = http.ClientRequest('GET', '/', None, None)
d = cxn.client.submitRequest(req)
self.assertReceived(cxn, 'GET / HTTP/1.1',
['Connection: close'])
self.writeLines(cxn, ('HTTP/1.1 200 OK',
'Content-Length: 10',
'Connection: close',
'\r\n'))
cxn.client.connectionLost(ValueError("foo"))
def cb(response):
return self.assertFailure(response.stream.read(), ValueError)
d.addCallback(cb)
return d
calendarserver-5.2+dfsg/twext/web2/test/server.pem 0000644 0001750 0001750 00000004000 11337102650 021301 0 ustar rahul rahul -----BEGIN CERTIFICATE-----
MIIDBjCCAm+gAwIBAgIBATANBgkqhkiG9w0BAQQFADB7MQswCQYDVQQGEwJTRzER
MA8GA1UEChMITTJDcnlwdG8xFDASBgNVBAsTC00yQ3J5cHRvIENBMSQwIgYDVQQD
ExtNMkNyeXB0byBDZXJ0aWZpY2F0ZSBNYXN0ZXIxHTAbBgkqhkiG9w0BCQEWDm5n
cHNAcG9zdDEuY29tMB4XDTAwMDkxMDA5NTEzMFoXDTAyMDkxMDA5NTEzMFowUzEL
MAkGA1UEBhMCU0cxETAPBgNVBAoTCE0yQ3J5cHRvMRIwEAYDVQQDEwlsb2NhbGhv
c3QxHTAbBgkqhkiG9w0BCQEWDm5ncHNAcG9zdDEuY29tMFwwDQYJKoZIhvcNAQEB
BQADSwAwSAJBAKy+e3dulvXzV7zoTZWc5TzgApr8DmeQHTYC8ydfzH7EECe4R1Xh
5kwIzOuuFfn178FBiS84gngaNcrFi0Z5fAkCAwEAAaOCAQQwggEAMAkGA1UdEwQC
MAAwLAYJYIZIAYb4QgENBB8WHU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRl
MB0GA1UdDgQWBBTPhIKSvnsmYsBVNWjj0m3M2z0qVTCBpQYDVR0jBIGdMIGagBT7
hyNp65w6kxXlxb8pUU/+7Sg4AaF/pH0wezELMAkGA1UEBhMCU0cxETAPBgNVBAoT
CE0yQ3J5cHRvMRQwEgYDVQQLEwtNMkNyeXB0byBDQTEkMCIGA1UEAxMbTTJDcnlw
dG8gQ2VydGlmaWNhdGUgTWFzdGVyMR0wGwYJKoZIhvcNAQkBFg5uZ3BzQHBvc3Qx
LmNvbYIBADANBgkqhkiG9w0BAQQFAAOBgQA7/CqT6PoHycTdhEStWNZde7M/2Yc6
BoJuVwnW8YxGO8Sn6UJ4FeffZNcYZddSDKosw8LtPOeWoK3JINjAk5jiPQ2cww++
7QGG/g5NDjxFZNDJP1dGiLAxPW6JXwov4v0FmdzfLOZ01jDcgQQZqEpYlgpuI5JE
WUQ9Ho4EzbYCOQ==
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
MIIBPAIBAAJBAKy+e3dulvXzV7zoTZWc5TzgApr8DmeQHTYC8ydfzH7EECe4R1Xh
5kwIzOuuFfn178FBiS84gngaNcrFi0Z5fAkCAwEAAQJBAIqm/bz4NA1H++Vx5Ewx
OcKp3w19QSaZAwlGRtsUxrP7436QjnREM3Bm8ygU11BjkPVmtrKm6AayQfCHqJoT
ZIECIQDW0BoMoL0HOYM/mrTLhaykYAVqgIeJsPjvkEhTFXWBuQIhAM3deFAvWNu4
nklUQ37XsCT2c9tmNt1LAT+slG2JOTTRAiAuXDtC/m3NYVwyHfFm+zKHRzHkClk2
HjubeEgjpj32AQIhAJqMGTaZVOwevTXvvHwNEH+vRWsAYU/gbx+OQB+7VOcBAiEA
oolb6NMg/R3enNPvS1O4UU1H8wpaF77L4yiSWlE0p4w=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE REQUEST-----
MIIBDTCBuAIBADBTMQswCQYDVQQGEwJTRzERMA8GA1UEChMITTJDcnlwdG8xEjAQ
BgNVBAMTCWxvY2FsaG9zdDEdMBsGCSqGSIb3DQEJARYObmdwc0Bwb3N0MS5jb20w
XDANBgkqhkiG9w0BAQEFAANLADBIAkEArL57d26W9fNXvOhNlZzlPOACmvwOZ5Ad
NgLzJ1/MfsQQJ7hHVeHmTAjM664V+fXvwUGJLziCeBo1ysWLRnl8CQIDAQABoAAw
DQYJKoZIhvcNAQEEBQADQQA7uqbrNTjVWpF6By5ZNPvhZ4YdFgkeXFVWi5ao/TaP
Vq4BG021fJ9nlHRtr4rotpgHDX1rr+iWeHKsx4+5DRSy
-----END CERTIFICATE REQUEST-----
calendarserver-5.2+dfsg/twext/web2/test/test_server.py 0000644 0001750 0001750 00000076160 12165665515 022245 0 ustar rahul rahul # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A test harness for the twext.web2 server.
"""
from zope.interface import implementer
from twisted.python import components
from twext.web2 import http, http_headers, iweb, server
from twext.web2 import resource, stream
from twext.web2.dav.test.util import SimpleRequest
from twisted.trial import unittest
from twisted.internet import reactor, defer, address
class NotResource(object):
"""
Class which does not implement IResource.
Used as an adaptee by L{AdaptionTestCase.test_registered} to test that
if an object which does not provide IResource is adapted to IResource
and there is an adapter to IResource registered, that adapter is used.
"""
@implementer(iweb.IResource)
class ResourceAdapter(object):
"""
Adapter to IResource.
Registered as an adapter from NotResource to IResource so that
L{AdaptionTestCase.test_registered} can test that such an adapter will
be used.
"""
def __init__(self, original):
pass
components.registerAdapter(ResourceAdapter, NotResource, iweb.IResource)
class NotOldResource(object):
"""
Class which does not implement IOldNevowResource or IResource.
Used as an adaptee by L{AdaptionTestCase.test_transitive} to test that
if an object which does not provide IResource or IOldNevowResource is
adapted to IResource and there is an adapter to IOldNevowResource
registered, first that adapter is used, then the included adapter from
IOldNevowResource to IResource is used.
"""
@implementer(iweb.IOldNevowResource)
class OldResourceAdapter(object):
"""
Adapter to IOldNevowResource.
Registered as an adapter from NotOldResource to IOldNevowResource so
that L{AdaptionTestCase.test_transitive} can test that such an adapter
will be used to allow the initial input to be adapted to IResource.
"""
def __init__(self, original):
pass
components.registerAdapter(OldResourceAdapter, NotOldResource, iweb.IOldNevowResource)
class AdaptionTestCase(unittest.TestCase):
"""
Test the adaption of various objects to IResource.
Necessary due to the special implementation of __call__ on IResource
which extends the behavior provided by the base Interface.__call__.
"""
def test_unadaptable(self):
"""
Test that attempting to adapt to IResource an object not adaptable
to IResource raises an exception or returns the specified alternate
object.
"""
class Unadaptable(object):
pass
self.assertRaises(TypeError, iweb.IResource, Unadaptable())
alternate = object()
self.assertIdentical(iweb.IResource(Unadaptable(), alternate), alternate)
def test_redundant(self):
"""
Test that the adaption to IResource of an object which provides
IResource returns the same object.
"""
@implementer(iweb.IResource)
class Resource(object): ""
resource = Resource()
self.assertIdentical(iweb.IResource(resource), resource)
def test_registered(self):
"""
Test that if an adapter exists which can provide IResource for an
object which does not provide it, that adapter is used.
"""
notResource = NotResource()
self.failUnless(isinstance(iweb.IResource(notResource), ResourceAdapter))
@implementer(iweb.IChanRequest)
class TestChanRequest:
hostInfo = address.IPv4Address('TCP', 'host', 80), False
remoteHost = address.IPv4Address('TCP', 'remotehost', 34567)
finished = False
def __init__(self, site, method, prepath, uri, length=None,
headers=None, version=(1,1), content=None):
self.producer = None
self.site = site
self.method = method
self.prepath = prepath
self.uri = uri
if headers is None:
headers = http_headers.Headers()
self.headers = headers
self.http_version = version
# Anything below here we do not pass as arguments
self.request = server.Request(self,
self.method,
self.uri,
self.http_version,
length,
self.headers,
site=self.site,
prepathuri=self.prepath)
if content is not None:
self.request.handleContentChunk(content)
self.request.handleContentComplete()
self.code = None
self.responseHeaders = None
self.data = ''
self.deferredFinish = defer.Deferred()
def writeIntermediateResponse(code, headers=None):
pass
def writeHeaders(self, code, headers):
self.responseHeaders = headers
self.code = code
def write(self, data):
self.data += data
def finish(self, failed=False):
result = self.code, self.responseHeaders, self.data, failed
self.finished = True
self.deferredFinish.callback(result)
def abortConnection(self):
self.finish(failed=True)
def registerProducer(self, producer, streaming):
if self.producer is not None:
raise ValueError("Producer still set: " + repr(self.producer))
self.producer = producer
def unregisterProducer(self):
self.producer = None
def getHostInfo(self):
return self.hostInfo
def getRemoteHost(self):
return self.remoteHost
class BaseTestResource(resource.Resource):
responseCode = 200
responseText = 'This is a fake resource.'
responseHeaders = {}
addSlash = False
def __init__(self, children=[]):
"""
@type children: C{list} of C{tuple}
@param children: a list of ('path', resource) tuples
"""
for i in children:
self.putChild(i[0], i[1])
def render(self, req):
return http.Response(self.responseCode, headers=self.responseHeaders,
stream=self.responseStream())
def responseStream(self):
return stream.MemoryStream(self.responseText)
class MyRenderError(Exception):
""
class ErrorWithProducerResource(BaseTestResource):
addSlash = True
def render(self, req):
req.chanRequest.registerProducer(object(), None)
return defer.fail(MyRenderError())
def child_(self, request):
return self
_unset = object()
class BaseCase(unittest.TestCase):
"""
Base class for test cases that involve testing the result
of arbitrary HTTP(S) queries.
"""
method = 'GET'
version = (1, 1)
wait_timeout = 5.0
def chanrequest(self, root, uri, length, headers, method, version, prepath, content):
site = server.Site(root)
return TestChanRequest(site, method, prepath, uri, length, headers, version, content)
def getResponseFor(self, root, uri, headers={},
method=None, version=None, prepath='', content=None, length=_unset):
if not isinstance(headers, http_headers.Headers):
headers = http_headers.Headers(headers)
if length is _unset:
if content is not None:
length = len(content)
else:
length = 0
if method is None:
method = self.method
if version is None:
version = self.version
cr = self.chanrequest(root, uri, length, headers, method, version, prepath, content)
cr.request.process()
return cr.deferredFinish
def assertResponse(self, request_data, expected_response, failure=False):
"""
@type request_data: C{tuple}
@type expected_response: C{tuple}
@param request_data: A tuple of arguments to pass to L{getResponseFor}:
(root, uri, headers, method, version, prepath).
Root resource and requested URI are required,
and everything else is optional.
@param expected_response: A 3-tuple of the expected response:
(responseCode, headers, htmlData)
"""
d = self.getResponseFor(*request_data)
d.addCallback(self._cbGotResponse, expected_response, failure)
return d
def _cbGotResponse(self, (code, headers, data, failed), expected_response, expectedfailure=False):
expected_code, expected_headers, expected_data = expected_response
self.assertEquals(code, expected_code)
if expected_data is not None:
self.assertEquals(data, expected_data)
for key, value in expected_headers.iteritems():
self.assertEquals(headers.getHeader(key), value)
self.assertEquals(failed, expectedfailure)
class ErrorHandlingTest(BaseCase):
"""
Tests for error handling.
"""
def test_processingReallyReallyReallyFailed(self):
"""
The HTTP connection will be shut down if there's really no way to relay
any useful information about the error to the HTTP client.
"""
root = ErrorWithProducerResource()
site = server.Site(root)
tcr = TestChanRequest(site, "GET", "/", "http://localhost/")
request = server.Request(tcr, "GET", "/", (1, 1),
0, http_headers.Headers(
{"host": "localhost"}),
site=site)
proc = request.process()
done = []
proc.addBoth(done.append)
self.assertEquals(done, [None])
errs = self.flushLoggedErrors(ValueError)
self.assertIn('producer', str(errs[0]).lower())
errs = self.flushLoggedErrors(MyRenderError)
self.assertEquals(bool(errs), True)
self.assertEquals(tcr.finished, True)
class SampleWebTest(BaseCase):
class SampleTestResource(BaseTestResource):
addSlash = True
def child_validChild(self, req):
f = BaseTestResource()
f.responseCode = 200
f.responseText = 'This is a valid child resource.'
return f
def child_missingChild(self, req):
f = BaseTestResource()
f.responseCode = 404
f.responseStream = lambda self: None
return f
def child_remoteAddr(self, req):
f = BaseTestResource()
f.responseCode = 200
f.responseText = 'Remote Addr: %r' % req.remoteAddr.host
return f
def setUp(self):
self.root = self.SampleTestResource()
def test_root(self):
return self.assertResponse(
(self.root, 'http://host/'),
(200, {}, 'This is a fake resource.'))
def test_validChild(self):
return self.assertResponse(
(self.root, 'http://host/validChild'),
(200, {}, 'This is a valid child resource.'))
def test_invalidChild(self):
return self.assertResponse(
(self.root, 'http://host/invalidChild'),
(404, {}, None))
def test_remoteAddrExposure(self):
return self.assertResponse(
(self.root, 'http://host/remoteAddr'),
(200, {}, "Remote Addr: 'remotehost'"))
def test_leafresource(self):
class TestResource(resource.LeafResource):
def render(self, req):
return http.Response(stream="prepath:%s postpath:%s" % (
req.prepath,
req.postpath))
return self.assertResponse(
(TestResource(), 'http://host/consumed/path/segments'),
(200, {}, "prepath:[] postpath:['consumed', 'path', 'segments']"))
def test_redirectResource(self):
"""
Make sure a redirect response has the correct status and Location header.
"""
redirectResource = resource.RedirectResource(scheme='https',
host='localhost',
port=443,
path='/foo',
querystring='bar=baz')
return self.assertResponse(
(redirectResource, 'http://localhost/'),
(301, {'location': 'https://localhost/foo?bar=baz'}, None))
def test_redirectResourceWithSchemeRemapping(self):
"""
Make sure a redirect response has the correct status and Location header, when
SSL is on, and the client request uses scheme http with the SSL port.
"""
def chanrequest2(root, uri, length, headers, method, version, prepath, content):
site = server.Site(root)
site.EnableSSL = True
site.SSLPort = 8443
site.BindSSLPorts = []
return TestChanRequest(site, method, prepath, uri, length, headers, version, content)
self.patch(self, "chanrequest", chanrequest2)
redirectResource = resource.RedirectResource(path='/foo')
return self.assertResponse(
(redirectResource, 'http://localhost:8443/'),
(301, {'location': 'https://localhost:8443/foo'}, None))
def test_redirectResourceWithoutSchemeRemapping(self):
"""
Make sure a redirect response has the correct status and Location header, when
SSL is on, and the client request uses scheme http with the non-SSL port.
"""
def chanrequest2(root, uri, length, headers, method, version, prepath, content):
site = server.Site(root)
site.EnableSSL = True
site.SSLPort = 8443
site.BindSSLPorts = []
return TestChanRequest(site, method, prepath, uri, length, headers, version, content)
self.patch(self, "chanrequest", chanrequest2)
redirectResource = resource.RedirectResource(path='/foo')
return self.assertResponse(
(redirectResource, 'http://localhost:8008/'),
(301, {'location': 'http://localhost:8008/foo'}, None))
def test_redirectResourceWithoutSSLSchemeRemapping(self):
"""
Make sure a redirect response has the correct status and Location header, when
SSL is off, and the client request uses scheme http with the SSL port.
"""
def chanrequest2(root, uri, length, headers, method, version, prepath, content):
site = server.Site(root)
site.EnableSSL = False
site.SSLPort = 8443
site.BindSSLPorts = []
return TestChanRequest(site, method, prepath, uri, length, headers, version, content)
self.patch(self, "chanrequest", chanrequest2)
redirectResource = resource.RedirectResource(path='/foo')
return self.assertResponse(
(redirectResource, 'http://localhost:8443/'),
(301, {'location': 'http://localhost:8443/foo'}, None))
class URLParsingTest(BaseCase):
class TestResource(resource.LeafResource):
def render(self, req):
return http.Response(stream="Host:%s, Path:%s"%(req.host, req.path))
def setUp(self):
self.root = self.TestResource()
def test_normal(self):
return self.assertResponse(
(self.root, '/path', {'Host':'host'}),
(200, {}, 'Host:host, Path:/path'))
def test_fullurl(self):
return self.assertResponse(
(self.root, 'http://host/path'),
(200, {}, 'Host:host, Path:/path'))
def test_strangepath(self):
# Ensure that the double slashes don't confuse it
return self.assertResponse(
(self.root, '//path', {'Host':'host'}),
(200, {}, 'Host:host, Path://path'))
def test_strangepathfull(self):
return self.assertResponse(
(self.root, 'http://host//path'),
(200, {}, 'Host:host, Path://path'))
class TestDeferredRendering(BaseCase):
class ResourceWithDeferreds(BaseTestResource):
addSlash=True
responseText = 'I should be wrapped in a Deferred.'
def render(self, req):
d = defer.Deferred()
reactor.callLater(
0, d.callback, BaseTestResource.render(self, req))
return d
def child_deferred(self, req):
d = defer.Deferred()
reactor.callLater(0, d.callback, BaseTestResource())
return d
def test_deferredRootResource(self):
return self.assertResponse(
(self.ResourceWithDeferreds(), 'http://host/'),
(200, {}, 'I should be wrapped in a Deferred.'))
def test_deferredChild(self):
return self.assertResponse(
(self.ResourceWithDeferreds(), 'http://host/deferred'),
(200, {}, 'This is a fake resource.'))
class RedirectResourceTest(BaseCase):
def html(url):
return "Moved PermanentlyMoved Permanently
Document moved to %s.
" % (url,)
html = staticmethod(html)
def test_noRedirect(self):
# This is useless, since it's a loop, but hey
ds = []
for url in ("http://host/", "http://host/foo"):
ds.append(self.assertResponse(
(resource.RedirectResource(), url),
(301, {"location": url}, self.html(url))
))
return defer.DeferredList(ds, fireOnOneErrback=True)
def test_hostRedirect(self):
ds = []
for url1, url2 in (
("http://host/", "http://other/"),
("http://host/foo", "http://other/foo"),
):
ds.append(self.assertResponse(
(resource.RedirectResource(host="other"), url1),
(301, {"location": url2}, self.html(url2))
))
return defer.DeferredList(ds, fireOnOneErrback=True)
def test_pathRedirect(self):
root = BaseTestResource()
redirect = resource.RedirectResource(path="/other")
root.putChild("r", redirect)
ds = []
for url1, url2 in (
("http://host/r", "http://host/other"),
("http://host/r/foo", "http://host/other"),
):
ds.append(self.assertResponse(
(resource.RedirectResource(path="/other"), url1),
(301, {"location": url2}, self.html(url2))
))
return defer.DeferredList(ds, fireOnOneErrback=True)
class EmptyResource(resource.Resource):
def __init__(self, test):
self.test = test
def render(self, request):
self.test.assertEquals(request.urlForResource(self), self.expectedURI)
return 201
class RememberURIs(BaseCase):
"""
Tests for URI memory and lookup mechanism in server.Request.
"""
def test_requestedResource(self):
"""
Test urlForResource() on deeply nested resource looked up via
request processing.
"""
root = EmptyResource(self)
root.expectedURI = "/"
foo = EmptyResource(self)
foo.expectedURI = "/foo"
root.putChild("foo", foo)
bar = EmptyResource(self)
bar.expectedURI = foo.expectedURI + "/bar"
foo.putChild("bar", bar)
baz = EmptyResource(self)
baz.expectedURI = bar.expectedURI + "/baz"
bar.putChild("baz", baz)
ds = []
for uri in (foo.expectedURI, bar.expectedURI, baz.expectedURI):
ds.append(self.assertResponse(
(root, uri, {'Host':'host'}),
(201, {}, None),
))
return defer.DeferredList(ds, fireOnOneErrback=True)
def test_urlEncoding(self):
"""
Test to make sure that URL encoding is working.
"""
root = EmptyResource(self)
root.expectedURI = "/"
child = EmptyResource(self)
child.expectedURI = "/foo%20bar"
root.putChild("foo bar", child)
return self.assertResponse(
(root, child.expectedURI, {'Host':'host'}),
(201, {}, None)
)
def test_locateResource(self):
"""
Test urlForResource() on resource looked up via a locateResource() call.
"""
root = resource.Resource()
child = resource.Resource()
root.putChild("foo", child)
request = SimpleRequest(server.Site(root), "GET", "/")
def gotResource(resource):
self.assertEquals("/foo", request.urlForResource(resource))
d = defer.maybeDeferred(request.locateResource, "/foo")
d.addCallback(gotResource)
return d
def test_unknownResource(self):
"""
Test urlForResource() on unknown resource.
"""
root = resource.Resource()
child = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/")
self.assertRaises(server.NoURLForResourceError, request.urlForResource, child)
def test_locateChildResource(self):
"""
Test urlForResource() on deeply nested resource looked up via
locateChildResource().
"""
root = EmptyResource(self)
root.expectedURI = "/"
foo = EmptyResource(self)
foo.expectedURI = "/foo"
root.putChild("foo", foo)
bar = EmptyResource(self)
bar.expectedURI = "/foo/bar"
foo.putChild("bar", bar)
baz = EmptyResource(self)
baz.expectedURI = "/foo/bar/b%20a%20z"
bar.putChild("b a z", baz)
request = SimpleRequest(server.Site(root), "GET", "/")
def gotResource(resource):
# Make sure locateChildResource() gave us the right answer
self.assertEquals(resource, bar)
return request.locateChildResource(resource, "b a z").addCallback(gotChildResource)
def gotChildResource(resource):
# Make sure locateChildResource() gave us the right answer
self.assertEquals(resource, baz)
self.assertEquals(resource.expectedURI, request.urlForResource(resource))
d = request.locateResource(bar.expectedURI)
d.addCallback(gotResource)
return d
def test_deferredLocateChild(self):
"""
Test deferred value from locateChild()
"""
class DeferredLocateChild(resource.Resource):
def locateChild(self, req, segments):
return defer.maybeDeferred(
super(DeferredLocateChild, self).locateChild,
req, segments
)
root = DeferredLocateChild()
child = resource.Resource()
root.putChild("foo", child)
request = SimpleRequest(server.Site(root), "GET", "/foo")
def gotResource(resource):
self.assertEquals("/foo", request.urlForResource(resource))
d = request.locateResource("/foo")
d.addCallback(gotResource)
return d
class ParsePostDataTests(unittest.TestCase):
"""
Tests for L{server.parsePOSTData}.
"""
def test_noData(self):
"""
Parsing a request without data should succeed but should not fill the
C{args} and C{files} attributes of the request.
"""
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/")
def cb(ign):
self.assertEquals(request.args, {})
self.assertEquals(request.files, {})
return server.parsePOSTData(request).addCallback(cb)
def test_noContentType(self):
"""
Parsing a request without content-type should succeed but should not
fill the C{args} and C{files} attributes of the request.
"""
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/", content="foo")
def cb(ign):
self.assertEquals(request.args, {})
self.assertEquals(request.files, {})
return server.parsePOSTData(request).addCallback(cb)
def test_urlencoded(self):
"""
Test parsing data in urlencoded format: it should end in the C{args}
attribute.
"""
ctype = http_headers.MimeType('application', 'x-www-form-urlencoded')
content = "key=value&multiple=two+words&multiple=more%20words"
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/",
http_headers.Headers({'content-type': ctype}), content)
def cb(ign):
self.assertEquals(request.files, {})
self.assertEquals(request.args,
{'multiple': ['two words', 'more words'], 'key': ['value']})
return server.parsePOSTData(request).addCallback(cb)
def test_multipart(self):
"""
Test parsing data in multipart format: it should fill the C{files}
attribute.
"""
ctype = http_headers.MimeType('multipart', 'form-data',
(('boundary', '---weeboundary'),))
content="""-----weeboundary\r
Content-Disposition: form-data; name="FileNameOne"; filename="myfilename"\r
Content-Type: text/html\r
\r
my great content wooo\r
-----weeboundary--\r
"""
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/",
http_headers.Headers({'content-type': ctype}), content)
def cb(ign):
self.assertEquals(request.args, {})
self.assertEquals(request.files.keys(), ['FileNameOne'])
self.assertEquals(request.files.values()[0][0][:2],
('myfilename', http_headers.MimeType('text', 'html', {})))
f = request.files.values()[0][0][2]
self.assertEquals(f.read(), "my great content wooo")
return server.parsePOSTData(request).addCallback(cb)
def test_multipartWithNoBoundary(self):
"""
If the boundary type is not specified, parsing should fail with a
C{http.HTTPError}.
"""
ctype = http_headers.MimeType('multipart', 'form-data')
content="""-----weeboundary\r
Content-Disposition: form-data; name="FileNameOne"; filename="myfilename"\r
Content-Type: text/html\r
\r
my great content wooo\r
-----weeboundary--\r
"""
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/",
http_headers.Headers({'content-type': ctype}), content)
return self.assertFailure(server.parsePOSTData(request),
http.HTTPError)
def test_wrongContentType(self):
"""
Check that a content-type not handled raise a C{http.HTTPError}.
"""
ctype = http_headers.MimeType('application', 'foobar')
content = "key=value&multiple=two+words&multiple=more%20words"
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/",
http_headers.Headers({'content-type': ctype}), content)
return self.assertFailure(server.parsePOSTData(request),
http.HTTPError)
def test_mimeParsingError(self):
"""
A malformed content should result in a C{http.HTTPError}.
The tested content has an invalid closing boundary.
"""
ctype = http_headers.MimeType('multipart', 'form-data',
(('boundary', '---weeboundary'),))
content="""-----weeboundary\r
Content-Disposition: form-data; name="FileNameOne"; filename="myfilename"\r
Content-Type: text/html\r
\r
my great content wooo\r
-----weeoundary--\r
"""
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/",
http_headers.Headers({'content-type': ctype}), content)
return self.assertFailure(server.parsePOSTData(request),
http.HTTPError)
def test_multipartMaxMem(self):
"""
Check that the C{maxMem} parameter makes the parsing raise an
exception if the value is reached.
"""
ctype = http_headers.MimeType('multipart', 'form-data',
(('boundary', '---weeboundary'),))
content="""-----weeboundary\r
Content-Disposition: form-data; name="FileNameOne"\r
Content-Type: text/html\r
\r
my great content wooo
and even more and more\r
-----weeboundary--\r
"""
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/",
http_headers.Headers({'content-type': ctype}), content)
def cb(res):
self.assertEquals(res.response.description,
"Maximum length of 10 bytes exceeded.")
return self.assertFailure(server.parsePOSTData(request, maxMem=10),
http.HTTPError).addCallback(cb)
def test_multipartMaxSize(self):
"""
Check that the C{maxSize} parameter makes the parsing raise an
exception if the data is too big.
"""
ctype = http_headers.MimeType('multipart', 'form-data',
(('boundary', '---weeboundary'),))
content="""-----weeboundary\r
Content-Disposition: form-data; name="FileNameOne"; filename="myfilename"\r
Content-Type: text/html\r
\r
my great content wooo
and even more and more\r
-----weeboundary--\r
"""
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/",
http_headers.Headers({'content-type': ctype}), content)
def cb(res):
self.assertEquals(res.response.description,
"Maximum length of 10 bytes exceeded.")
return self.assertFailure(server.parsePOSTData(request, maxSize=10),
http.HTTPError).addCallback(cb)
def test_maxFields(self):
"""
Check that the C{maxSize} parameter makes the parsing raise an
exception if the data contains too many fields.
"""
ctype = http_headers.MimeType('multipart', 'form-data',
(('boundary', '---xyz'),))
content = """-----xyz\r
Content-Disposition: form-data; name="foo"\r
\r
Foo Bar\r
-----xyz\r
Content-Disposition: form-data; name="foo"\r
\r
Baz\r
-----xyz\r
Content-Disposition: form-data; name="file"; filename="filename"\r
Content-Type: text/html\r
\r
blah\r
-----xyz\r
Content-Disposition: form-data; name="file"; filename="filename"\r
Content-Type: text/plain\r
\r
bleh\r
-----xyz--\r
"""
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/",
http_headers.Headers({'content-type': ctype}), content)
def cb(res):
self.assertEquals(res.response.description,
"Maximum number of fields 3 exceeded")
return self.assertFailure(server.parsePOSTData(request, maxFields=3),
http.HTTPError).addCallback(cb)
def test_otherErrors(self):
"""
Test that errors durign parsing other than C{MimeFormatError} are
propagated.
"""
ctype = http_headers.MimeType('multipart', 'form-data',
(('boundary', '---weeboundary'),))
# XXX: maybe this is not a good example
# parseContentDispositionFormData could handle this problem
content="""-----weeboundary\r
Content-Disposition: form-data; name="FileNameOne"; filename="myfilename and invalid data \r
-----weeboundary--\r
"""
root = resource.Resource()
request = SimpleRequest(server.Site(root), "GET", "/",
http_headers.Headers({'content-type': ctype}), content)
return self.assertFailure(server.parsePOSTData(request),
ValueError)
calendarserver-5.2+dfsg/twext/web2/test/__init__.py 0000644 0001750 0001750 00000000242 11337102650 021405 0 ustar rahul rahul # Copyright (c) 2001-2006 Twisted Matrix Laboratories.
# See LICENSE for details.
"""
twext.web2.test: unittests for the Twext.Web2, Web Server Framework
"""
calendarserver-5.2+dfsg/twext/web2/fileupload.py 0000644 0001750 0001750 00000032127 12263343324 021026 0 ustar rahul rahul ##
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
from __future__ import print_function
import re
from zope.interface import implements
import urllib
import tempfile
from twisted.internet import defer
from twext.web2.stream import IStream, FileStream, BufferedStream, readStream
from twext.web2.stream import generatorToStream, readAndDiscard
from twext.web2 import http_headers
from cStringIO import StringIO
###################################
##### Multipart MIME Reader #####
###################################
class MimeFormatError(Exception):
pass
# parseContentDispositionFormData is absolutely horrible, but as
# browsers don't seem to believe in sensible quoting rules, it's
# really the only way to handle the header. (Quotes can be in the
# filename, unescaped)
cd_regexp = re.compile(
' *form-data; *name="([^"]*)"(?:; *filename="(.*)")?$',
re.IGNORECASE)
def parseContentDispositionFormData(value):
match = cd_regexp.match(value)
if not match:
# Error parsing.
raise ValueError("Unknown content-disposition format.")
name=match.group(1)
filename=match.group(2)
return name, filename
#@defer.deferredGenerator
def _readHeaders(stream):
"""Read the MIME headers. Assumes we've just finished reading in the
boundary string."""
ctype = fieldname = filename = None
headers = []
# Now read headers
while 1:
line = stream.readline(size=1024)
if isinstance(line, defer.Deferred):
line = defer.waitForDeferred(line)
yield line
line = line.getResult()
#print("GOT", line)
if not line.endswith('\r\n'):
if line == "":
raise MimeFormatError("Unexpected end of stream.")
else:
raise MimeFormatError("Header line too long")
line = line[:-2] # strip \r\n
if line == "":
break # End of headers
parts = line.split(':', 1)
if len(parts) != 2:
raise MimeFormatError("Header did not have a :")
name, value = parts
name = name.lower()
headers.append((name, value))
if name == "content-type":
ctype = http_headers.parseContentType(http_headers.tokenize((value,), foldCase=False))
elif name == "content-disposition":
fieldname, filename = parseContentDispositionFormData(value)
if ctype is None:
ctype == http_headers.MimeType('application', 'octet-stream')
if fieldname is None:
raise MimeFormatError('Content-disposition invalid or omitted.')
# End of headers, return (field name, content-type, filename)
yield fieldname, filename, ctype
return
_readHeaders = defer.deferredGenerator(_readHeaders)
class _BoundaryWatchingStream(object):
def __init__(self, stream, boundary):
self.stream = stream
self.boundary = boundary
self.data = ''
self.deferred = defer.Deferred()
length = None # unknown
def read(self):
if self.stream is None:
if self.deferred is not None:
deferred = self.deferred
self.deferred = None
deferred.callback(None)
return None
newdata = self.stream.read()
if isinstance(newdata, defer.Deferred):
return newdata.addCallbacks(self._gotRead, self._gotError)
return self._gotRead(newdata)
def _gotRead(self, newdata):
if not newdata:
raise MimeFormatError("Unexpected EOF")
# BLECH, converting buffer back into string.
self.data += str(newdata)
data = self.data
boundary = self.boundary
off = data.find(boundary)
if off == -1:
# No full boundary, check for the first character
off = data.rfind(boundary[0], max(0, len(data)-len(boundary)))
if off != -1:
# We could have a partial boundary, store it for next time
self.data = data[off:]
return data[:off]
else:
self.data = ''
return data
else:
self.stream.pushback(data[off+len(boundary):])
self.stream = None
return data[:off]
def _gotError(self, err):
# Propogate error back to MultipartMimeStream also
if self.deferred is not None:
deferred = self.deferred
self.deferred = None
deferred.errback(err)
return err
def close(self):
# Assume error will be raised again and handled by MMS?
readAndDiscard(self).addErrback(lambda _: None)
class MultipartMimeStream(object):
implements(IStream)
def __init__(self, stream, boundary):
self.stream = BufferedStream(stream)
self.boundary = "--"+boundary
self.first = True
def read(self):
"""
Return a deferred which will fire with a tuple of:
(fieldname, filename, ctype, dataStream)
or None when all done.
Format errors will be sent to the errback.
Returns None when all done.
IMPORTANT: you *must* exhaust dataStream returned by this call
before calling .read() again!
"""
if self.first:
self.first = False
d = self._readFirstBoundary()
else:
d = self._readBoundaryLine()
d.addCallback(self._doReadHeaders)
d.addCallback(self._gotHeaders)
return d
def _readFirstBoundary(self):
#print("_readFirstBoundary")
line = self.stream.readline(size=1024)
if isinstance(line, defer.Deferred):
line = defer.waitForDeferred(line)
yield line
line = line.getResult()
if line != self.boundary + '\r\n':
raise MimeFormatError("Extra data before first boundary: %r looking for: %r" % (line, self.boundary + '\r\n'))
self.boundary = "\r\n"+self.boundary
yield True
return
_readFirstBoundary = defer.deferredGenerator(_readFirstBoundary)
def _readBoundaryLine(self):
#print("_readBoundaryLine")
line = self.stream.readline(size=1024)
if isinstance(line, defer.Deferred):
line = defer.waitForDeferred(line)
yield line
line = line.getResult()
if line == "--\r\n":
# THE END!
yield False
return
elif line != "\r\n":
raise MimeFormatError("Unexpected data on same line as boundary: %r" % (line,))
yield True
return
_readBoundaryLine = defer.deferredGenerator(_readBoundaryLine)
def _doReadHeaders(self, morefields):
#print("_doReadHeaders", morefields)
if not morefields:
return None
return _readHeaders(self.stream)
def _gotHeaders(self, headers):
if headers is None:
return None
bws = _BoundaryWatchingStream(self.stream, self.boundary)
self.deferred = bws.deferred
ret=list(headers)
ret.append(bws)
return tuple(ret)
def readIntoFile(stream, outFile, maxlen):
"""Read the stream into a file, but not if it's longer than maxlen.
Returns Deferred which will be triggered on finish.
"""
curlen = [0]
def done(_):
return _
def write(data):
curlen[0] += len(data)
if curlen[0] > maxlen:
raise MimeFormatError("Maximum length of %d bytes exceeded." %
maxlen)
outFile.write(data)
return readStream(stream, write).addBoth(done)
#@defer.deferredGenerator
def parseMultipartFormData(stream, boundary,
maxMem=100*1024, maxFields=1024, maxSize=10*1024*1024):
# If the stream length is known to be too large upfront, abort immediately
if stream.length is not None and stream.length > maxSize:
raise MimeFormatError("Maximum length of %d bytes exceeded." %
maxSize)
mms = MultipartMimeStream(stream, boundary)
numFields = 0
args = {}
files = {}
while 1:
datas = mms.read()
if isinstance(datas, defer.Deferred):
datas = defer.waitForDeferred(datas)
yield datas
datas = datas.getResult()
if datas is None:
break
numFields+=1
if numFields == maxFields:
raise MimeFormatError("Maximum number of fields %d exceeded"%maxFields)
# Parse data
fieldname, filename, ctype, stream = datas
if filename is None:
# Not a file
outfile = StringIO()
maxBuf = min(maxSize, maxMem)
else:
outfile = tempfile.NamedTemporaryFile()
maxBuf = maxSize
x = readIntoFile(stream, outfile, maxBuf)
if isinstance(x, defer.Deferred):
x = defer.waitForDeferred(x)
yield x
x = x.getResult()
if filename is None:
# Is a normal form field
outfile.seek(0)
data = outfile.read()
args.setdefault(fieldname, []).append(data)
maxMem -= len(data)
maxSize -= len(data)
else:
# Is a file upload
maxSize -= outfile.tell()
outfile.seek(0)
files.setdefault(fieldname, []).append((filename, ctype, outfile))
yield args, files
return
parseMultipartFormData = defer.deferredGenerator(parseMultipartFormData)
###################################
##### x-www-urlencoded reader #####
###################################
def parse_urlencoded_stream(input, maxMem=100*1024,
keep_blank_values=False, strict_parsing=False):
lastdata = ''
still_going=1
while still_going:
try:
yield input.wait
data = input.next()
except StopIteration:
pairs = [lastdata]
still_going=0
else:
maxMem -= len(data)
if maxMem < 0:
raise MimeFormatError("Maximum length of %d bytes exceeded." %
maxMem)
pairs = str(data).split('&')
pairs[0] = lastdata + pairs[0]
lastdata=pairs.pop()
for name_value in pairs:
nv = name_value.split('=', 1)
if len(nv) != 2:
if strict_parsing:
raise MimeFormatError("bad query field: %s") % `name_value`
continue
if len(nv[1]) or keep_blank_values:
name = urllib.unquote(nv[0].replace('+', ' '))
value = urllib.unquote(nv[1].replace('+', ' '))
yield name, value
parse_urlencoded_stream = generatorToStream(parse_urlencoded_stream)
def parse_urlencoded(stream, maxMem=100*1024, maxFields=1024,
keep_blank_values=False, strict_parsing=False):
d = {}
numFields = 0
s=parse_urlencoded_stream(stream, maxMem, keep_blank_values, strict_parsing)
while 1:
datas = s.read()
if isinstance(datas, defer.Deferred):
datas = defer.waitForDeferred(datas)
yield datas
datas = datas.getResult()
if datas is None:
break
name, value = datas
numFields += 1
if numFields == maxFields:
raise MimeFormatError("Maximum number of fields %d exceeded"%maxFields)
if name in d:
d[name].append(value)
else:
d[name] = [value]
yield d
return
parse_urlencoded = defer.deferredGenerator(parse_urlencoded)
if __name__ == '__main__':
d = parseMultipartFormData(
FileStream(open("upload.txt")), "----------0xKhTmLbOuNdArY")
from twext.python.log import Logger
log = Logger()
d.addErrback(log.err)
def pr(s):
print(s)
d.addCallback(pr)
__all__ = ['parseMultipartFormData', 'parse_urlencoded', 'parse_urlencoded_stream', 'MultipartMimeStream', 'MimeFormatError']
calendarserver-5.2+dfsg/twext/web2/http.py 0000644 0001750 0001750 00000046712 12263343324 017666 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_http -*-
##
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""HyperText Transfer Protocol implementation.
The second coming.
Maintainer: James Y Knight
"""
# import traceback; log.info(''.join(traceback.format_stack()))
import json
import time
from twisted.internet import interfaces, error
from twisted.python import components
from twisted.web.template import Element, XMLString, renderer, flattenString
from zope.interface import implements
from twext.python.log import Logger
from twext.web2 import responsecode
from twext.web2 import http_headers
from twext.web2 import iweb
from twext.web2 import stream
from twext.web2.stream import IByteStream, readAndDiscard
log = Logger()
defaultPortForScheme = {'http': 80, 'https': 443, 'ftp': 21}
def splitHostPort(scheme, hostport):
"""Split the host in "host:port" format into host and port fields.
If port was not specified, use the default for the given scheme, if
known. Returns a tuple of (hostname, portnumber)."""
# Split hostport into host and port
hostport = hostport.split(':', 1)
try:
if len(hostport) == 2:
return hostport[0], int(hostport[1])
except ValueError:
pass
return hostport[0], defaultPortForScheme.get(scheme, 0)
def parseVersion(strversion):
"""Parse version strings of the form Protocol '/' Major '.' Minor. E.g. 'HTTP/1.1'.
Returns (protocol, major, minor).
Will raise ValueError on bad syntax."""
proto, strversion = strversion.split('/')
major, minor = strversion.split('.')
major, minor = int(major), int(minor)
if major < 0 or minor < 0:
raise ValueError("negative number")
return (proto.lower(), major, minor)
class HTTPError(Exception):
def __init__(self, codeOrResponse):
"""An Exception for propagating HTTP Error Responses.
@param codeOrResponse: The numeric HTTP code or a complete http.Response
object.
@type codeOrResponse: C{int} or L{http.Response}
"""
self.response = iweb.IResponse(codeOrResponse)
Exception.__init__(self, str(self.response))
def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.response)
class Response(object):
"""An object representing an HTTP Response to be sent to the client.
"""
implements(iweb.IResponse)
code = responsecode.OK
headers = None
stream = None
def __init__(self, code=None, headers=None, stream=None):
"""
@param code: The HTTP status code for this Response
@type code: C{int}
@param headers: Headers to be sent to the client.
@type headers: C{dict}, L{twext.web2.http_headers.Headers}, or
C{None}
@param stream: Content body to send to the HTTP client
@type stream: L{twext.web2.stream.IByteStream}
"""
if code is not None:
self.code = int(code)
if headers is not None:
if isinstance(headers, dict):
headers = http_headers.Headers(headers)
self.headers = headers
else:
self.headers = http_headers.Headers()
if stream is not None:
self.stream = IByteStream(stream)
def __repr__(self):
if self.stream is None:
streamlen = None
else:
streamlen = self.stream.length
return "<%s.%s code=%d, streamlen=%s>" % (self.__module__, self.__class__.__name__, self.code, streamlen)
class StatusResponseElement(Element):
"""
Render the HTML for a L{StatusResponse}
"""
loader = XMLString("""
""")
def __init__(self, title, description):
super(StatusResponseElement, self).__init__()
self.title = title
self.description = description
@renderer
def response(self, request, tag):
"""
Top-level renderer.
"""
return tag.fillSlots(title=self.title, description=self.description)
class StatusResponse (Response):
"""
A L{Response} object which simply contains a status code and a description
of what happened.
"""
def __init__(self, code, description, title=None):
"""
@param code: a response code in L{responsecode.RESPONSES}.
@param description: a string description.
@param title: the message title. If not specified or C{None}, defaults
to C{responsecode.RESPONSES[code]}.
"""
if title is None:
title = responsecode.RESPONSES[code]
element = StatusResponseElement(title, description)
out = []
flattenString(None, element).addCallback(out.append)
mime_params = {"charset": "utf-8"}
super(StatusResponse, self).__init__(code=code, stream=out[0])
self.headers.setHeader(
"content-type", http_headers.MimeType("text", "html", mime_params)
)
self.description = description
def __repr__(self):
return "<%s %s %s>" % (self.__class__.__name__, self.code, self.description)
class RedirectResponse (StatusResponse):
"""
A L{Response} object that contains a redirect to another network location.
"""
def __init__(self, location, temporary=False):
"""
@param location: the URI to redirect to.
@param temporary: whether it's a temporary redirect or permanent
"""
code = responsecode.TEMPORARY_REDIRECT if temporary else responsecode.MOVED_PERMANENTLY
super(RedirectResponse, self).__init__(
code,
"Document moved to %s." % (location,)
)
self.headers.setHeader("location", location)
def NotModifiedResponse(oldResponse=None):
if oldResponse is not None:
headers = http_headers.Headers()
for header in (
# Required from sec 10.3.5:
'date', 'etag', 'content-location', 'expires',
'cache-control', 'vary',
# Others:
'server', 'proxy-authenticate', 'www-authenticate', 'warning'):
value = oldResponse.headers.getRawHeaders(header)
if value is not None:
headers.setRawHeaders(header, value)
else:
headers = None
return Response(code=responsecode.NOT_MODIFIED, headers=headers)
def checkPreconditions(request, response=None, entityExists=True, etag=None, lastModified=None):
"""Check to see if this request passes the conditional checks specified
by the client. May raise an HTTPError with result codes L{NOT_MODIFIED}
or L{PRECONDITION_FAILED}, as appropriate.
This function is called automatically as an output filter for GET and
HEAD requests. With GET/HEAD, it is not important for the precondition
check to occur before doing the action, as the method is non-destructive.
However, if you are implementing other request methods, like PUT
for your resource, you will need to call this after determining
the etag and last-modified time of the existing resource but
before actually doing the requested action. In that case,
This examines the appropriate request headers for conditionals,
(If-Modified-Since, If-Unmodified-Since, If-Match, If-None-Match,
or If-Range), compares with the etag and last and
and then sets the response code as necessary.
@param response: This should be provided for GET/HEAD methods. If
it is specified, the etag and lastModified arguments will
be retrieved automatically from the response headers and
shouldn't be separately specified. Not providing the
response with a GET request may cause the emitted
"Not Modified" responses to be non-conformant.
@param entityExists: Set to False if the entity in question doesn't
yet exist. Necessary for PUT support with 'If-None-Match: *'.
@param etag: The etag of the resource to check against, or None.
@param lastModified: The last modified date of the resource to check
against, or None.
@raise: HTTPError: Raised when the preconditions fail, in order to
abort processing and emit an error page.
"""
if response:
assert etag is None and lastModified is None
# if the code is some sort of error code, don't do anything
if not ((response.code >= 200 and response.code <= 299)
or response.code == responsecode.PRECONDITION_FAILED):
return False
etag = response.headers.getHeader("etag")
lastModified = response.headers.getHeader("last-modified")
def matchETag(tags, allowWeak):
if entityExists and '*' in tags:
return True
if etag is None:
return False
return ((allowWeak or not etag.weak) and
([etagmatch for etagmatch in tags if etag.match(etagmatch, strongCompare=not allowWeak)]))
# First check if-match/if-unmodified-since
# If either one fails, we return PRECONDITION_FAILED
match = request.headers.getHeader("if-match")
if match:
if not matchETag(match, False):
raise HTTPError(StatusResponse(responsecode.PRECONDITION_FAILED, "Requested resource does not have a matching ETag."))
unmod_since = request.headers.getHeader("if-unmodified-since")
if unmod_since:
if not lastModified or lastModified > unmod_since:
raise HTTPError(StatusResponse(responsecode.PRECONDITION_FAILED, "Requested resource has changed."))
# Now check if-none-match/if-modified-since.
# This bit is tricky, because of the requirements when both IMS and INM
# are present. In that case, you can't return a failure code
# unless *both* checks think it failed.
# Also, if the INM check succeeds, ignore IMS, because INM is treated
# as more reliable.
# I hope I got the logic right here...the RFC is quite poorly written
# in this area. Someone might want to verify the testcase against
# RFC wording.
# If IMS header is later than current time, ignore it.
notModified = None
ims = request.headers.getHeader('if-modified-since')
if ims:
notModified = (ims < time.time() and lastModified and lastModified <= ims)
inm = request.headers.getHeader("if-none-match")
if inm:
if request.method in ("HEAD", "GET"):
# If it's a range request, don't allow a weak ETag, as that
# would break.
canBeWeak = not request.headers.hasHeader('Range')
if notModified != False and matchETag(inm, canBeWeak):
raise HTTPError(NotModifiedResponse(response))
else:
if notModified != False and matchETag(inm, False):
raise HTTPError(StatusResponse(responsecode.PRECONDITION_FAILED, "Requested resource has a matching ETag."))
else:
if notModified == True:
raise HTTPError(NotModifiedResponse(response))
def checkIfRange(request, response):
"""Checks for the If-Range header, and if it exists, checks if the
test passes. Returns true if the server should return partial data."""
ifrange = request.headers.getHeader("if-range")
if ifrange is None:
return True
if isinstance(ifrange, http_headers.ETag):
return ifrange.match(response.headers.getHeader("etag"), strongCompare=True)
else:
return ifrange == response.headers.getHeader("last-modified")
class _NotifyingProducerStream(stream.ProducerStream):
doStartReading = None
def __init__(self, length=None, doStartReading=None):
stream.ProducerStream.__init__(self, length=length)
self.doStartReading = doStartReading
def read(self):
if self.doStartReading is not None:
doStartReading = self.doStartReading
self.doStartReading = None
doStartReading()
return stream.ProducerStream.read(self)
def write(self, data):
self.doStartReading = None
stream.ProducerStream.write(self, data)
def finish(self):
self.doStartReading = None
stream.ProducerStream.finish(self)
# response codes that must have empty bodies
NO_BODY_CODES = (responsecode.NO_CONTENT, responsecode.NOT_MODIFIED)
class Request(object):
"""A HTTP request.
Subclasses should override the process() method to determine how
the request will be processed.
@ivar method: The HTTP method that was used.
@ivar uri: The full URI that was requested (includes arguments).
@ivar headers: All received headers
@ivar clientproto: client HTTP version
@ivar stream: incoming data stream.
"""
implements(iweb.IRequest, interfaces.IConsumer)
known_expects = ('100-continue',)
def __init__(self, chanRequest, command, path, version, contentLength, headers):
"""
@param chanRequest: the channel request we're associated with.
"""
self.chanRequest = chanRequest
self.method = command
self.uri = path
self.clientproto = version
self.headers = headers
if '100-continue' in self.headers.getHeader('expect', ()):
doStartReading = self._sendContinue
else:
doStartReading = None
self.stream = _NotifyingProducerStream(contentLength, doStartReading)
self.stream.registerProducer(self.chanRequest, True)
def checkExpect(self):
"""Ensure there are no expectations that cannot be met.
Checks Expect header against self.known_expects."""
expects = self.headers.getHeader('expect', ())
for expect in expects:
if expect not in self.known_expects:
raise HTTPError(responsecode.EXPECTATION_FAILED)
def process(self):
"""Called by channel to let you process the request.
Can be overridden by a subclass to do something useful."""
pass
def handleContentChunk(self, data):
"""Callback from channel when a piece of data has been received.
Puts the data in .stream"""
self.stream.write(data)
def handleContentComplete(self):
"""Callback from channel when all data has been received. """
self.stream.unregisterProducer()
self.stream.finish()
def connectionLost(self, reason):
"""connection was lost"""
pass
def __repr__(self):
return '<%s %s %s>' % (self.method, self.uri, self.clientproto)
def _sendContinue(self):
self.chanRequest.writeIntermediateResponse(responsecode.CONTINUE)
def _reallyFinished(self, x):
"""We are finished writing data."""
self.chanRequest.finish()
def _finished(self, x):
"""
We are finished writing data.
But we need to check that we have also finished reading all data as we
might have sent a, for example, 401 response before we read any data.
To make sure that the stream/producer sequencing works properly we need
to discard the remaining data in the request.
"""
if self.stream.length != 0:
return readAndDiscard(self.stream).addCallback(self._reallyFinished).addErrback(self._error)
else:
self._reallyFinished(x)
def _error(self, reason):
if reason.check(error.ConnectionLost):
log.info("Request error: {message}", message=reason.getErrorMessage())
else:
log.failure("Request error", reason)
# Only bother with cleanup on errors other than lost connection.
self.chanRequest.abortConnection()
def writeResponse(self, response):
"""
Write a response.
"""
if self.stream.doStartReading is not None:
# Expect: 100-continue was requested, but 100 response has not been
# sent, and there's a possibility that data is still waiting to be
# sent.
#
# Ideally this means the remote side will not send any data.
# However, because of compatibility requirements, it might timeout,
# and decide to do so anyways at the same time we're sending back
# this response. Thus, the read state is unknown after this.
# We must close the connection.
self.chanRequest.channel.setReadPersistent(False)
# Nothing more will be read
self.chanRequest.allContentReceived()
if response.code != responsecode.NOT_MODIFIED:
# Not modified response is *special* and doesn't get a content-length.
if response.stream is None:
response.headers.setHeader('content-length', 0)
elif response.stream.length is not None:
response.headers.setHeader('content-length', response.stream.length)
self.chanRequest.writeHeaders(response.code, response.headers)
# if this is a "HEAD" request, or a special response code,
# don't return any data.
if self.method == "HEAD" or response.code in NO_BODY_CODES:
if response.stream is not None:
response.stream.close()
self._finished(None)
return
d = stream.StreamProducer(response.stream).beginProducing(self.chanRequest)
d.addCallback(self._finished).addErrback(self._error)
class XMLResponse (Response):
"""
XML L{Response} object.
Renders itself as an XML document.
"""
def __init__(self, code, element):
"""
@param xml_responses: an iterable of davxml.Response objects.
"""
Response.__init__(self, code, stream=element.toxml())
self.headers.setHeader("content-type", http_headers.MimeType("text", "xml"))
class JSONResponse (Response):
"""
JSON L{Response} object.
Renders itself as an JSON document.
"""
def __init__(self, code, jobj):
"""
@param xml_responses: an iterable of davxml.Response objects.
"""
Response.__init__(self, code, stream=json.dumps(jobj))
self.headers.setHeader("content-type", http_headers.MimeType("application", "json"))
components.registerAdapter(Response, int, iweb.IResponse)
__all__ = ['HTTPError', 'NotModifiedResponse', 'Request', 'Response', 'StatusResponse', 'RedirectResponse', 'checkIfRange', 'checkPreconditions', 'defaultPortForScheme', 'parseVersion', 'splitHostPort', "XMLResponse", "JSONResponse"]
calendarserver-5.2+dfsg/twext/web2/auth/ 0000755 0001750 0001750 00000000000 12322625325 017264 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/auth/basic.py 0000644 0001750 0001750 00000004425 12263343324 020724 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_httpauth -*-
##
# Copyright (c) 2006-2009 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
from twisted.cred import credentials, error
from twisted.internet.defer import succeed, fail
from twext.web2.auth.interfaces import ICredentialFactory
from zope.interface import implements
class BasicCredentialFactory(object):
"""
Credential Factory for HTTP Basic Authentication
"""
implements(ICredentialFactory)
scheme = 'basic'
def __init__(self, realm):
self.realm = realm
def getChallenge(self, peer):
"""
@see L{ICredentialFactory.getChallenge}
"""
return succeed({'realm': self.realm})
def decode(self, response, request):
"""
Decode the credentials for basic auth.
@see L{ICredentialFactory.decode}
"""
try:
creds = (response + '===').decode('base64')
except:
raise error.LoginFailed('Invalid credentials')
creds = creds.split(':', 1)
if len(creds) == 2:
return succeed(credentials.UsernamePassword(*creds))
else:
return fail(error.LoginFailed('Invalid credentials'))
calendarserver-5.2+dfsg/twext/web2/auth/interfaces.py 0000644 0001750 0001750 00000006111 12263343324 021760 0 ustar rahul rahul ##
# Copyright (c) 2004-2007 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
from zope.interface import Interface, Attribute
class ICredentialFactory(Interface):
"""
A credential factory provides state between stages in HTTP
authentication. It is ultimately in charge of creating an
ICredential for the specified scheme, that will be used by
cred to complete authentication.
"""
scheme = Attribute(("string indicating the authentication scheme "
"this factory is associated with."))
def getChallenge(peer):
"""
Generate a challenge the client may respond to.
@type peer: L{twisted.internet.interfaces.IAddress}
@param peer: The client's address
@rtype: C{dict}
@return: Deferred returning dictionary of challenge arguments
"""
def decode(response, request):
"""
Create a credentials object from the given response.
May raise twisted.cred.error.LoginFailed if the response is invalid.
@type response: C{str}
@param response: scheme specific response string
@type request: L{twext.web2.server.Request}
@param request: the request being processed
@return: Deferred returning ICredentials
"""
class IAuthenticatedRequest(Interface):
"""
A request that has been authenticated with the use of Cred,
and holds a reference to the avatar returned by portal.login
"""
avatarInterface = Attribute(("The credential interface implemented by "
"the avatar"))
avatar = Attribute("The application specific avatar returned by "
"the application's realm")
class IHTTPUser(Interface):
"""
A generic interface that can implemented by an avatar to provide
access to the username used when authenticating.
"""
username = Attribute(("A string representing the username portion of "
"the credentials used for authentication"))
calendarserver-5.2+dfsg/twext/web2/auth/wrapper.py 0000644 0001750 0001750 00000022566 12263343324 021331 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_httpauth -*-
##
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
Wrapper Resources for rfc2617 HTTP Auth.
"""
from zope.interface import implements, directlyProvides
from twisted.cred import error, credentials
from twisted.internet.defer import gatherResults, succeed
from twisted.python import failure
from twext.web2 import responsecode
from twext.web2 import http
from twext.web2 import iweb
from twext.web2.auth.interfaces import IAuthenticatedRequest
class UnauthorizedResponse(http.StatusResponse):
"""A specialized response class for generating www-authenticate headers
from the given L{CredentialFactory} instances
"""
def __init__(self):
super(UnauthorizedResponse, self).__init__(
responsecode.UNAUTHORIZED,
"You are not authorized to access this resource.")
def _generateHeaders(self, factories, remoteAddr=None):
"""
Set up the response's headers.
@param factories: A L{dict} of {'scheme': ICredentialFactory}
@param remoteAddr: An L{IAddress} for the connecting client.
"""
schemes = []
challengeDs = []
for factory in factories.itervalues():
schemes.append(factory.scheme)
challengeDs.append(factory.getChallenge(remoteAddr))
def _setAuthHeader(challenges):
authHeaders = zip(schemes, challenges)
self.headers.setHeader('www-authenticate', authHeaders)
return gatherResults(challengeDs).addCallback(_setAuthHeader)
@classmethod
def makeResponse(cls, factories, remoteAddr=None):
"""
Create an Unauthorized response.
@param factories: A L{dict} of {'scheme': ICredentialFactory}
@param remoteAddr: An L{IAddress} for the connecting client.
@return: a Deferred that fires with the L{UnauthorizedResponse}
instance.
"""
response = UnauthorizedResponse()
d = response._generateHeaders(factories, remoteAddr)
d.addCallback(lambda _:response)
return d
class HTTPAuthResource(object):
"""I wrap a resource to prevent it being accessed unless the authentication
can be completed using the credential factory, portal, and interfaces
specified.
"""
implements(iweb.IResource)
def __init__(self, wrappedResource, credentialFactories,
portal, interfaces):
"""
@param wrappedResource: A L{twext.web2.iweb.IResource} to be returned
from locateChild and render upon successful
authentication.
@param credentialFactories: A list of instances that implement
L{ICredentialFactory}.
@type credentialFactories: L{list}
@param portal: Portal to handle logins for this resource.
@type portal: L{twisted.cred.portal.Portal}
@param interfaces: the interfaces that are allowed to log in via the
given portal
@type interfaces: L{tuple}
"""
self.wrappedResource = wrappedResource
self.credentialFactories = dict([(factory.scheme, factory)
for factory in credentialFactories])
self.portal = portal
self.interfaces = interfaces
def _loginSucceeded(self, avatar, request):
"""
Callback for successful login.
@param avatar: A tuple of the form (interface, avatar) as
returned by your realm.
@param request: L{IRequest} that encapsulates this auth
attempt.
@return: the IResource in C{self.wrappedResource}
"""
request.avatarInterface, request.avatar = avatar
directlyProvides(request, IAuthenticatedRequest)
def _addAuthenticateHeaders(request, response):
"""
A response filter that adds www-authenticate headers
to an outgoing response if it's code is UNAUTHORIZED (401)
and it does not already have them.
"""
if response.code == responsecode.UNAUTHORIZED:
if not response.headers.hasHeader('www-authenticate'):
d = UnauthorizedResponse.makeResponse(
self.credentialFactories,
request.remoteAddr)
def _respond(newResp):
response.headers.setHeader(
'www-authenticate',
newResp.headers.getHeader('www-authenticate'))
return response
d.addCallback(_respond)
return d
return succeed(response)
_addAuthenticateHeaders.handleErrors = True
request.addResponseFilter(_addAuthenticateHeaders)
return self.wrappedResource
def _loginFailed(self, ignored, request):
"""
Errback for failed login.
@param request: L{IRequest} that encapsulates this auth
attempt.
@return: A Deferred L{Failure} containing an L{HTTPError} containing the
L{UnauthorizedResponse} if C{result} is an L{UnauthorizedLogin}
or L{UnhandledCredentials} error
"""
d = UnauthorizedResponse.makeResponse(self.credentialFactories,
request.remoteAddr)
def _fail(response):
return failure.Failure(http.HTTPError(response))
return d.addCallback(_fail)
def login(self, factory, response, request):
"""
@param factory: An L{ICredentialFactory} that understands the given
response.
@param response: The client's authentication response as a string.
@param request: The request that prompted this authentication attempt.
@return: A L{Deferred} that fires with the wrappedResource on success
or a failure containing an L{UnauthorizedResponse}
"""
d = factory.decode(response, request)
def _decodeFailure(err):
err.trap(error.LoginFailed)
d = UnauthorizedResponse.makeResponse(self.credentialFactories,
request.remoteAddr)
def _respond(response):
return failure.Failure(http.HTTPError(response))
return d.addCallback(_respond)
def _login(creds):
return self.portal.login(creds, None, *self.interfaces
).addCallbacks(self._loginSucceeded,
self._loginFailed,
(request,), None,
(request,), None)
return d.addErrback(_decodeFailure).addCallback(_login)
def authenticate(self, request):
"""
Attempt to authenticate the given request
@param request: An L{IRequest} to be authenticated.
"""
authHeader = request.headers.getHeader('authorization')
if authHeader is None:
return self.portal.login(credentials.Anonymous(),
None,
*self.interfaces
).addCallbacks(self._loginSucceeded,
self._loginFailed,
(request,), None,
(request,), None)
elif authHeader[0] not in self.credentialFactories:
return self._loginFailed(None, request)
else:
return self.login(self.credentialFactories[authHeader[0]],
authHeader[1], request)
def locateChild(self, request, seg):
"""
Authenticate the request then return the C{self.wrappedResource}
and the unmodified segments.
"""
return self.authenticate(request), seg
def renderHTTP(self, request):
"""
Authenticate the request then return the result of calling renderHTTP
on C{self.wrappedResource}
"""
def _renderResource(resource):
return resource.renderHTTP(request)
d = self.authenticate(request)
d.addCallback(_renderResource)
return d
calendarserver-5.2+dfsg/twext/web2/auth/digest.py 0000644 0001750 0001750 00000010550 12263343324 021116 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_httpauth -*-
##
# Copyright (c) 2006-2009 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
Implementation of RFC2617: HTTP Digest Authentication
http://www.faqs.org/rfcs/rfc2617.html
"""
from zope.interface import implements
from twisted.python.hashlib import md5, sha1
from twisted.cred import credentials
# FIXME: Technically speaking - although you can't tell from looking at them -
# these APIs are private, they're defined within twisted.cred._digest. There
# should probably be some upstream bugs agains Twisted to more aggressively hide
# implementation details like these if they're not supposed to be used, so we
# can see the private-ness more clearly. The fix is really just to eliminate
# this whole module though, and use the Twisted stuff via the public interface,
# which should be sufficient to do digest auth.
from twisted.cred.credentials import (calcHA1 as _origCalcHA1,
calcResponse as _origCalcResponse,
calcHA2 as _origCalcHA2)
from twisted.internet.defer import maybeDeferred
from twext.web2.auth.interfaces import ICredentialFactory
# The digest math
algorithms = {
'md5': md5,
'md5-sess': md5,
'sha': sha1,
}
# DigestCalcHA1
def calcHA1(pszAlg, pszUserName, pszRealm, pszPassword, pszNonce, pszCNonce,
preHA1=None):
"""
@param pszAlg: The name of the algorithm to use to calculate the digest.
Currently supported are md5 md5-sess and sha.
@param pszUserName: The username
@param pszRealm: The realm
@param pszPassword: The password
@param pszNonce: The nonce
@param pszCNonce: The cnonce
@param preHA1: If available this is a str containing a previously
calculated HA1 as a hex string. If this is given then the values for
pszUserName, pszRealm, and pszPassword are ignored.
"""
return _origCalcHA1(pszAlg, pszUserName, pszRealm, pszPassword, pszNonce,
pszCNonce, preHA1)
# DigestCalcResponse
def calcResponse(
HA1,
algo,
pszNonce,
pszNonceCount,
pszCNonce,
pszQop,
pszMethod,
pszDigestUri,
pszHEntity,
):
return _origCalcResponse(HA1, _origCalcHA2(algo, pszMethod, pszDigestUri,
pszQop, pszHEntity),
algo, pszNonce, pszNonceCount, pszCNonce, pszQop)
DigestedCredentials = credentials.DigestedCredentials
class DigestCredentialFactory(object):
implements(ICredentialFactory)
CHALLENGE_LIFETIME_SECS = (
credentials.DigestCredentialFactory.CHALLENGE_LIFETIME_SECS
)
def __init__(self, algorithm, realm):
self._real = credentials.DigestCredentialFactory(algorithm, realm)
scheme = 'digest'
def getChallenge(self, peer):
return maybeDeferred(self._real.getChallenge, peer.host)
def generateOpaque(self, *a, **k):
return self._real._generateOpaque(*a, **k)
def verifyOpaque(self, opaque, nonce, clientip):
return self._real._verifyOpaque(opaque, nonce, clientip)
def decode(self, response, request):
method = getattr(request, "originalMethod", request.method)
host = request.remoteAddr.host
return self._real.decode(response, method, host)
calendarserver-5.2+dfsg/twext/web2/auth/__init__.py 0000644 0001750 0001750 00000002346 12263343324 021402 0 ustar rahul rahul ##
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
Client and server implementations of http authentication
"""
calendarserver-5.2+dfsg/twext/web2/server.py 0000644 0001750 0001750 00000064600 12263343324 020211 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_server -*-
##
# Copyright (c) 2001-2008 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
This is a web-server which integrates with the twisted.internet
infrastructure.
"""
from __future__ import print_function
import cgi, time, urlparse
from urllib import quote, unquote
from urlparse import urlsplit
import weakref
from zope.interface import implements
from twisted.internet import defer
from twisted.python import failure
from twext.python.log import Logger
from twext.web2 import http, iweb, fileupload, responsecode
from twext.web2 import http_headers
from twext.web2.filter.range import rangefilter
from twext.web2 import error
from twext.web2 import __version__ as web2_version
from twisted import __version__ as twisted_version
VERSION = "Twisted/%s TwistedWeb/%s" % (twisted_version, web2_version)
_errorMarker = object()
log = Logger()
def defaultHeadersFilter(request, response):
if not response.headers.hasHeader('server'):
response.headers.setHeader('server', VERSION)
if not response.headers.hasHeader('date'):
response.headers.setHeader('date', time.time())
return response
defaultHeadersFilter.handleErrors = True
def preconditionfilter(request, response):
if request.method in ("GET", "HEAD"):
http.checkPreconditions(request, response)
return response
def doTrace(request):
request = iweb.IRequest(request)
txt = "%s %s HTTP/%d.%d\r\n" % (request.method, request.uri,
request.clientproto[0], request.clientproto[1])
l=[]
for name, valuelist in request.headers.getAllRawHeaders():
for value in valuelist:
l.append("%s: %s\r\n" % (name, value))
txt += ''.join(l)
return http.Response(
responsecode.OK,
{'content-type': http_headers.MimeType('message', 'http')},
txt)
def parsePOSTData(request, maxMem=100*1024, maxFields=1024,
maxSize=10*1024*1024):
"""
Parse data of a POST request.
@param request: the request to parse.
@type request: L{twext.web2.http.Request}.
@param maxMem: maximum memory used during the parsing of the data.
@type maxMem: C{int}
@param maxFields: maximum number of form fields allowed.
@type maxFields: C{int}
@param maxSize: maximum size of file upload allowed.
@type maxSize: C{int}
@return: a deferred that will fire when the parsing is done. The deferred
itself doesn't hold a return value, the request is modified directly.
@rtype: C{defer.Deferred}
"""
if request.stream.length == 0:
return defer.succeed(None)
ctype = request.headers.getHeader('content-type')
if ctype is None:
return defer.succeed(None)
def updateArgs(data):
args = data
request.args.update(args)
def updateArgsAndFiles(data):
args, files = data
request.args.update(args)
request.files.update(files)
def error(f):
f.trap(fileupload.MimeFormatError)
raise http.HTTPError(
http.StatusResponse(responsecode.BAD_REQUEST, str(f.value)))
if (ctype.mediaType == 'application'
and ctype.mediaSubtype == 'x-www-form-urlencoded'):
d = fileupload.parse_urlencoded(request.stream)
d.addCallbacks(updateArgs, error)
return d
elif (ctype.mediaType == 'multipart'
and ctype.mediaSubtype == 'form-data'):
boundary = ctype.params.get('boundary')
if boundary is None:
return defer.fail(http.HTTPError(
http.StatusResponse(
responsecode.BAD_REQUEST,
"Boundary not specified in Content-Type.")))
d = fileupload.parseMultipartFormData(request.stream, boundary,
maxMem, maxFields, maxSize)
d.addCallbacks(updateArgsAndFiles, error)
return d
else:
return defer.fail(http.HTTPError(
http.StatusResponse(
responsecode.BAD_REQUEST,
"Invalid content-type: %s/%s" % (
ctype.mediaType, ctype.mediaSubtype))))
class StopTraversal(object):
"""
Indicates to Request._handleSegment that it should stop handling
path segments.
"""
pass
class Request(http.Request):
"""
vars:
site
remoteAddr
scheme
host
port
path
params
querystring
args
files
prepath
postpath
@ivar path: The path only (arguments not included).
@ivar args: All of the arguments, including URL and POST arguments.
@type args: A mapping of strings (the argument names) to lists of values.
i.e., ?foo=bar&foo=baz&quux=spam results in
{'foo': ['bar', 'baz'], 'quux': ['spam']}.
"""
implements(iweb.IRequest)
site = None
_initialprepath = None
responseFilters = [rangefilter, preconditionfilter,
error.defaultErrorHandler, defaultHeadersFilter]
def __init__(self, *args, **kw):
self.timeStamps = [("t", time.time(),)]
if kw.has_key('site'):
self.site = kw['site']
del kw['site']
if kw.has_key('prepathuri'):
self._initialprepath = kw['prepathuri']
del kw['prepathuri']
self._resourcesByURL = {}
self._urlsByResource = {}
# Copy response filters from the class
self.responseFilters = self.responseFilters[:]
self.files = {}
self.resources = []
http.Request.__init__(self, *args, **kw)
try:
self.serverInstance = self.chanRequest.channel.transport.server.port
except AttributeError:
self.serverInstance = "Unknown"
def timeStamp(self, tag):
self.timeStamps.append((tag, time.time(),))
def addResponseFilter(self, filter, atEnd=False, onlyOnce=False):
"""
Add a response filter to this request.
Response filters are applied to the response to this request in order.
@param filter: a callable which takes an response argument and returns
a response object.
@param atEnd: if C{True}, C{filter} is added at the end of the list of
response filters; if C{False}, it is added to the beginning.
@param onlyOnce: if C{True}, C{filter} is not added to the list of
response filters if it already in the list.
"""
if onlyOnce and filter in self.responseFilters:
return
if atEnd:
self.responseFilters.append(filter)
else:
self.responseFilters.insert(0, filter)
def unparseURL(self, scheme=None, host=None, port=None,
path=None, params=None, querystring=None, fragment=None):
"""Turn the request path into a url string. For any pieces of
the url that are not specified, use the value from the
request. The arguments have the same meaning as the same named
attributes of Request."""
if scheme is None: scheme = self.scheme
if host is None: host = self.host
if port is None: port = self.port
if path is None: path = self.path
if params is None: params = self.params
if querystring is None: querystring = self.querystring
if fragment is None: fragment = ''
if port == http.defaultPortForScheme.get(scheme, 0):
hostport = host
else:
hostport = host + ':' + str(port)
return urlparse.urlunparse((
scheme, hostport, path,
params, querystring, fragment))
def _parseURL(self):
if self.uri[0] == '/':
# Can't use urlparse for request_uri because urlparse
# wants to be given an absolute or relative URI, not just
# an abs_path, and thus gets '//foo' wrong.
self.scheme = self.host = self.path = self.params = self.querystring = ''
if '?' in self.uri:
self.path, self.querystring = self.uri.split('?', 1)
else:
self.path = self.uri
if ';' in self.path:
self.path, self.params = self.path.split(';', 1)
else:
# It is an absolute uri, use standard urlparse
(self.scheme, self.host, self.path,
self.params, self.querystring, fragment) = urlparse.urlparse(self.uri)
if self.querystring:
self.args = cgi.parse_qs(self.querystring, True)
else:
self.args = {}
path = map(unquote, self.path[1:].split('/'))
if self._initialprepath:
# We were given an initial prepath -- this is for supporting
# CGI-ish applications where part of the path has already
# been processed
prepath = map(unquote, self._initialprepath[1:].split('/'))
if path[:len(prepath)] == prepath:
self.prepath = prepath
self.postpath = path[len(prepath):]
else:
self.prepath = []
self.postpath = path
else:
self.prepath = []
self.postpath = path
#print("_parseURL", self.uri, (self.uri, self.scheme, self.host, self.path, self.params, self.querystring))
def _schemeFromPort(self, port):
"""
Try to determine the scheme matching the supplied server port. This is needed in case
where a device in front of the server is changing the scheme (e.g. decoding SSL) but not
rewriting the scheme in URIs returned in responses (e.g. in Location headers). This could trick
clients into using an inappropriate scheme for subsequent requests. What we should do is
take the port number from the Host header or request-URI and map that to the scheme that
matches the service we configured to listen on that port.
@param port: the port number to test
@type port: C{int}
@return: C{True} if scheme is https (secure), C{False} otherwise
@rtype: C{bool}
"""
#from twistedcaldav.config import config
if hasattr(self.site, "EnableSSL") and self.site.EnableSSL:
if port == self.site.SSLPort:
return True
elif port in self.site.BindSSLPorts:
return True
return False
def _fixupURLParts(self):
hostaddr, secure = self.chanRequest.getHostInfo()
if not self.scheme:
self.scheme = ('http', 'https')[secure]
if self.host:
self.host, self.port = http.splitHostPort(self.scheme, self.host)
self.scheme = ('http', 'https')[self._schemeFromPort(self.port)]
else:
# If GET line wasn't an absolute URL
host = self.headers.getHeader('host')
if host:
self.host, self.port = http.splitHostPort(self.scheme, host)
self.scheme = ('http', 'https')[self._schemeFromPort(self.port)]
else:
# When no hostname specified anywhere, either raise an
# error, or use the interface hostname, depending on
# protocol version
if self.clientproto >= (1,1):
raise http.HTTPError(responsecode.BAD_REQUEST)
self.host = hostaddr.host
self.port = hostaddr.port
def process(self):
"Process a request."
log.info("%s %s %s" % (
self.method,
self.uri,
"HTTP/%s.%s" % self.clientproto
))
try:
self.checkExpect()
resp = self.preprocessRequest()
if resp is not None:
self._cbFinishRender(resp).addErrback(self._processingFailed)
return
self._parseURL()
self._fixupURLParts()
self.remoteAddr = self.chanRequest.getRemoteHost()
except:
self._processingFailed(failure.Failure())
return
d = defer.Deferred()
d.addCallback(self._getChild, self.site.resource, self.postpath)
d.addCallback(self._rememberResource, "/" + "/".join(quote(s) for s in self.postpath))
d.addCallback(self._processTimeStamp)
d.addCallback(lambda res, req: res.renderHTTP(req), self)
d.addCallback(self._cbFinishRender)
d.addErrback(self._processingFailed)
d.callback(None)
return d
def _processTimeStamp(self, res):
self.timeStamp("t-req-proc")
return res
def preprocessRequest(self):
"""Do any request processing that doesn't follow the normal
resource lookup procedure. "OPTIONS *" is handled here, for
example. This would also be the place to do any CONNECT
processing."""
if self.method == "OPTIONS" and self.uri == "*":
response = http.Response(responsecode.OK)
response.headers.setHeader('allow', ('GET', 'HEAD', 'OPTIONS', 'TRACE'))
return response
elif self.method == "POST":
# Allow other methods to tunnel through using POST and a request header.
# See http://code.google.com/apis/gdata/docs/2.0/basics.html
if self.headers.hasHeader("X-HTTP-Method-Override"):
intendedMethod = self.headers.getRawHeaders("X-HTTP-Method-Override")[0];
if intendedMethod:
self.originalMethod = self.method
self.method = intendedMethod
# This is where CONNECT would go if we wanted it
return None
def _getChild(self, _, res, path, updatepaths=True):
"""Call res.locateChild, and pass the result on to _handleSegment."""
self.resources.append(res)
if not path:
return res
result = res.locateChild(self, path)
if isinstance(result, defer.Deferred):
return result.addCallback(self._handleSegment, res, path, updatepaths)
else:
return self._handleSegment(result, res, path, updatepaths)
def _handleSegment(self, result, res, path, updatepaths):
"""Handle the result of a locateChild call done in _getChild."""
newres, newpath = result
# If the child resource is None then display a error page
if newres is None:
raise http.HTTPError(responsecode.NOT_FOUND)
# If we got a deferred then we need to call back later, once the
# child is actually available.
if isinstance(newres, defer.Deferred):
return newres.addCallback(
lambda actualRes: self._handleSegment(
(actualRes, newpath), res, path, updatepaths)
)
if path:
url = quote("/" + "/".join(path))
else:
url = "/"
if newpath is StopTraversal:
# We need to rethink how to do this.
#if newres is res:
return res
#else:
# raise ValueError("locateChild must not return StopTraversal with a resource other than self.")
newres = iweb.IResource(newres)
if newres is res:
assert not newpath is path, "URL traversal cycle detected when attempting to locateChild %r from resource %r." % (path, res)
assert len(newpath) < len(path), "Infinite loop impending..."
if updatepaths:
# We found a Resource... update the request.prepath and postpath
for x in xrange(len(path) - len(newpath)):
self.prepath.append(self.postpath.pop(0))
url = quote("/" + "/".join(self.prepath) + ("/" if self.prepath and self.prepath[-1] else ""))
self._rememberResource(newres, url)
else:
try:
previousURL = self.urlForResource(res)
url = quote(previousURL + path[0] + ("/" if path[0] and len(path) > 1 else ""))
self._rememberResource(newres, url)
except NoURLForResourceError:
pass
child = self._getChild(None, newres, newpath, updatepaths=updatepaths)
return child
_urlsByResource = weakref.WeakKeyDictionary()
def _rememberResource(self, resource, url):
"""
Remember the URL of a visited resource.
"""
self._resourcesByURL[url] = resource
self._urlsByResource[resource] = url
return resource
def _forgetResource(self, resource, url):
"""
Remember the URL of a visited resource.
"""
del self._resourcesByURL[url]
del self._urlsByResource[resource]
def urlForResource(self, resource):
"""
Looks up the URL of the given resource if this resource was found while
processing this request. Specifically, this includes the requested
resource, and resources looked up via L{locateResource}.
Note that a resource may be found at multiple URIs; if the same resource
is visited at more than one location while processing this request,
this method will return one of those URLs, but which one is not defined,
nor whether the same URL is returned in subsequent calls.
@param resource: the resource to find a URI for. This resource must
have been obtained from the request (i.e. via its C{uri} attribute, or
through its C{locateResource} or C{locateChildResource} methods).
@return: a valid URL for C{resource} in this request.
@raise NoURLForResourceError: if C{resource} has no URL in this request
(because it was not obtained from the request).
"""
url = self._urlsByResource.get(resource, None)
if url is None:
raise NoURLForResourceError(resource)
return url
def locateResource(self, url):
"""
Looks up the resource with the given URL.
@param uri: The URL of the desired resource.
@return: a L{Deferred} resulting in the L{IResource} at the
given URL or C{None} if no such resource can be located.
@raise HTTPError: If C{url} is not a URL on the site that this
request is being applied to. The contained response will
have a status code of L{responsecode.BAD_GATEWAY}.
@raise HTTPError: If C{url} contains a query or fragment.
The contained response will have a status code of
L{responsecode.BAD_REQUEST}.
"""
if url is None:
return defer.succeed(None)
#
# Parse the URL
#
(scheme, host, path, query, fragment) = urlsplit(url)
if query or fragment:
raise http.HTTPError(http.StatusResponse(
responsecode.BAD_REQUEST,
"URL may not contain a query or fragment: %s" % (url,)
))
# Look for cached value
cached = self._resourcesByURL.get(path, None)
if cached is not None:
return defer.succeed(cached)
segments = unquote(path).split("/")
assert segments[0] == "", "URL path didn't begin with '/': %s" % (path,)
# Walk the segments up to see if we can find a cached resource to start from
preSegments = segments[:-1]
postSegments = segments[-1:]
cachedParent = None
while(len(preSegments)):
parentPath = "/".join(preSegments) + "/"
cachedParent = self._resourcesByURL.get(parentPath, None)
if cachedParent is not None:
break
else:
postSegments.insert(0, preSegments.pop())
if cachedParent is None:
cachedParent = self.site.resource
postSegments = segments[1:]
def notFound(f):
f.trap(http.HTTPError)
if f.value.response.code != responsecode.NOT_FOUND:
return f
return None
d = defer.maybeDeferred(self._getChild, None, cachedParent, postSegments, updatepaths=False)
d.addCallback(self._rememberResource, path)
d.addErrback(notFound)
return d
def locateChildResource(self, parent, childName):
"""
Looks up the child resource with the given name given the parent
resource. This is similar to locateResource(), but doesn't have to
start the lookup from the root resource, so it is potentially faster.
@param parent: the parent of the resource being looked up. This resource
must have been obtained from the request (i.e. via its C{uri} attribute,
or through its C{locateResource} or C{locateChildResource} methods).
@param childName: the name of the child of C{parent} to looked up.
to C{parent}.
@return: a L{Deferred} resulting in the L{IResource} at the
given URL or C{None} if no such resource can be located.
@raise NoURLForResourceError: if C{resource} was not obtained from the
request.
"""
if parent is None or childName is None:
return None
assert "/" not in childName, "Child name may not contain '/': %s" % (childName,)
parentURL = self.urlForResource(parent)
if not parentURL.endswith("/"):
parentURL += "/"
url = parentURL + quote(childName)
segment = childName
def notFound(f):
f.trap(http.HTTPError)
if f.value.response.code != responsecode.NOT_FOUND:
return f
return None
d = defer.maybeDeferred(self._getChild, None, parent, [segment], updatepaths=False)
d.addCallback(self._rememberResource, url)
d.addErrback(notFound)
return d
def _processingFailed(self, reason):
if reason.check(http.HTTPError) is not None:
# If the exception was an HTTPError, leave it alone
d = defer.succeed(reason.value.response)
else:
# Otherwise, it was a random exception, so give a
# ICanHandleException implementer a chance to render the page.
def _processingFailed_inner(reason):
handler = iweb.ICanHandleException(self, self)
return handler.renderHTTP_exception(self, reason)
d = defer.maybeDeferred(_processingFailed_inner, reason)
d.addCallback(self._cbFinishRender)
d.addErrback(self._processingReallyFailed, reason)
return d
def _processingReallyFailed(self, reason, origReason):
"""
An error occurred when attempting to report an error to the HTTP
client.
"""
log.failure("Exception rendering error page", reason)
log.failure("Original exception", origReason)
try:
body = (
"Internal Server Error"
"Internal Server Error
"
"An error occurred rendering the requested page. "
"Additionally, an error occurred rendering the error page."
""
)
response = http.Response(
responsecode.INTERNAL_SERVER_ERROR,
{'content-type': http_headers.MimeType('text','html')},
body
)
self.writeResponse(response)
except:
log.failure(
"An error occurred. We tried to report that error. "
"Reporting that error caused an error. "
"In the process of reporting the error-reporting error to "
"the client, there was *yet another* error. Here it is. "
"I give up."
)
self.chanRequest.abortConnection()
def _cbFinishRender(self, result):
def filterit(response, f):
if (hasattr(f, 'handleErrors') or
(response.code >= 200 and response.code < 300)):
return f(self, response)
else:
return response
response = iweb.IResponse(result, None)
if response:
d = defer.Deferred()
for f in self.responseFilters:
d.addCallback(filterit, f)
d.addCallback(self.writeResponse)
d.callback(response)
return d
resource = iweb.IResource(result, None)
if resource:
self.resources.append(resource)
d = defer.maybeDeferred(resource.renderHTTP, self)
d.addCallback(self._cbFinishRender)
return d
raise TypeError("html is not a resource or a response")
def renderHTTP_exception(self, req, reason):
log.failure("Exception rendering request: {request}", reason, request=req)
body = ("Internal Server Error"
"Internal Server Error
An error occurred rendering the requested page. More information is available in the server log.")
return http.Response(
responsecode.INTERNAL_SERVER_ERROR,
{'content-type': http_headers.MimeType('text','html')},
body)
class Site(object):
def __init__(self, resource):
"""Initialize.
"""
self.resource = iweb.IResource(resource)
def __call__(self, *args, **kwargs):
return Request(site=self, *args, **kwargs)
class NoURLForResourceError(RuntimeError):
def __init__(self, resource):
RuntimeError.__init__(self, "Resource %r has no URL in this request." % (resource,))
self.resource = resource
__all__ = ['Request', 'Site', 'StopTraversal', 'VERSION', 'defaultHeadersFilter', 'doTrace', 'parsePOSTData', 'preconditionfilter', 'NoURLForResourceError']
calendarserver-5.2+dfsg/twext/web2/stream.py 0000644 0001750 0001750 00000111266 12263343324 020177 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_stream -*-
##
# Copyright (c) 2001-2007 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
The stream module provides a simple abstraction of streaming
data. While Twisted already has some provisions for handling this in
its Producer/Consumer model, the rather complex interactions between
producer and consumer makes it difficult to implement something like
the CompoundStream object. Thus, this API.
The IStream interface is very simple. It consists of two methods:
read, and close. The read method should either return some data, None
if there is no data left to read, or a Deferred. Close frees up any
underlying resources and causes read to return None forevermore.
IByteStream adds a bit more to the API:
1) read is required to return objects conforming to the buffer interface.
2) .length, which may either an integer number of bytes remaining, or
None if unknown
3) .split(position). Split takes a position, and splits the
stream in two pieces, returning the two new streams. Using the
original stream after calling split is not allowed.
There are two builtin source stream classes: FileStream and
MemoryStream. The first produces data from a file object, the second
from a buffer in memory. Any number of these can be combined into one
stream with the CompoundStream object. Then, to interface with other
parts of Twisted, there are two transcievers: StreamProducer and
ProducerStream. The first takes a stream and turns it into an
IPushProducer, which will write to a consumer. The second is a
consumer which is a stream, so that other producers can write to it.
"""
from __future__ import generators
import copy, os, types, sys
from zope.interface import Interface, Attribute, implements
from twisted.internet.defer import Deferred
from twisted.internet import interfaces as ti_interfaces, defer, reactor, protocol, error as ti_error
from twisted.python import components
from twisted.python.failure import Failure
from hashlib import md5
from twext.python.log import Logger
log = Logger()
# Python 2.4.2 (only) has a broken mmap that leaks a fd every time you call it.
if sys.version_info[0:3] != (2,4,2):
try:
import mmap
except ImportError:
mmap = None
else:
mmap = None
##############################
#### Interfaces ####
##############################
class IStream(Interface):
"""A stream of arbitrary data."""
def read():
"""Read some data.
Returns some object representing the data.
If there is no more data available, returns None.
Can also return a Deferred resulting in one of the above.
Errors may be indicated by exception or by a Deferred of a Failure.
"""
def close():
"""Prematurely close. Should also cause further reads to
return None."""
class IByteStream(IStream):
"""A stream which is of bytes."""
length = Attribute("""How much data is in this stream. Can be None if unknown.""")
def read():
"""Read some data.
Returns an object conforming to the buffer interface, or
if there is no more data available, returns None.
Can also return a Deferred resulting in one of the above.
Errors may be indicated by exception or by a Deferred of a Failure.
"""
def split(point):
"""Split this stream into two, at byte position 'point'.
Returns a tuple of (before, after). After calling split, no other
methods should be called on this stream. Doing so will have undefined
behavior.
If you cannot implement split easily, you may implement it as::
return fallbackSplit(self, point)
"""
def close():
"""Prematurely close this stream. Should also cause further reads to
return None. Additionally, .length should be set to 0.
"""
class ISendfileableStream(Interface):
def read(sendfile=False):
"""
Read some data.
If sendfile == False, returns an object conforming to the buffer
interface, or else a Deferred.
If sendfile == True, returns either the above, or a SendfileBuffer.
"""
class SimpleStream(object):
"""Superclass of simple streams with a single buffer and a offset and length
into that buffer."""
implements(IByteStream)
length = None
start = None
def read(self):
return None
def close(self):
self.length = 0
def split(self, point):
if self.length is not None:
if point > self.length:
raise ValueError("split point (%d) > length (%d)" % (point, self.length))
b = copy.copy(self)
self.length = point
if b.length is not None:
b.length -= point
b.start += point
return (self, b)
##############################
#### FileStream ####
##############################
# maximum mmap size
MMAP_LIMIT = 4*1024*1024
# minimum mmap size
MMAP_THRESHOLD = 8*1024
# maximum sendfile length
SENDFILE_LIMIT = 16777216
# minimum sendfile size
SENDFILE_THRESHOLD = 256
def mmapwrapper(*args, **kwargs):
"""
Python's mmap call sucks and ommitted the "offset" argument for no
discernable reason. Replace this with a mmap module that has offset.
"""
offset = kwargs.get('offset', None)
if offset in [None, 0]:
if 'offset' in kwargs:
del kwargs['offset']
else:
raise mmap.error("mmap: Python sucks and does not support offset.")
return mmap.mmap(*args, **kwargs)
class FileStream(SimpleStream):
implements(ISendfileableStream)
"""A stream that reads data from a file. File must be a normal
file that supports seek, (e.g. not a pipe or device or socket)."""
# 65K, minus some slack
CHUNK_SIZE = 2 ** 2 ** 2 ** 2 - 32
f = None
def __init__(self, f, start=0, length=None, useMMap=bool(mmap)):
"""
Create the stream from file f. If you specify start and length,
use only that portion of the file.
"""
self.f = f
self.start = start
if length is None:
self.length = os.fstat(f.fileno()).st_size
else:
self.length = length
self.useMMap = useMMap
def read(self, sendfile=False):
if self.f is None:
return None
length = self.length
if length == 0:
self.f = None
return None
#if sendfile and length > SENDFILE_THRESHOLD:
# # XXX: Yay using non-existent sendfile support!
# # FIXME: if we return a SendfileBuffer, and then sendfile
# # fails, then what? Or, what if file is too short?
# readSize = min(length, SENDFILE_LIMIT)
# res = SendfileBuffer(self.f, self.start, readSize)
# self.length -= readSize
# self.start += readSize
# return res
if self.useMMap and length > MMAP_THRESHOLD:
readSize = min(length, MMAP_LIMIT)
try:
res = mmapwrapper(self.f.fileno(), readSize,
access=mmap.ACCESS_READ, offset=self.start)
#madvise(res, MADV_SEQUENTIAL)
self.length -= readSize
self.start += readSize
return res
except mmap.error:
pass
# Fall back to standard read.
readSize = min(length, self.CHUNK_SIZE)
self.f.seek(self.start)
b = self.f.read(readSize)
bytesRead = len(b)
if not bytesRead:
raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
else:
self.length -= bytesRead
self.start += bytesRead
return b
def close(self):
self.f = None
SimpleStream.close(self)
components.registerAdapter(FileStream, file, IByteStream)
##############################
#### MemoryStream ####
##############################
class MemoryStream(SimpleStream):
"""A stream that reads data from a buffer object."""
def __init__(self, mem, start=0, length=None):
"""
Create the stream from buffer object mem. If you specify start and length,
use only that portion of the buffer.
"""
self.mem = mem
self.start = start
if length is None:
self.length = len(mem) - start
else:
if len(mem) < length:
raise ValueError("len(mem) < start + length")
self.length = length
def read(self):
if self.mem is None:
return None
if self.length == 0:
result = None
else:
result = buffer(self.mem, self.start, self.length)
self.mem = None
self.length = 0
return result
def close(self):
self.mem = None
SimpleStream.close(self)
components.registerAdapter(MemoryStream, str, IByteStream)
components.registerAdapter(MemoryStream, types.BufferType, IByteStream)
##############################
#### CompoundStream ####
##############################
class CompoundStream(object):
"""A stream which is composed of many other streams.
Call addStream to add substreams.
"""
implements(IByteStream, ISendfileableStream)
deferred = None
length = 0
def __init__(self, buckets=()):
self.buckets = [IByteStream(s) for s in buckets]
def addStream(self, bucket):
"""Add a stream to the output"""
bucket = IByteStream(bucket)
self.buckets.append(bucket)
if self.length is not None:
if bucket.length is None:
self.length = None
else:
self.length += bucket.length
def read(self, sendfile=False):
if self.deferred is not None:
raise RuntimeError("Call to read while read is already outstanding")
if not self.buckets:
return None
if sendfile and ISendfileableStream.providedBy(self.buckets[0]):
try:
result = self.buckets[0].read(sendfile)
except:
return self._gotFailure(Failure())
else:
try:
result = self.buckets[0].read()
except:
return self._gotFailure(Failure())
if isinstance(result, Deferred):
self.deferred = result
result.addCallbacks(self._gotRead, self._gotFailure, (sendfile,))
return result
return self._gotRead(result, sendfile)
def _gotFailure(self, f):
self.deferred = None
del self.buckets[0]
self.close()
return f
def _gotRead(self, result, sendfile):
self.deferred = None
if result is None:
del self.buckets[0]
# Next bucket
return self.read(sendfile)
if self.length is not None:
self.length -= len(result)
return result
def split(self, point):
num = 0
origPoint = point
for bucket in self.buckets:
num+=1
if point == 0:
b = CompoundStream()
b.buckets = self.buckets[num:]
del self.buckets[num:]
return self,b
if bucket.length is None:
# Indeterminate length bucket.
# give up and use fallback splitter.
return fallbackSplit(self, origPoint)
if point < bucket.length:
before,after = bucket.split(point)
b = CompoundStream()
b.buckets = self.buckets[num:]
b.buckets[0] = after
del self.buckets[num+1:]
self.buckets[num] = before
return self,b
point -= bucket.length
def close(self):
for bucket in self.buckets:
bucket.close()
self.buckets = []
self.length = 0
##############################
#### readStream ####
##############################
class _StreamReader(object):
"""Process a stream's data using callbacks for data and stream finish."""
def __init__(self, stream, gotDataCallback):
self.stream = stream
self.gotDataCallback = gotDataCallback
self.result = Deferred()
def run(self):
# self.result may be del'd in _read()
result = self.result
self._read()
return result
def _read(self):
try:
result = self.stream.read()
except:
self._gotError(Failure())
return
if isinstance(result, Deferred):
result.addCallbacks(self._gotData, self._gotError)
else:
self._gotData(result)
def _gotError(self, failure):
result = self.result
del self.result, self.gotDataCallback, self.stream
result.errback(failure)
def _gotData(self, data):
if data is None:
result = self.result
del self.result, self.gotDataCallback, self.stream
result.callback(None)
return
try:
self.gotDataCallback(data)
except:
self._gotError(Failure())
return
reactor.callLater(0, self._read)
def readStream(stream, gotDataCallback):
"""Pass a stream's data to a callback.
Returns Deferred which will be triggered on finish. Errors in
reading the stream or in processing it will be returned via this
Deferred.
"""
return _StreamReader(stream, gotDataCallback).run()
def readAndDiscard(stream):
"""Read all the data from the given stream, and throw it out.
Returns Deferred which will be triggered on finish.
"""
return readStream(stream, lambda _: None)
def readIntoFile(stream, outFile):
"""Read a stream and write it into a file.
Returns Deferred which will be triggered on finish.
"""
def done(_):
outFile.close()
return _
return readStream(stream, outFile.write).addBoth(done)
def connectStream(inputStream, factory):
"""Connect a protocol constructed from a factory to stream.
Returns an output stream from the protocol.
The protocol's transport will have a finish() method it should
call when done writing.
"""
# XXX deal better with addresses
p = factory.buildProtocol(None)
out = ProducerStream()
out.disconnecting = False # XXX for LineReceiver suckage
p.makeConnection(out)
readStream(inputStream, lambda _: p.dataReceived(_)).addCallbacks(
lambda _: p.connectionLost(ti_error.ConnectionDone()), lambda _: p.connectionLost(_))
return out
##############################
#### fallbackSplit ####
##############################
def fallbackSplit(stream, point):
after = PostTruncaterStream(stream, point)
before = TruncaterStream(stream, point, after)
return (before, after)
class TruncaterStream(object):
def __init__(self, stream, point, postTruncater):
self.stream = stream
self.length = point
self.postTruncater = postTruncater
def read(self):
if self.length == 0:
if self.postTruncater is not None:
postTruncater = self.postTruncater
self.postTruncater = None
postTruncater.sendInitialSegment(self.stream.read())
self.stream = None
return None
result = self.stream.read()
if isinstance(result, Deferred):
return result.addCallback(self._gotRead)
else:
return self._gotRead(result)
def _gotRead(self, data):
if data is None:
raise ValueError("Ran out of data for a split of a indeterminate length source")
if self.length >= len(data):
self.length -= len(data)
return data
else:
before = buffer(data, 0, self.length)
after = buffer(data, self.length)
self.length = 0
if self.postTruncater is not None:
postTruncater = self.postTruncater
self.postTruncater = None
postTruncater.sendInitialSegment(after)
self.stream = None
return before
def split(self, point):
if point > self.length:
raise ValueError("split point (%d) > length (%d)" % (point, self.length))
post = PostTruncaterStream(self.stream, point)
trunc = TruncaterStream(post, self.length - point, self.postTruncater)
self.length = point
self.postTruncater = post
return self, trunc
def close(self):
if self.postTruncater is not None:
self.postTruncater.notifyClosed(self)
else:
# Nothing cares about the rest of the stream
self.stream.close()
self.stream = None
self.length = 0
class PostTruncaterStream(object):
deferred = None
sentInitialSegment = False
truncaterClosed = None
closed = False
length = None
def __init__(self, stream, point):
self.stream = stream
self.deferred = Deferred()
if stream.length is not None:
self.length = stream.length - point
def read(self):
if not self.sentInitialSegment:
self.sentInitialSegment = True
if self.truncaterClosed is not None:
readAndDiscard(self.truncaterClosed)
self.truncaterClosed = None
return self.deferred
return self.stream.read()
def split(self, point):
return fallbackSplit(self, point)
def close(self):
self.closed = True
if self.truncaterClosed is not None:
# have first half close itself
self.truncaterClosed.postTruncater = None
self.truncaterClosed.close()
elif self.sentInitialSegment:
# first half already finished up
self.stream.close()
self.deferred = None
# Callbacks from TruncaterStream
def sendInitialSegment(self, data):
if self.closed:
# First half finished, we don't want data.
self.stream.close()
self.stream = None
if self.deferred is not None:
if isinstance(data, Deferred):
data.chainDeferred(self.deferred)
else:
self.deferred.callback(data)
def notifyClosed(self, truncater):
if self.closed:
# we are closed, have first half really close
truncater.postTruncater = None
truncater.close()
elif self.sentInitialSegment:
# We are trying to read, read up first half
readAndDiscard(truncater)
else:
# Idle, store closed info.
self.truncaterClosed = truncater
########################################
#### ProducerStream/StreamProducer ####
########################################
class ProducerStream(object):
"""Turns producers into a IByteStream.
Thus, implements IConsumer and IByteStream."""
implements(IByteStream, ti_interfaces.IConsumer)
length = None
closed = False
failed = False
producer = None
producerPaused = False
deferred = None
bufferSize = 5
def __init__(self, length=None):
self.buffer = []
self.length = length
# IByteStream implementation
def read(self):
if self.buffer:
return self.buffer.pop(0)
elif self.closed:
self.length = 0
if self.failed:
f = self.failure
del self.failure
return defer.fail(f)
return None
else:
deferred = self.deferred = Deferred()
if self.producer is not None and (not self.streamingProducer
or self.producerPaused):
self.producerPaused = False
self.producer.resumeProducing()
return deferred
def split(self, point):
return fallbackSplit(self, point)
def close(self):
"""Called by reader of stream when it is done reading."""
self.buffer=[]
self.closed = True
if self.producer is not None:
self.producer.stopProducing()
self.producer = None
self.deferred = None
# IConsumer implementation
def write(self, data):
if self.closed:
return
if self.deferred:
deferred = self.deferred
self.deferred = None
deferred.callback(data)
else:
self.buffer.append(data)
if(self.producer is not None and self.streamingProducer
and len(self.buffer) > self.bufferSize):
self.producer.pauseProducing()
self.producerPaused = True
def finish(self, failure=None):
"""Called by producer when it is done.
If the optional failure argument is passed a Failure instance,
the stream will return it as errback on next Deferred.
"""
self.closed = True
if not self.buffer:
self.length = 0
if self.deferred is not None:
deferred = self.deferred
self.deferred = None
if failure is not None:
self.failed = True
deferred.errback(failure)
else:
deferred.callback(None)
else:
if failure is not None:
self.failed = True
self.failure = failure
def registerProducer(self, producer, streaming):
if self.producer is not None:
raise RuntimeError("Cannot register producer %s, because producer %s was never unregistered." % (producer, self.producer))
if self.closed:
producer.stopProducing()
else:
self.producer = producer
self.streamingProducer = streaming
if not streaming:
producer.resumeProducing()
def unregisterProducer(self):
self.producer = None
class StreamProducer(object):
"""A push producer which gets its data by reading a stream."""
implements(ti_interfaces.IPushProducer)
deferred = None
finishedCallback = None
paused = False
consumer = None
def __init__(self, stream, enforceStr=True):
self.stream = stream
self.enforceStr = enforceStr
def beginProducing(self, consumer):
if self.stream is None:
return defer.succeed(None)
self.consumer = consumer
finishedCallback = self.finishedCallback = Deferred()
self.consumer.registerProducer(self, True)
self.resumeProducing()
return finishedCallback
def resumeProducing(self):
self.paused = False
if self.deferred is not None:
return
try:
data = self.stream.read()
except:
self.stopProducing(Failure())
return
if isinstance(data, Deferred):
self.deferred = data
self.deferred.addCallbacks(self._doWrite, self.stopProducing)
else:
self._doWrite(data)
def _doWrite(self, data):
if self.consumer is None:
return
if data is None:
# The end.
if self.consumer is not None:
self.consumer.unregisterProducer()
if self.finishedCallback is not None:
self.finishedCallback.callback(None)
self.finishedCallback = self.deferred = self.consumer = self.stream = None
return
self.deferred = None
if self.enforceStr:
# XXX: sucks that we have to do this. make transport.write(buffer) work!
data = str(buffer(data))
self.consumer.write(data)
if not self.paused:
self.resumeProducing()
def pauseProducing(self):
self.paused = True
def stopProducing(self, failure=ti_error.ConnectionLost()):
if self.consumer is not None:
self.consumer.unregisterProducer()
if self.finishedCallback is not None:
if failure is not None:
self.finishedCallback.errback(failure)
else:
self.finishedCallback.callback(None)
self.finishedCallback = None
self.paused = True
if self.stream is not None:
self.stream.close()
self.finishedCallback = self.deferred = self.consumer = self.stream = None
##############################
#### ProcessStreamer ####
##############################
class _ProcessStreamerProtocol(protocol.ProcessProtocol):
def __init__(self, inputStream, outStream, errStream):
self.inputStream = inputStream
self.outStream = outStream
self.errStream = errStream
self.resultDeferred = defer.Deferred()
def connectionMade(self):
p = StreamProducer(self.inputStream)
# if the process stopped reading from the input stream,
# this is not an error condition, so it oughtn't result
# in a ConnectionLost() from the input stream:
p.stopProducing = lambda err=None: StreamProducer.stopProducing(p, err)
d = p.beginProducing(self.transport)
d.addCallbacks(lambda _: self.transport.closeStdin(),
self._inputError)
def _inputError(self, f):
log.failure("Error in input stream for transport {transport}", f, transport=self.transport)
self.transport.closeStdin()
def outReceived(self, data):
self.outStream.write(data)
def errReceived(self, data):
self.errStream.write(data)
def outConnectionLost(self):
self.outStream.finish()
def errConnectionLost(self):
self.errStream.finish()
def processEnded(self, reason):
self.resultDeferred.errback(reason)
del self.resultDeferred
class ProcessStreamer(object):
"""Runs a process hooked up to streams.
Requires an input stream, has attributes 'outStream' and 'errStream'
for stdout and stderr.
outStream and errStream are public attributes providing streams
for stdout and stderr of the process.
"""
def __init__(self, inputStream, program, args, env={}):
self.outStream = ProducerStream()
self.errStream = ProducerStream()
self._protocol = _ProcessStreamerProtocol(IByteStream(inputStream), self.outStream, self.errStream)
self._program = program
self._args = args
self._env = env
def run(self):
"""Run the process.
Returns Deferred which will eventually have errback for non-clean (exit code > 0)
exit, with ProcessTerminated, or callback with None on exit code 0.
"""
# XXX what happens if spawn fails?
reactor.spawnProcess(self._protocol, self._program, self._args, env=self._env)
del self._env
return self._protocol.resultDeferred.addErrback(lambda _: _.trap(ti_error.ProcessDone))
##############################
#### generatorToStream ####
##############################
class _StreamIterator(object):
done=False
def __iter__(self):
return self
def next(self):
if self.done:
raise StopIteration
return self.value
wait=object()
class _IteratorStream(object):
length = None
def __init__(self, fun, stream, args, kwargs):
self._stream=stream
self._streamIterator = _StreamIterator()
self._gen = fun(self._streamIterator, *args, **kwargs)
def read(self):
try:
val = self._gen.next()
except StopIteration:
return None
else:
if val is _StreamIterator.wait:
newdata = self._stream.read()
if isinstance(newdata, defer.Deferred):
return newdata.addCallback(self._gotRead)
else:
return self._gotRead(newdata)
return val
def _gotRead(self, data):
if data is None:
self._streamIterator.done=True
else:
self._streamIterator.value=data
return self.read()
def close(self):
self._stream.close()
del self._gen, self._stream, self._streamIterator
def split(self):
return fallbackSplit(self)
def generatorToStream(fun):
"""Converts a generator function into a stream.
The function should take an iterator as its first argument,
which will be converted *from* a stream by this wrapper, and
yield items which are turned *into* the results from the
stream's 'read' call.
One important point: before every call to input.next(), you
*MUST* do a "yield input.wait" first. Yielding this magic value
takes care of ensuring that the input is not a deferred before
you see it.
>>> from twext.web2 import stream
>>> from string import maketrans
>>> alphabet = 'abcdefghijklmnopqrstuvwxyz'
>>>
>>> def encrypt(input, key):
... code = alphabet[key:] + alphabet[:key]
... translator = maketrans(alphabet+alphabet.upper(), code+code.upper())
... yield input.wait
... for s in input:
... yield str(s).translate(translator)
... yield input.wait
...
>>> encrypt = stream.generatorToStream(encrypt)
>>>
>>> plaintextStream = stream.MemoryStream('SampleSampleSample')
>>> encryptedStream = encrypt(plaintextStream, 13)
>>> encryptedStream.read()
'FnzcyrFnzcyrFnzcyr'
>>>
>>> plaintextStream = stream.MemoryStream('SampleSampleSample')
>>> encryptedStream = encrypt(plaintextStream, 13)
>>> evenMoreEncryptedStream = encrypt(encryptedStream, 13)
>>> evenMoreEncryptedStream.read()
'SampleSampleSample'
"""
def generatorToStream_inner(stream, *args, **kwargs):
return _IteratorStream(fun, stream, args, kwargs)
return generatorToStream_inner
##############################
#### BufferedStream ####
##############################
class BufferedStream(object):
"""A stream which buffers its data to provide operations like
readline and readExactly."""
data = ""
def __init__(self, stream):
self.stream = stream
def _readUntil(self, f):
"""Internal helper function which repeatedly calls f each time
after more data has been received, until it returns non-None."""
while True:
r = f()
if r is not None:
yield r; return
newdata = self.stream.read()
if isinstance(newdata, defer.Deferred):
newdata = defer.waitForDeferred(newdata)
yield newdata; newdata = newdata.getResult()
if newdata is None:
# End Of File
newdata = self.data
self.data = ''
yield newdata; return
self.data += str(newdata)
_readUntil = defer.deferredGenerator(_readUntil)
def readExactly(self, size=None):
"""Read exactly size bytes of data, or, if size is None, read
the entire stream into a string."""
if size is not None and size < 0:
raise ValueError("readExactly: size cannot be negative: %s", size)
def gotdata():
data = self.data
if size is not None and len(data) >= size:
pre,post = data[:size], data[size:]
self.data = post
return pre
return self._readUntil(gotdata)
def readline(self, delimiter='\r\n', size=None):
"""
Read a line of data from the string, bounded by
delimiter. The delimiter is included in the return value.
If size is specified, read and return at most that many bytes,
even if the delimiter has not yet been reached. If the size
limit falls within a delimiter, the rest of the delimiter, and
the next line will be returned together.
"""
if size is not None and size < 0:
raise ValueError("readline: size cannot be negative: %s" % (size, ))
def gotdata():
data = self.data
if size is not None:
splitpoint = data.find(delimiter, 0, size)
if splitpoint == -1:
if len(data) >= size:
splitpoint = size
else:
splitpoint += len(delimiter)
else:
splitpoint = data.find(delimiter)
if splitpoint != -1:
splitpoint += len(delimiter)
if splitpoint != -1:
pre = data[:splitpoint]
self.data = data[splitpoint:]
return pre
return self._readUntil(gotdata)
def pushback(self, pushed):
"""Push data back into the buffer."""
self.data = pushed + self.data
def read(self):
data = self.data
if data:
self.data = ""
return data
return self.stream.read()
def _len(self):
l = self.stream.length
if l is None:
return None
return l + len(self.data)
length = property(_len)
def split(self, offset):
off = offset - len(self.data)
pre, post = self.stream.split(max(0, off))
pre = BufferedStream(pre)
post = BufferedStream(post)
if off < 0:
pre.data = self.data[:-off]
post.data = self.data[-off:]
else:
pre.data = self.data
return pre, post
#########################
#### MD5Stream ####
#########################
class MD5Stream(SimpleStream):
"""
An wrapper which computes the MD5 hash of the data read from the
wrapped stream.
"""
def __init__(self, wrap):
if wrap is None:
raise ValueError("Stream to wrap must be provided")
self._stream = wrap
self._md5 = md5()
def _update(self, value):
"""
Update the MD5 hash object.
@param value: L{None} or a L{str} with which to update the MD5 hash
object.
@return: C{value}
"""
if value is not None:
self._md5.update(value)
return value
def read(self):
"""
Read from the wrapped stream and update the MD5 hash object.
"""
if self._stream is None:
raise RuntimeError("Cannot read after stream is closed")
b = self._stream.read()
if isinstance(b, Deferred):
b.addCallback(self._update)
else:
self._update(b)
return b
def close(self):
"""
Compute the final hex digest of the contents of the wrapped stream.
"""
SimpleStream.close(self)
self._md5value = self._md5.hexdigest()
self._stream = None
self._md5 = None
def getMD5(self):
"""
Return the hex encoded MD5 digest of the contents of the wrapped
stream. This may only be called after C{close}.
@rtype: C{str}
@raise RuntimeError: If C{close} has not yet been called.
"""
if self._md5 is not None:
raise RuntimeError("Cannot get MD5 value until stream is closed")
return self._md5value
__all__ = ['IStream', 'IByteStream', 'FileStream', 'MemoryStream', 'CompoundStream',
'readAndDiscard', 'fallbackSplit', 'ProducerStream', 'StreamProducer',
'BufferedStream', 'MD5Stream', 'readStream', 'ProcessStreamer', 'readIntoFile',
'generatorToStream']
calendarserver-5.2+dfsg/twext/web2/resource.py 0000644 0001750 0001750 00000030047 12263343324 020530 0 ustar rahul rahul # -*- test-case-name: twext.web2.test.test_server,twext.web2.test.test_resource -*-
##
# Copyright (c) 2001-2007 Twisted Matrix Laboratories.
# Copyright (c) 2010-2014 Apple Computer, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##
"""
I hold the lowest-level L{Resource} class and related mix-in classes.
"""
# System Imports
from zope.interface import implements
from twisted.internet.defer import inlineCallbacks, returnValue
from twext.web2 import iweb, http, server, responsecode
from twisted.internet.defer import maybeDeferred
class RenderMixin(object):
"""
Mix-in class for L{iweb.IResource} which provides a dispatch mechanism for
handling HTTP methods.
"""
def allowedMethods(self):
"""
@return: A tuple of HTTP methods that are allowed to be invoked on this resource.
"""
if not hasattr(self, "_allowed_methods"):
self._allowed_methods = tuple([name[5:] for name in dir(self)
if name.startswith('http_') and getattr(self, name) is not None])
return self._allowed_methods
def checkPreconditions(self, request):
"""
Checks all preconditions imposed by this resource upon a request made
against it.
@param request: the request to process.
@raise http.HTTPError: if any precondition fails.
@return: C{None} or a deferred whose callback value is C{request}.
"""
#
# http.checkPreconditions() gets called by the server after every
# GET or HEAD request.
#
# For other methods, we need to know to bail out before request
# processing, especially for methods that modify server state (eg. PUT).
# We also would like to do so even for methods that don't, if those
# methods might be expensive to process. We're assuming that GET and
# HEAD are not expensive.
#
if request.method not in ("GET", "HEAD"):
http.checkPreconditions(request)
# Check per-method preconditions
method = getattr(self, "preconditions_" + request.method, None)
if method:
return method(request)
@inlineCallbacks
def renderHTTP(self, request):
"""
See L{iweb.IResource.renderHTTP}.
This implementation will dispatch the given C{request} to another method
of C{self} named C{http_}METHOD, where METHOD is the HTTP method used by
C{request} (eg. C{http_GET}, C{http_POST}, etc.).
Generally, a subclass should implement those methods instead of
overriding this one.
C{http_*} methods are expected provide the same interface and return the
same results as L{iweb.IResource}C{.renderHTTP} (and therefore this method).
C{etag} and C{last-modified} are added to the response returned by the
C{http_*} header, if known.
If an appropriate C{http_*} method is not found, a
L{responsecode.NOT_ALLOWED}-status response is returned, with an
appropriate C{allow} header.
@param request: the request to process.
@return: an object adaptable to L{iweb.IResponse}.
"""
method = getattr(self, "http_" + request.method, None)
if method is None:
response = http.Response(responsecode.NOT_ALLOWED)
response.headers.setHeader("allow", self.allowedMethods())
returnValue(response)
yield self.checkPreconditions(request)
result = maybeDeferred(method, request)
result.addErrback(self.methodRaisedException)
returnValue((yield result))
def methodRaisedException(self, failure):
"""
An C{http_METHOD} method raised an exception; this is an errback for
that exception. By default, simply propagate the error up; subclasses
may override this for top-level exception handling.
"""
return failure
def http_OPTIONS(self, request):
"""
Respond to a OPTIONS request.
@param request: the request to process.
@return: an object adaptable to L{iweb.IResponse}.
"""
response = http.Response(responsecode.OK)
response.headers.setHeader("allow", self.allowedMethods())
return response
# def http_TRACE(self, request):
# """
# Respond to a TRACE request.
# @param request: the request to process.
# @return: an object adaptable to L{iweb.IResponse}.
# """
# return server.doTrace(request)
def http_HEAD(self, request):
"""
Respond to a HEAD request.
@param request: the request to process.
@return: an object adaptable to L{iweb.IResponse}.
"""
return self.http_GET(request)
def http_GET(self, request):
"""
Respond to a GET request.
This implementation validates that the request body is empty and then
dispatches the given C{request} to L{render} and returns its result.
@param request: the request to process.
@return: an object adaptable to L{iweb.IResponse}.
"""
if request.stream.length != 0:
return responsecode.REQUEST_ENTITY_TOO_LARGE
return self.render(request)
def render(self, request):
"""
Subclasses should implement this method to do page rendering.
See L{http_GET}.
@param request: the request to process.
@return: an object adaptable to L{iweb.IResponse}.
"""
raise NotImplementedError("Subclass must implement render method.")
class Resource(RenderMixin):
"""
An L{iweb.IResource} implementation with some convenient mechanisms for
locating children.
"""
implements(iweb.IResource)
addSlash = False
def locateChild(self, request, segments):
"""
Locates a child resource of this resource.
@param request: the request to process.
@param segments: a sequence of URL path segments.
@return: a tuple of C{(child, segments)} containing the child
of this resource which matches one or more of the given C{segments} in
sequence, and a list of remaining segments.
"""
w = getattr(self, 'child_%s' % (segments[0],), None)
if w:
r = iweb.IResource(w, None)
if r:
return r, segments[1:]
return w(request), segments[1:]
factory = getattr(self, 'childFactory', None)
if factory is not None:
r = factory(request, segments[0])
if r:
return r, segments[1:]
return None, []
def child_(self, request):
"""
This method locates a child with a trailing C{"/"} in the URL.
@param request: the request to process.
"""
if self.addSlash and len(request.postpath) == 1:
return self
return None
def getChild(self, path):
"""
Get a static child - when registered using L{putChild}.
@param path: the name of the child to get
@type path: C{str}
@return: the child or C{None} if not present
@rtype: L{iweb.IResource}
"""
return getattr(self, 'child_%s' % (path,), None)
def putChild(self, path, child):
"""
Register a static child.
This implementation registers children by assigning them to attributes
with a C{child_} prefix. C{resource.putChild("foo", child)} is
therefore same as C{o.child_foo = child}.
@param path: the name of the child to register. You almost certainly
don't want C{"/"} in C{path}. If you want to add a "directory"
resource (e.g. C{/foo/}) specify C{path} as C{""}.
@param child: an object adaptable to L{iweb.IResource}.
"""
setattr(self, 'child_%s' % (path,), child)
def http_GET(self, request):
if self.addSlash and request.prepath[-1] != '':
# If this is a directory-ish resource...
return http.RedirectResponse(request.unparseURL(path=request.path + '/'))
return super(Resource, self).http_GET(request)
class PostableResource(Resource):
"""
A L{Resource} capable of handling the POST request method.
@cvar maxMem: maximum memory used during the parsing of the data.
@type maxMem: C{int}
@cvar maxFields: maximum number of form fields allowed.
@type maxFields: C{int}
@cvar maxSize: maximum size of the whole post allowed.
@type maxSize: C{int}
"""
maxMem = 100 * 1024
maxFields = 1024
maxSize = 10 * 1024 * 1024
def http_POST(self, request):
"""
Respond to a POST request.
Reads and parses the incoming body data then calls L{render}.
@param request: the request to process.
@return: an object adaptable to L{iweb.IResponse}.
"""
return server.parsePOSTData(request,
self.maxMem, self.maxFields, self.maxSize
).addCallback(lambda res: self.render(request))
class LeafResource(RenderMixin):
"""
A L{Resource} with no children.
"""
implements(iweb.IResource)
def locateChild(self, request, segments):
return self, server.StopTraversal
class RedirectResource(LeafResource):
"""
A L{LeafResource} which always performs a redirect.
"""
implements(iweb.IResource)
def __init__(self, *args, **kwargs):
"""
Parameters are URL components and are the same as those for
L{urlparse.urlunparse}. URL components which are not specified will
default to the corresponding component of the URL of the request being
redirected.
"""
self._args = args
self._kwargs = kwargs
def renderHTTP(self, request):
return http.RedirectResponse(request.unparseURL(*self._args, **self._kwargs))
class WrapperResource(object):
"""
An L{iweb.IResource} implementation which wraps a L{RenderMixin} instance
and provides a hook in which a subclass can implement logic that is called
before request processing on the contained L{Resource}.
"""
implements(iweb.IResource)
def __init__(self, resource):
self.resource = resource
def hook(self, request):
"""
Override this method in order to do something before passing control on
to the wrapped resource's C{renderHTTP} and C{locateChild} methods.
@return: None or a L{Deferred}. If a deferred object is
returned, it's value is ignored, but C{renderHTTP} and
C{locateChild} are chained onto the deferred as callbacks.
"""
raise NotImplementedError()
def locateChild(self, request, segments):
x = self.hook(request)
if x is not None:
return x.addCallback(lambda data: (self.resource, segments))
return self.resource, segments
def renderHTTP(self, request):
x = self.hook(request)
if x is not None:
return x.addCallback(lambda data: self.resource)
return self.resource
def getChild(self, name):
return self.resource.getChild(name)
__all__ = ['RenderMixin', 'Resource', 'PostableResource', 'LeafResource', 'WrapperResource']
calendarserver-5.2+dfsg/twext/web2/filter/ 0000755 0001750 0001750 00000000000 12322625325 017610 5 ustar rahul rahul calendarserver-5.2+dfsg/twext/web2/filter/gzip.py 0000644 0001750 0001750 00000005364 11340001243 021125 0 ustar rahul rahul from __future__ import generators
import struct
import zlib
from twext.web2 import stream
# TODO: ungzip (can any browsers actually generate gzipped
# upload data?) But it's necessary for client anyways.
def gzipStream(input, compressLevel=6):
crc, size = zlib.crc32(''), 0
# magic header, compression method, no flags
header = '\037\213\010\000'
# timestamp
header += struct.pack('