diff --git a/decode_response.go b/decode_response.go index 7f40bc7..6cc37a4 100644 --- a/decode_response.go +++ b/decode_response.go @@ -277,22 +277,55 @@ func (sp *SAMLServiceProvider) ValidateEncodedResponse(encodedResponse string) ( return decodedResponse, nil } -// parseResponse is a helper function that was refactored out so that the XML parsing behavior can be isolated and unit tested -func parseResponse(xml []byte) (*etree.Document, *etree.Element, error) { - doc := etree.NewDocument() - err := doc.ReadFromBytes(xml) +// DecodeUnverifiedBaseResponse decodes several attributes from a SAML response for the purpose +// of determining how to validate the response. This is useful for Service Providers which +// expose a single Assertion Consumer Service URL but consume Responses from many IdPs. +func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBaseResponse, error) { + raw, err := base64.StdEncoding.DecodeString(encodedResponse) + if err != nil { + return nil, err + } + + var response *types.UnverifiedBaseResponse + + err = maybeDeflate(raw, func(maybeXML []byte) error { + response = &types.UnverifiedBaseResponse{} + return xml.Unmarshal(maybeXML, response) + }) if err != nil { - // Attempt to inflate the response in case it happens to be compressed (as with one case at saml.oktadev.com) - buf, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(xml))) + return nil, err + } + + return response, nil +} + +// maybeDeflate invokes the passed decoder over the passed data. If an error is +// returned, it then attempts to deflate the passed data before re-invoking +// the decoder over the deflated data. +func maybeDeflate(data []byte, decoder func([]byte) error) error { + err := decoder(data) + if err != nil { + deflated, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(data))) if err != nil { - return nil, nil, err + return err } + return decoder(deflated) + } + + return nil +} + +// parseResponse is a helper function that was refactored out so that the XML parsing behavior can be isolated and unit tested +func parseResponse(xml []byte) (*etree.Document, *etree.Element, error) { + var doc *etree.Document + + err := maybeDeflate(xml, func(xml []byte) error { doc = etree.NewDocument() - err = doc.ReadFromBytes(buf) - if err != nil { - return nil, nil, err - } + return doc.ReadFromBytes(xml) + }) + if err != nil { + return nil, nil, err } el := doc.Root() diff --git a/providertests/exercise.go b/providertests/exercise.go index 7c0c19a..cfe0aa7 100644 --- a/providertests/exercise.go +++ b/providertests/exercise.go @@ -5,12 +5,22 @@ package providertests import ( "testing" + saml2 "github.com/russellhaering/gosaml2" "github.com/stretchr/testify/require" ) func ExerciseProviderTestScenarios(t *testing.T, scenarios []ProviderTestScenario) { + println("TESTING") for _, scenario := range scenarios { t.Run(scenario.ScenarioName, func(t *testing.T) { + _, err := saml2.DecodeUnverifiedBaseResponse(scenario.Response) + // DecodeUnverifiedBaseResponse is more permissive than RetrieveAssertionInfo. + // If an error _is_ returned it should match, but it is OK for no error to be + // returned even when one is expected during full validation. + if err != nil { + scenario.CheckError(t, err) + } + assertionInfo, err := scenario.ServiceProvider.RetrieveAssertionInfo(scenario.Response) if scenario.CheckError != nil { scenario.CheckError(t, err) diff --git a/providertests/exercise_go_1_6.go b/providertests/exercise_go_1_6.go index 324ff67..fd5fdcf 100644 --- a/providertests/exercise_go_1_6.go +++ b/providertests/exercise_go_1_6.go @@ -10,6 +10,13 @@ import ( func ExerciseProviderTestScenarios(t *testing.T, scenarios []ProviderTestScenario) { for _, scenario := range scenarios { + // DecodeUnverifiedBaseResponse is more permissive than RetrieveAssertionInfo. + // If an error _is_ returned it should match, but it is OK for no error to be + // returned even when one is expected during full validation. + if err != nil { + scenario.CheckError(t, err) + } + assertionInfo, err := scenario.ServiceProvider.RetrieveAssertionInfo(scenario.Response) if scenario.CheckError != nil { scenario.CheckError(t, err) diff --git a/types/response.go b/types/response.go index adbc792..4f338f0 100644 --- a/types/response.go +++ b/types/response.go @@ -5,6 +5,20 @@ import ( "time" ) +// UnverifiedBaseResponse extracts several basic attributes of a SAML Response +// which may be useful in deciding how to validate the Response. An UnverifiedBaseResponse +// is parsed by this library prior to any validation of the Response, so the +// values it contains may have been supplied by an attacker and should not be +// trusted as authoritative from the IdP. +type UnverifiedBaseResponse struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"` + ID string `xml:"ID,attr"` + InResponseTo string `xml:"InResponseTo,attr"` + Destination string `xml:"Destination,attr"` + Version string `xml:"Version,attr"` + Issuer *Issuer `xml:"Issuer"` +} + type Response struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"` ID string `xml:"ID,attr"`