| package matchers |
| |
| import ( |
| "bytes" |
| "encoding/xml" |
| "errors" |
| "fmt" |
| "io" |
| "reflect" |
| "sort" |
| "strings" |
| |
| "github.com/onsi/gomega/format" |
| "golang.org/x/net/html/charset" |
| ) |
| |
| type MatchXMLMatcher struct { |
| XMLToMatch interface{} |
| } |
| |
| func (matcher *MatchXMLMatcher) Match(actual interface{}) (success bool, err error) { |
| actualString, expectedString, err := matcher.formattedPrint(actual) |
| if err != nil { |
| return false, err |
| } |
| |
| aval, err := parseXmlContent(actualString) |
| if err != nil { |
| return false, fmt.Errorf("Actual '%s' should be valid XML, but it is not.\nUnderlying error:%s", actualString, err) |
| } |
| |
| eval, err := parseXmlContent(expectedString) |
| if err != nil { |
| return false, fmt.Errorf("Expected '%s' should be valid XML, but it is not.\nUnderlying error:%s", expectedString, err) |
| } |
| |
| return reflect.DeepEqual(aval, eval), nil |
| } |
| |
| func (matcher *MatchXMLMatcher) FailureMessage(actual interface{}) (message string) { |
| actualString, expectedString, _ := matcher.formattedPrint(actual) |
| return fmt.Sprintf("Expected\n%s\nto match XML of\n%s", actualString, expectedString) |
| } |
| |
| func (matcher *MatchXMLMatcher) NegatedFailureMessage(actual interface{}) (message string) { |
| actualString, expectedString, _ := matcher.formattedPrint(actual) |
| return fmt.Sprintf("Expected\n%s\nnot to match XML of\n%s", actualString, expectedString) |
| } |
| |
| func (matcher *MatchXMLMatcher) formattedPrint(actual interface{}) (actualString, expectedString string, err error) { |
| var ok bool |
| actualString, ok = toString(actual) |
| if !ok { |
| return "", "", fmt.Errorf("MatchXMLMatcher matcher requires a string, stringer, or []byte. Got actual:\n%s", format.Object(actual, 1)) |
| } |
| expectedString, ok = toString(matcher.XMLToMatch) |
| if !ok { |
| return "", "", fmt.Errorf("MatchXMLMatcher matcher requires a string, stringer, or []byte. Got expected:\n%s", format.Object(matcher.XMLToMatch, 1)) |
| } |
| return actualString, expectedString, nil |
| } |
| |
| func parseXmlContent(content string) (*xmlNode, error) { |
| allNodes := []*xmlNode{} |
| |
| dec := newXmlDecoder(strings.NewReader(content)) |
| for { |
| tok, err := dec.Token() |
| if err != nil { |
| if err == io.EOF { |
| break |
| } |
| return nil, fmt.Errorf("failed to decode next token: %v", err) |
| } |
| |
| lastNodeIndex := len(allNodes) - 1 |
| var lastNode *xmlNode |
| if len(allNodes) > 0 { |
| lastNode = allNodes[lastNodeIndex] |
| } else { |
| lastNode = &xmlNode{} |
| } |
| |
| switch tok := tok.(type) { |
| case xml.StartElement: |
| attrs := attributesSlice(tok.Attr) |
| sort.Sort(attrs) |
| allNodes = append(allNodes, &xmlNode{XMLName: tok.Name, XMLAttr: tok.Attr}) |
| case xml.EndElement: |
| if len(allNodes) > 1 { |
| allNodes[lastNodeIndex-1].Nodes = append(allNodes[lastNodeIndex-1].Nodes, lastNode) |
| allNodes = allNodes[:lastNodeIndex] |
| } |
| case xml.CharData: |
| lastNode.Content = append(lastNode.Content, tok.Copy()...) |
| case xml.Comment: |
| lastNode.Comments = append(lastNode.Comments, tok.Copy()) |
| case xml.ProcInst: |
| lastNode.ProcInsts = append(lastNode.ProcInsts, tok.Copy()) |
| } |
| } |
| |
| if len(allNodes) == 0 { |
| return nil, errors.New("found no nodes") |
| } |
| firstNode := allNodes[0] |
| trimParentNodesContentSpaces(firstNode) |
| |
| return firstNode, nil |
| } |
| |
| func newXmlDecoder(reader io.Reader) *xml.Decoder { |
| dec := xml.NewDecoder(reader) |
| dec.CharsetReader = charset.NewReaderLabel |
| return dec |
| } |
| |
| func trimParentNodesContentSpaces(node *xmlNode) { |
| if len(node.Nodes) > 0 { |
| node.Content = bytes.TrimSpace(node.Content) |
| for _, childNode := range node.Nodes { |
| trimParentNodesContentSpaces(childNode) |
| } |
| } |
| } |
| |
| type xmlNode struct { |
| XMLName xml.Name |
| Comments []xml.Comment |
| ProcInsts []xml.ProcInst |
| XMLAttr []xml.Attr |
| Content []byte |
| Nodes []*xmlNode |
| } |