// OAuth 1.0 consumer implementation. // See http://www.oauth.net and RFC 5849 // // There are typically three parties involved in an OAuth exchange: // (1) The "Service Provider" (e.g. Google, Twitter, NetFlix) who operates the // service where the data resides. // (2) The "End User" who owns that data, and wants to grant access to a third-party. // (3) That third-party who wants access to the data (after first being authorized by // the user). This third-party is referred to as the "Consumer" in OAuth // terminology. // // This library is designed to help implement the third-party consumer by handling the // low-level authentication tasks, and allowing for authenticated requests to the // service provider on behalf of the user. // // Caveats: // - Currently only supports HMAC and RSA signatures. // - Currently only supports SHA1 and SHA256 hashes. // - Currently only supports OAuth 1.0 // // Overview of how to use this library: // (1) First create a new Consumer instance with the NewConsumer function // (2) Get a RequestToken, and "authorization url" from GetRequestTokenAndUrl() // (3) Save the RequestToken, you will need it again in step 6. // (4) Redirect the user to the "authorization url" from step 2, where they will // authorize your access to the service provider. // (5) Wait. You will be called back on the CallbackUrl that you provide, and you // will recieve a "verification code". // (6) Call AuthorizeToken() with the RequestToken from step 2 and the // "verification code" from step 5. // (7) You will get back an AccessToken. Save this for as long as you need access // to the user's data, and treat it like a password; it is a secret. // (8) You can now throw away the RequestToken from step 2, it is no longer // necessary. // (9) Call "MakeHttpClient" using the AccessToken from step 7 to get an // HTTP client which can access protected resources. package oauth import ( "bytes" "crypto" "crypto/hmac" cryptoRand "crypto/rand" "crypto/rsa" "encoding/base64" "errors" "fmt" "io" "io/ioutil" "math/rand" "mime/multipart" "net/http" "net/url" "sort" "strconv" "strings" "sync" "time" ) const ( OAUTH_VERSION = "1.0" SIGNATURE_METHOD_HMAC = "HMAC-" SIGNATURE_METHOD_RSA = "RSA-" HTTP_AUTH_HEADER = "Authorization" OAUTH_HEADER = "OAuth " BODY_HASH_PARAM = "oauth_body_hash" CALLBACK_PARAM = "oauth_callback" CONSUMER_KEY_PARAM = "oauth_consumer_key" NONCE_PARAM = "oauth_nonce" SESSION_HANDLE_PARAM = "oauth_session_handle" SIGNATURE_METHOD_PARAM = "oauth_signature_method" SIGNATURE_PARAM = "oauth_signature" TIMESTAMP_PARAM = "oauth_timestamp" TOKEN_PARAM = "oauth_token" TOKEN_SECRET_PARAM = "oauth_token_secret" VERIFIER_PARAM = "oauth_verifier" VERSION_PARAM = "oauth_version" ) var HASH_METHOD_MAP = map[crypto.Hash]string{ crypto.SHA1: "SHA1", crypto.SHA256: "SHA256", } // TODO(mrjones) Do we definitely want separate "Request" and "Access" token classes? // They're identical structurally, but used for different purposes. type RequestToken struct { Token string Secret string } type AccessToken struct { Token string Secret string AdditionalData map[string]string } type DataLocation int const ( LOC_BODY DataLocation = iota + 1 LOC_URL LOC_MULTIPART LOC_JSON LOC_XML ) // Information about how to contact the service provider (see #1 above). // You usually find all of these URLs by reading the documentation for the service // that you're trying to connect to. // Some common examples are: // (1) Google, standard APIs: // http://code.google.com/apis/accounts/docs/OAuth_ref.html // - RequestTokenUrl: https://www.google.com/accounts/OAuthGetRequestToken // - AuthorizeTokenUrl: https://www.google.com/accounts/OAuthAuthorizeToken // - AccessTokenUrl: https://www.google.com/accounts/OAuthGetAccessToken // Note: Some Google APIs (for example, Google Latitude) use different values for // one or more of those URLs. // (2) Twitter API: // http://dev.twitter.com/pages/auth // - RequestTokenUrl: http://api.twitter.com/oauth/request_token // - AuthorizeTokenUrl: https://api.twitter.com/oauth/authorize // - AccessTokenUrl: https://api.twitter.com/oauth/access_token // (3) NetFlix API: // http://developer.netflix.com/docs/Security // - RequestTokenUrl: http://api.netflix.com/oauth/request_token // - AuthroizeTokenUrl: https://api-user.netflix.com/oauth/login // - AccessTokenUrl: http://api.netflix.com/oauth/access_token // Set HttpMethod if the service provider requires a different HTTP method // to be used for OAuth token requests type ServiceProvider struct { RequestTokenUrl string AuthorizeTokenUrl string AccessTokenUrl string HttpMethod string BodyHash bool IgnoreTimestamp bool // Enables non spec-compliant behavior: // Allow parameters to be passed in the query string rather // than the body. // See https://github.com/mrjones/oauth/pull/63 SignQueryParams bool } func (sp *ServiceProvider) httpMethod() string { if sp.HttpMethod != "" { return sp.HttpMethod } return "GET" } // lockedNonceGenerator wraps a non-reentrant random number generator with a // lock type lockedNonceGenerator struct { nonceGenerator nonceGenerator lock sync.Mutex } func newLockedNonceGenerator(c clock) *lockedNonceGenerator { return &lockedNonceGenerator{ nonceGenerator: rand.New(rand.NewSource(c.Nanos())), } } func (n *lockedNonceGenerator) Int63() int64 { n.lock.Lock() r := n.nonceGenerator.Int63() n.lock.Unlock() return r } // Consumers are stateless, you can call the various methods (GetRequestTokenAndUrl, // AuthorizeToken, and Get) on various different instances of Consumers *as long as // they were set up in the same way.* It is up to you, as the caller to persist the // necessary state (RequestTokens and AccessTokens). type Consumer struct { // Some ServiceProviders require extra parameters to be passed for various reasons. // For example Google APIs require you to set a scope= parameter to specify how much // access is being granted. The proper values for scope= depend on the service: // For more, see: http://code.google.com/apis/accounts/docs/OAuth.html#prepScope AdditionalParams map[string]string // The rest of this class is configured via the NewConsumer function. consumerKey string serviceProvider ServiceProvider // Some APIs (e.g. Netflix) aren't quite standard OAuth, and require passing // additional parameters when authorizing the request token. For most APIs // this field can be ignored. For Netflix, do something like: // consumer.AdditionalAuthorizationUrlParams = map[string]string{ // "application_name": "YourAppName", // "oauth_consumer_key": "YourConsumerKey", // } AdditionalAuthorizationUrlParams map[string]string debug bool // Defaults to http.Client{}, can be overridden (e.g. for testing) as necessary HttpClient HttpClient // Some APIs (e.g. Intuit/Quickbooks) require sending additional headers along with // requests. (like "Accept" to specify the response type as XML or JSON) Note that this // will only *add* headers, not set existing ones. AdditionalHeaders map[string][]string // Private seams for mocking dependencies when testing clock clock // Seeded generators are not reentrant nonceGenerator nonceGenerator signer signer } func newConsumer(consumerKey string, serviceProvider ServiceProvider, httpClient *http.Client) *Consumer { clock := &defaultClock{} if httpClient == nil { httpClient = &http.Client{} } return &Consumer{ consumerKey: consumerKey, serviceProvider: serviceProvider, clock: clock, HttpClient: httpClient, nonceGenerator: newLockedNonceGenerator(clock), AdditionalParams: make(map[string]string), AdditionalAuthorizationUrlParams: make(map[string]string), } } // Creates a new Consumer instance, with a HMAC-SHA1 signer // - consumerKey and consumerSecret: // values you should obtain from the ServiceProvider when you register your // application. // // - serviceProvider: // see the documentation for ServiceProvider for how to create this. // func NewConsumer(consumerKey string, consumerSecret string, serviceProvider ServiceProvider) *Consumer { consumer := newConsumer(consumerKey, serviceProvider, nil) consumer.signer = &HMACSigner{ consumerSecret: consumerSecret, hashFunc: crypto.SHA1, } return consumer } // Creates a new Consumer instance, with a HMAC-SHA1 signer // - consumerKey and consumerSecret: // values you should obtain from the ServiceProvider when you register your // application. // // - serviceProvider: // see the documentation for ServiceProvider for how to create this. // // - httpClient: // Provides a custom implementation of the httpClient used under the hood // to make the request. This is especially useful if you want to use // Google App Engine. // func NewCustomHttpClientConsumer(consumerKey string, consumerSecret string, serviceProvider ServiceProvider, httpClient *http.Client) *Consumer { consumer := newConsumer(consumerKey, serviceProvider, httpClient) consumer.signer = &HMACSigner{ consumerSecret: consumerSecret, hashFunc: crypto.SHA1, } return consumer } // Creates a new Consumer instance, with a HMAC signer // - consumerKey and consumerSecret: // values you should obtain from the ServiceProvider when you register your // application. // // - hashFunc: // the crypto.Hash to use for signatures // // - serviceProvider: // see the documentation for ServiceProvider for how to create this. // // - httpClient: // Provides a custom implementation of the httpClient used under the hood // to make the request. This is especially useful if you want to use // Google App Engine. Can be nil for default. // func NewCustomConsumer(consumerKey string, consumerSecret string, hashFunc crypto.Hash, serviceProvider ServiceProvider, httpClient *http.Client) *Consumer { consumer := newConsumer(consumerKey, serviceProvider, httpClient) consumer.signer = &HMACSigner{ consumerSecret: consumerSecret, hashFunc: hashFunc, } return consumer } // Creates a new Consumer instance, with a RSA-SHA1 signer // - consumerKey: // value you should obtain from the ServiceProvider when you register your // application. // // - privateKey: // the private key to use for signatures // // - serviceProvider: // see the documentation for ServiceProvider for how to create this. // func NewRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey, serviceProvider ServiceProvider) *Consumer { consumer := newConsumer(consumerKey, serviceProvider, nil) consumer.signer = &RSASigner{ privateKey: privateKey, hashFunc: crypto.SHA1, rand: cryptoRand.Reader, } return consumer } // Creates a new Consumer instance, with a RSA signer // - consumerKey: // value you should obtain from the ServiceProvider when you register your // application. // // - privateKey: // the private key to use for signatures // // - hashFunc: // the crypto.Hash to use for signatures // // - serviceProvider: // see the documentation for ServiceProvider for how to create this. // // - httpClient: // Provides a custom implementation of the httpClient used under the hood // to make the request. This is especially useful if you want to use // Google App Engine. Can be nil for default. // func NewCustomRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey, hashFunc crypto.Hash, serviceProvider ServiceProvider, httpClient *http.Client) *Consumer { consumer := newConsumer(consumerKey, serviceProvider, httpClient) consumer.signer = &RSASigner{ privateKey: privateKey, hashFunc: hashFunc, rand: cryptoRand.Reader, } return consumer } // Kicks off the OAuth authorization process. // - callbackUrl: // Authorizing a token *requires* redirecting to the service provider. This is the // URL which the service provider will redirect the user back to after that // authorization is completed. The service provider will pass back a verification // code which is necessary to complete the rest of the process (in AuthorizeToken). // Notes on callbackUrl: // - Some (all?) service providers allow for setting "oob" (for out-of-band) as a // callback url. If this is set the service provider will present the // verification code directly to the user, and you must provide a place for // them to copy-and-paste it into. // - Otherwise, the user will be redirected to callbackUrl in the browser, and // will append a "oauth_verifier=<verifier>" parameter. // // This function returns: // - rtoken: // A temporary RequestToken, used during the authorization process. You must save // this since it will be necessary later in the process when calling // AuthorizeToken(). // // - url: // A URL that you should redirect the user to in order that they may authorize you // to the service provider. // // - err: // Set only if there was an error, nil otherwise. func (c *Consumer) GetRequestTokenAndUrl(callbackUrl string) (rtoken *RequestToken, loginUrl string, err error) { return c.GetRequestTokenAndUrlWithParams(callbackUrl, c.AdditionalParams) } func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additionalParams map[string]string) (rtoken *RequestToken, loginUrl string, err error) { params := c.baseParams(c.consumerKey, additionalParams) if callbackUrl != "" { params.Add(CALLBACK_PARAM, callbackUrl) } req := &request{ method: c.serviceProvider.httpMethod(), url: c.serviceProvider.RequestTokenUrl, oauthParams: params, } if _, err := c.signRequest(req, ""); err != nil { // We don't have a token secret for the key yet return nil, "", err } resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.RequestTokenUrl, params) if err != nil { return nil, "", errors.New("getBody: " + err.Error()) } requestToken, err := parseRequestToken(*resp) if err != nil { return nil, "", errors.New("parseRequestToken: " + err.Error()) } loginParams := make(url.Values) for k, v := range c.AdditionalAuthorizationUrlParams { loginParams.Set(k, v) } loginParams.Set(TOKEN_PARAM, requestToken.Token) loginUrl = c.serviceProvider.AuthorizeTokenUrl + "?" + loginParams.Encode() return requestToken, loginUrl, nil } // After the user has authorized you to the service provider, use this method to turn // your temporary RequestToken into a permanent AccessToken. You must pass in two values: // - rtoken: // The RequestToken returned from GetRequestTokenAndUrl() // // - verificationCode: // The string which passed back from the server, either as the oauth_verifier // query param appended to callbackUrl *OR* a string manually entered by the user // if callbackUrl is "oob" // // It will return: // - atoken: // A permanent AccessToken which can be used to access the user's data (until it is // revoked by the user or the service provider). // // - err: // Set only if there was an error, nil otherwise. func (c *Consumer) AuthorizeToken(rtoken *RequestToken, verificationCode string) (atoken *AccessToken, err error) { return c.AuthorizeTokenWithParams(rtoken, verificationCode, c.AdditionalParams) } func (c *Consumer) AuthorizeTokenWithParams(rtoken *RequestToken, verificationCode string, additionalParams map[string]string) (atoken *AccessToken, err error) { params := map[string]string{ TOKEN_PARAM: rtoken.Token, } if verificationCode != "" { params[VERIFIER_PARAM] = verificationCode } return c.makeAccessTokenRequestWithParams(params, rtoken.Secret, additionalParams) } // Use the service provider to refresh the AccessToken for a given session. // Note that this is only supported for service providers that manage an // authorization session (e.g. Yahoo). // // Most providers do not return the SESSION_HANDLE_PARAM needed to refresh // the token. // // See http://oauth.googlecode.com/svn/spec/ext/session/1.0/drafts/1/spec.html // for more information. // - accessToken: // The AccessToken returned from AuthorizeToken() // // It will return: // - atoken: // An AccessToken which can be used to access the user's data (until it is // revoked by the user or the service provider). // // - err: // Set if accessToken does not contain the SESSION_HANDLE_PARAM needed to // refresh the token, or if an error occurred when making the request. func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken, err error) { params := make(map[string]string) sessionHandle, ok := accessToken.AdditionalData[SESSION_HANDLE_PARAM] if !ok { return nil, errors.New("Missing " + SESSION_HANDLE_PARAM + " in access token.") } params[SESSION_HANDLE_PARAM] = sessionHandle params[TOKEN_PARAM] = accessToken.Token return c.makeAccessTokenRequest(params, accessToken.Secret) } // Use the service provider to obtain an AccessToken for a given session // - params: // The access token request paramters. // // - secret: // Secret key to use when signing the access token request. // // It will return: // - atoken // An AccessToken which can be used to access the user's data (until it is // revoked by the user or the service provider). // // - err: // Set only if there was an error, nil otherwise. func (c *Consumer) makeAccessTokenRequest(params map[string]string, secret string) (atoken *AccessToken, err error) { return c.makeAccessTokenRequestWithParams(params, secret, c.AdditionalParams) } func (c *Consumer) makeAccessTokenRequestWithParams(params map[string]string, secret string, additionalParams map[string]string) (atoken *AccessToken, err error) { orderedParams := c.baseParams(c.consumerKey, additionalParams) for key, value := range params { orderedParams.Add(key, value) } req := &request{ method: c.serviceProvider.httpMethod(), url: c.serviceProvider.AccessTokenUrl, oauthParams: orderedParams, } if _, err := c.signRequest(req, secret); err != nil { return nil, err } resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.AccessTokenUrl, orderedParams) if err != nil { return nil, err } return parseAccessToken(*resp) } type RoundTripper struct { consumer *Consumer token *AccessToken } func (c *Consumer) MakeRoundTripper(token *AccessToken) (*RoundTripper, error) { return &RoundTripper{consumer: c, token: token}, nil } func (c *Consumer) MakeHttpClient(token *AccessToken) (*http.Client, error) { return &http.Client{ Transport: &RoundTripper{consumer: c, token: token}, }, nil } // ** DEPRECATED ** // Please call Get on the http client returned by MakeHttpClient instead! // // Executes an HTTP Get, authorized via the AccessToken. // - url: // The base url, without any query params, which is being accessed // // - userParams: // Any key=value params to be included in the query string // // - token: // The AccessToken returned by AuthorizeToken() // // This method returns: // - resp: // The HTTP Response resulting from making this request. // // - err: // Set only if there was an error, nil otherwise. func (c *Consumer) Get(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { return c.makeAuthorizedRequest("GET", url, LOC_URL, "", userParams, token) } func encodeUserParams(userParams map[string]string) string { data := url.Values{} for k, v := range userParams { data.Add(k, v) } return data.Encode() } // ** DEPRECATED ** // Please call "Post" on the http client returned by MakeHttpClient instead func (c *Consumer) PostForm(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { return c.PostWithBody(url, "", userParams, token) } // ** DEPRECATED ** // Please call "Post" on the http client returned by MakeHttpClient instead func (c *Consumer) Post(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { return c.PostWithBody(url, "", userParams, token) } // ** DEPRECATED ** // Please call "Post" on the http client returned by MakeHttpClient instead func (c *Consumer) PostWithBody(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { return c.makeAuthorizedRequest("POST", url, LOC_BODY, body, userParams, token) } // ** DEPRECATED ** // Please call "Do" on the http client returned by MakeHttpClient instead // (and set the "Content-Type" header explicitly in the http.Request) func (c *Consumer) PostJson(url string, body string, token *AccessToken) (resp *http.Response, err error) { return c.makeAuthorizedRequest("POST", url, LOC_JSON, body, nil, token) } // ** DEPRECATED ** // Please call "Do" on the http client returned by MakeHttpClient instead // (and set the "Content-Type" header explicitly in the http.Request) func (c *Consumer) PostXML(url string, body string, token *AccessToken) (resp *http.Response, err error) { return c.makeAuthorizedRequest("POST", url, LOC_XML, body, nil, token) } // ** DEPRECATED ** // Please call "Do" on the http client returned by MakeHttpClient instead // (and setup the multipart data explicitly in the http.Request) func (c *Consumer) PostMultipart(url, multipartName string, multipartData io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { return c.makeAuthorizedRequestReader("POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token) } // ** DEPRECATED ** // Please call "Delete" on the http client returned by MakeHttpClient instead func (c *Consumer) Delete(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { return c.makeAuthorizedRequest("DELETE", url, LOC_URL, "", userParams, token) } // ** DEPRECATED ** // Please call "Put" on the http client returned by MakeHttpClient instead func (c *Consumer) Put(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { return c.makeAuthorizedRequest("PUT", url, LOC_URL, body, userParams, token) } func (c *Consumer) Debug(enabled bool) { c.debug = enabled c.signer.Debug(enabled) } type pair struct { key string value string } type pairs []pair func (p pairs) Len() int { return len(p) } func (p pairs) Less(i, j int) bool { return p[i].key < p[j].key } func (p pairs) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // This function has basically turned into a backwards compatibility layer // between the old API (where clients explicitly called consumer.Get() // consumer.Post() etc), and the new API (which takes actual http.Requests) // // So, here we construct the appropriate HTTP request for the inputs. func (c *Consumer) makeAuthorizedRequestReader(method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { urlObject, err := url.Parse(urlString) if err != nil { return nil, err } request := &http.Request{ Method: method, URL: urlObject, Header: http.Header{}, Body: body, ContentLength: int64(contentLength), } vals := url.Values{} for k, v := range userParams { vals.Add(k, v) } if dataLocation != LOC_BODY { request.URL.RawQuery = vals.Encode() request.URL.RawQuery = strings.Replace( request.URL.RawQuery, ";", "%3B", -1) } else { // TODO(mrjones): validate that we're not overrideing an exising body? request.ContentLength = int64(len(vals.Encode())) if request.ContentLength == 0 { request.Body = nil } else { request.Body = ioutil.NopCloser(strings.NewReader(vals.Encode())) } } for k, vs := range c.AdditionalHeaders { for _, v := range vs { request.Header.Set(k, v) } } if dataLocation == LOC_BODY { request.Header.Set("Content-Type", "application/x-www-form-urlencoded") } if dataLocation == LOC_JSON { request.Header.Set("Content-Type", "application/json") } if dataLocation == LOC_XML { request.Header.Set("Content-Type", "application/xml") } if dataLocation == LOC_MULTIPART { pipeReader, pipeWriter := io.Pipe() writer := multipart.NewWriter(pipeWriter) if request.URL.Host == "www.mrjon.es" && request.URL.Path == "/unittest" { writer.SetBoundary("UNITTESTBOUNDARY") } go func(body io.Reader) { part, err := writer.CreateFormFile(multipartName, "/no/matter") if err != nil { writer.Close() pipeWriter.CloseWithError(err) return } _, err = io.Copy(part, body) if err != nil { writer.Close() pipeWriter.CloseWithError(err) return } writer.Close() pipeWriter.Close() }(body) request.Body = pipeReader request.Header.Set("Content-Type", writer.FormDataContentType()) } rt := RoundTripper{consumer: c, token: token} resp, err = rt.RoundTrip(request) if err != nil { return resp, err } if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { defer resp.Body.Close() bytes, _ := ioutil.ReadAll(resp.Body) return resp, HTTPExecuteError{ RequestHeaders: "", ResponseBodyBytes: bytes, Status: resp.Status, StatusCode: resp.StatusCode, } } return resp, nil } // cloneReq clones the src http.Request, making deep copies of the Header and // the URL but shallow copies of everything else func cloneReq(src *http.Request) *http.Request { dst := &http.Request{} *dst = *src dst.Header = make(http.Header, len(src.Header)) for k, s := range src.Header { dst.Header[k] = append([]string(nil), s...) } if src.URL != nil { dst.URL = cloneURL(src.URL) } return dst } // cloneURL shallow clones the src *url.URL func cloneURL(src *url.URL) *url.URL { dst := &url.URL{} *dst = *src return dst } func canonicalizeUrl(u *url.URL) string { var buf bytes.Buffer buf.WriteString(u.Scheme) buf.WriteString("://") buf.WriteString(u.Host) buf.WriteString(u.Path) return buf.String() } func getBody(request *http.Request) ([]byte, error) { if request.Body == nil { return nil, nil } defer request.Body.Close() originalBody, err := ioutil.ReadAll(request.Body) if err != nil { return nil, err } // We have to re-install the body (because we've ruined it by reading it). if len(originalBody) > 0 { request.Body = ioutil.NopCloser(bytes.NewReader(originalBody)) } else { request.Body = nil } return originalBody, nil } func parseBody(request *http.Request) (map[string][]string, error) { userParams := map[string][]string{} // TODO(mrjones): factor parameter extraction into a separate method if request.Header.Get("Content-Type") != "application/x-www-form-urlencoded" { // Most of the time we get parameters from the query string: for k, vs := range request.URL.Query() { userParams[k] = vs } } else { // x-www-form-urlencoded parameters come from the body instead: body, err := getBody(request) if err != nil { return nil, err } params, err := url.ParseQuery(string(body)) if err != nil { return nil, err } for k, vs := range params { userParams[k] = vs } } return userParams, nil } func paramsToSortedPairs(params map[string][]string) pairs { // Sort parameters alphabetically paramPairs := pairs([]pair{}) for key, values := range params { for _, value := range values { paramPairs = append(paramPairs, pair{key: key, value: value}) } } sort.Sort(paramPairs) return paramPairs } func calculateBodyHash(request *http.Request, s signer) (string, error) { if request.Header.Get("Content-Type") == "application/x-www-form-urlencoded" { return "", nil } var body []byte if request.Body != nil { var err error body, err = getBody(request) if err != nil { return "", err } } h := s.HashFunc().New() h.Write(body) rawSignature := h.Sum(nil) return base64.StdEncoding.EncodeToString(rawSignature), nil } func (rt *RoundTripper) RoundTrip(userRequest *http.Request) (*http.Response, error) { serverRequest := cloneReq(userRequest) allParams := rt.consumer.baseParams( rt.consumer.consumerKey, rt.consumer.AdditionalParams) // Do not add the "oauth_token" parameter, if the access token has not been // specified. By omitting this parameter when it is not specified, allows // two-legged OAuth calls. if len(rt.token.Token) > 0 { allParams.Add(TOKEN_PARAM, rt.token.Token) } if rt.consumer.serviceProvider.BodyHash { bodyHash, err := calculateBodyHash(serverRequest, rt.consumer.signer) if err != nil { return nil, err } if bodyHash != "" { allParams.Add(BODY_HASH_PARAM, bodyHash) } } authParams := allParams.Clone() // TODO(mrjones): put these directly into the paramPairs below? userParams, err := parseBody(serverRequest) if err != nil { return nil, err } paramPairs := paramsToSortedPairs(userParams) for i := range paramPairs { allParams.Add(paramPairs[i].key, paramPairs[i].value) } signingURL := cloneURL(serverRequest.URL) if host := serverRequest.Host; host != "" { signingURL.Host = host } baseString := rt.consumer.requestString(serverRequest.Method, canonicalizeUrl(signingURL), allParams) signature, err := rt.consumer.signer.Sign(baseString, rt.token.Secret) if err != nil { return nil, err } authParams.Add(SIGNATURE_PARAM, signature) // Set auth header. oauthHdr := OAUTH_HEADER for pos, key := range authParams.Keys() { for innerPos, value := range authParams.Get(key) { if pos+innerPos > 0 { oauthHdr += "," } oauthHdr += key + "=\"" + value + "\"" } } serverRequest.Header.Add(HTTP_AUTH_HEADER, oauthHdr) if rt.consumer.debug { fmt.Printf("Request: %v\n", serverRequest) } resp, err := rt.consumer.HttpClient.Do(serverRequest) if err != nil { return resp, err } return resp, nil } func (c *Consumer) makeAuthorizedRequest(method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { return c.makeAuthorizedRequestReader(method, url, dataLocation, len(body), "", ioutil.NopCloser(strings.NewReader(body)), userParams, token) } type request struct { method string url string oauthParams *OrderedParams userParams map[string]string } type HttpClient interface { Do(req *http.Request) (resp *http.Response, err error) } type clock interface { Seconds() int64 Nanos() int64 } type nonceGenerator interface { Int63() int64 } type key interface { String() string } type signer interface { Sign(message string, tokenSecret string) (string, error) Verify(message string, signature string) error SignatureMethod() string HashFunc() crypto.Hash Debug(enabled bool) } type defaultClock struct{} func (*defaultClock) Seconds() int64 { return time.Now().Unix() } func (*defaultClock) Nanos() int64 { return time.Now().UnixNano() } func (c *Consumer) signRequest(req *request, tokenSecret string) (*request, error) { baseString := c.requestString(req.method, req.url, req.oauthParams) signature, err := c.signer.Sign(baseString, tokenSecret) if err != nil { return nil, err } req.oauthParams.Add(SIGNATURE_PARAM, signature) return req, nil } // Obtains an AccessToken from the response of a service provider. // - data: // The response body. // // This method returns: // - atoken: // The AccessToken generated from the response body. // // - err: // Set if an AccessToken could not be parsed from the given input. func parseAccessToken(data string) (atoken *AccessToken, err error) { parts, err := url.ParseQuery(data) if err != nil { return nil, err } tokenParam := parts[TOKEN_PARAM] parts.Del(TOKEN_PARAM) if len(tokenParam) < 1 { return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " + "Full response body: '" + data + "'") } tokenSecretParam := parts[TOKEN_SECRET_PARAM] parts.Del(TOKEN_SECRET_PARAM) if len(tokenSecretParam) < 1 { return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." + "Full response body: '" + data + "'") } additionalData := parseAdditionalData(parts) return &AccessToken{tokenParam[0], tokenSecretParam[0], additionalData}, nil } func parseRequestToken(data string) (*RequestToken, error) { parts, err := url.ParseQuery(data) if err != nil { return nil, err } tokenParam := parts[TOKEN_PARAM] if len(tokenParam) < 1 { return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " + "Full response body: '" + data + "'") } tokenSecretParam := parts[TOKEN_SECRET_PARAM] if len(tokenSecretParam) < 1 { return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." + "Full response body: '" + data + "'") } return &RequestToken{tokenParam[0], tokenSecretParam[0]}, nil } func (c *Consumer) baseParams(consumerKey string, additionalParams map[string]string) *OrderedParams { params := NewOrderedParams() params.Add(VERSION_PARAM, OAUTH_VERSION) params.Add(SIGNATURE_METHOD_PARAM, c.signer.SignatureMethod()) params.Add(TIMESTAMP_PARAM, strconv.FormatInt(c.clock.Seconds(), 10)) params.Add(NONCE_PARAM, strconv.FormatInt(c.nonceGenerator.Int63(), 10)) params.Add(CONSUMER_KEY_PARAM, consumerKey) for key, value := range additionalParams { params.Add(key, value) } return params } func parseAdditionalData(parts url.Values) map[string]string { params := make(map[string]string) for key, value := range parts { if len(value) > 0 { params[key] = value[0] } } return params } type HMACSigner struct { consumerSecret string hashFunc crypto.Hash debug bool } func (s *HMACSigner) Debug(enabled bool) { s.debug = enabled } func (s *HMACSigner) Sign(message string, tokenSecret string) (string, error) { key := escape(s.consumerSecret) + "&" + escape(tokenSecret) if s.debug { fmt.Println("Signing:", message) fmt.Println("Key:", key) } h := hmac.New(s.HashFunc().New, []byte(key)) h.Write([]byte(message)) rawSignature := h.Sum(nil) base64signature := base64.StdEncoding.EncodeToString(rawSignature) if s.debug { fmt.Println("Base64 signature:", base64signature) } return base64signature, nil } func (s *HMACSigner) Verify(message string, signature string) error { if s.debug { fmt.Println("Verifying Base64 signature:", signature) } validSignature, err := s.Sign(message, "") if err != nil { return err } if validSignature != signature { decodedSigniture, _ := url.QueryUnescape(signature) if validSignature != decodedSigniture { return fmt.Errorf("signature did not match") } } return nil } func (s *HMACSigner) SignatureMethod() string { return SIGNATURE_METHOD_HMAC + HASH_METHOD_MAP[s.HashFunc()] } func (s *HMACSigner) HashFunc() crypto.Hash { return s.hashFunc } type RSASigner struct { debug bool rand io.Reader privateKey *rsa.PrivateKey hashFunc crypto.Hash } func (s *RSASigner) Debug(enabled bool) { s.debug = enabled } func (s *RSASigner) Sign(message string, tokenSecret string) (string, error) { if s.debug { fmt.Println("Signing:", message) } h := s.HashFunc().New() h.Write([]byte(message)) digest := h.Sum(nil) signature, err := rsa.SignPKCS1v15(s.rand, s.privateKey, s.HashFunc(), digest) if err != nil { return "", nil } base64signature := base64.StdEncoding.EncodeToString(signature) if s.debug { fmt.Println("Base64 signature:", base64signature) } return base64signature, nil } func (s *RSASigner) Verify(message string, base64signature string) error { if s.debug { fmt.Println("Verifying:", message) fmt.Println("Verifying Base64 signature:", base64signature) } h := s.HashFunc().New() h.Write([]byte(message)) digest := h.Sum(nil) signature, err := base64.StdEncoding.DecodeString(base64signature) if err != nil { return err } return rsa.VerifyPKCS1v15(&s.privateKey.PublicKey, s.HashFunc(), digest, signature) } func (s *RSASigner) SignatureMethod() string { return SIGNATURE_METHOD_RSA + HASH_METHOD_MAP[s.HashFunc()] } func (s *RSASigner) HashFunc() crypto.Hash { return s.hashFunc } func escape(s string) string { t := make([]byte, 0, 3*len(s)) for i := 0; i < len(s); i++ { c := s[i] if isEscapable(c) { t = append(t, '%') t = append(t, "0123456789ABCDEF"[c>>4]) t = append(t, "0123456789ABCDEF"[c&15]) } else { t = append(t, s[i]) } } return string(t) } func isEscapable(b byte) bool { return !('A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' || b == '-' || b == '.' || b == '_' || b == '~') } func (c *Consumer) requestString(method string, url string, params *OrderedParams) string { result := method + "&" + escape(url) for pos, key := range params.Keys() { for innerPos, value := range params.Get(key) { if pos+innerPos == 0 { result += "&" } else { result += escape("&") } result += escape(fmt.Sprintf("%s=%s", key, value)) } } return result } func (c *Consumer) getBody(method, url string, oauthParams *OrderedParams) (*string, error) { resp, err := c.httpExecute(method, url, "", 0, nil, oauthParams) if err != nil { return nil, errors.New("httpExecute: " + err.Error()) } bodyBytes, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { return nil, errors.New("ReadAll: " + err.Error()) } bodyStr := string(bodyBytes) if c.debug { fmt.Printf("STATUS: %d %s\n", resp.StatusCode, resp.Status) fmt.Println("BODY RESPONSE: " + bodyStr) } return &bodyStr, nil } // HTTPExecuteError signals that a call to httpExecute failed. type HTTPExecuteError struct { // RequestHeaders provides a stringified listing of request headers. RequestHeaders string // ResponseBodyBytes is the response read into a byte slice. ResponseBodyBytes []byte // Status is the status code string response. Status string // StatusCode is the parsed status code. StatusCode int } // Error provides a printable string description of an HTTPExecuteError. func (e HTTPExecuteError) Error() string { return "HTTP response is not 200/OK as expected. Actual response: \n" + "\tResponse Status: '" + e.Status + "'\n" + "\tResponse Code: " + strconv.Itoa(e.StatusCode) + "\n" + "\tResponse Body: " + string(e.ResponseBodyBytes) + "\n" + "\tRequest Headers: " + e.RequestHeaders } func (c *Consumer) httpExecute( method string, urlStr string, contentType string, contentLength int, body io.Reader, oauthParams *OrderedParams) (*http.Response, error) { // Create base request. req, err := http.NewRequest(method, urlStr, body) if err != nil { return nil, errors.New("NewRequest failed: " + err.Error()) } // Set auth header. req.Header = http.Header{} oauthHdr := "OAuth " for pos, key := range oauthParams.Keys() { for innerPos, value := range oauthParams.Get(key) { if pos+innerPos > 0 { oauthHdr += "," } oauthHdr += key + "=\"" + value + "\"" } } req.Header.Add("Authorization", oauthHdr) // Add additional custom headers for key, vals := range c.AdditionalHeaders { for _, val := range vals { req.Header.Add(key, val) } } // Set contentType if passed. if contentType != "" { req.Header.Set("Content-Type", contentType) } // Set contentLength if passed. if contentLength > 0 { req.Header.Set("Content-Length", strconv.Itoa(contentLength)) } if c.debug { fmt.Printf("Request: %v\n", req) } resp, err := c.HttpClient.Do(req) if err != nil { return nil, errors.New("Do: " + err.Error()) } debugHeader := "" for k, vals := range req.Header { for _, val := range vals { debugHeader += "[key: " + k + ", val: " + val + "]" } } // StatusMultipleChoices is 300, any 2xx response should be treated as success if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { defer resp.Body.Close() bytes, _ := ioutil.ReadAll(resp.Body) return resp, HTTPExecuteError{ RequestHeaders: debugHeader, ResponseBodyBytes: bytes, Status: resp.Status, StatusCode: resp.StatusCode, } } return resp, err } // // String Sorting helpers // type ByValue []string func (a ByValue) Len() int { return len(a) } func (a ByValue) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a ByValue) Less(i, j int) bool { return a[i] < a[j] } // // ORDERED PARAMS // type OrderedParams struct { allParams map[string][]string keyOrdering []string } func NewOrderedParams() *OrderedParams { return &OrderedParams{ allParams: make(map[string][]string), keyOrdering: make([]string, 0), } } func (o *OrderedParams) Get(key string) []string { sort.Sort(ByValue(o.allParams[key])) return o.allParams[key] } func (o *OrderedParams) Keys() []string { sort.Sort(o) return o.keyOrdering } func (o *OrderedParams) Add(key, value string) { o.AddUnescaped(key, escape(value)) } func (o *OrderedParams) AddUnescaped(key, value string) { if _, exists := o.allParams[key]; !exists { o.keyOrdering = append(o.keyOrdering, key) o.allParams[key] = make([]string, 1) o.allParams[key][0] = value } else { o.allParams[key] = append(o.allParams[key], value) } } func (o *OrderedParams) Len() int { return len(o.keyOrdering) } func (o *OrderedParams) Less(i int, j int) bool { return o.keyOrdering[i] < o.keyOrdering[j] } func (o *OrderedParams) Swap(i int, j int) { o.keyOrdering[i], o.keyOrdering[j] = o.keyOrdering[j], o.keyOrdering[i] } func (o *OrderedParams) Clone() *OrderedParams { clone := NewOrderedParams() for _, key := range o.Keys() { for _, value := range o.Get(key) { clone.AddUnescaped(key, value) } } return clone }