blob: 3b412ce818a7a7f9b120d9570070ab648b5eddec [file] [log] [blame]
Matthias Andreas Benkard832a54e2019-01-29 09:27:38 +01001package matchers
2
3import (
4 "bytes"
5 "encoding/xml"
6 "errors"
7 "fmt"
8 "io"
9 "reflect"
10 "sort"
11 "strings"
12
13 "github.com/onsi/gomega/format"
14 "golang.org/x/net/html/charset"
15)
16
17type MatchXMLMatcher struct {
18 XMLToMatch interface{}
19}
20
21func (matcher *MatchXMLMatcher) Match(actual interface{}) (success bool, err error) {
22 actualString, expectedString, err := matcher.formattedPrint(actual)
23 if err != nil {
24 return false, err
25 }
26
27 aval, err := parseXmlContent(actualString)
28 if err != nil {
29 return false, fmt.Errorf("Actual '%s' should be valid XML, but it is not.\nUnderlying error:%s", actualString, err)
30 }
31
32 eval, err := parseXmlContent(expectedString)
33 if err != nil {
34 return false, fmt.Errorf("Expected '%s' should be valid XML, but it is not.\nUnderlying error:%s", expectedString, err)
35 }
36
37 return reflect.DeepEqual(aval, eval), nil
38}
39
40func (matcher *MatchXMLMatcher) FailureMessage(actual interface{}) (message string) {
41 actualString, expectedString, _ := matcher.formattedPrint(actual)
42 return fmt.Sprintf("Expected\n%s\nto match XML of\n%s", actualString, expectedString)
43}
44
45func (matcher *MatchXMLMatcher) NegatedFailureMessage(actual interface{}) (message string) {
46 actualString, expectedString, _ := matcher.formattedPrint(actual)
47 return fmt.Sprintf("Expected\n%s\nnot to match XML of\n%s", actualString, expectedString)
48}
49
50func (matcher *MatchXMLMatcher) formattedPrint(actual interface{}) (actualString, expectedString string, err error) {
51 var ok bool
52 actualString, ok = toString(actual)
53 if !ok {
54 return "", "", fmt.Errorf("MatchXMLMatcher matcher requires a string, stringer, or []byte. Got actual:\n%s", format.Object(actual, 1))
55 }
56 expectedString, ok = toString(matcher.XMLToMatch)
57 if !ok {
58 return "", "", fmt.Errorf("MatchXMLMatcher matcher requires a string, stringer, or []byte. Got expected:\n%s", format.Object(matcher.XMLToMatch, 1))
59 }
60 return actualString, expectedString, nil
61}
62
63func parseXmlContent(content string) (*xmlNode, error) {
64 allNodes := []*xmlNode{}
65
66 dec := newXmlDecoder(strings.NewReader(content))
67 for {
68 tok, err := dec.Token()
69 if err != nil {
70 if err == io.EOF {
71 break
72 }
73 return nil, fmt.Errorf("failed to decode next token: %v", err)
74 }
75
76 lastNodeIndex := len(allNodes) - 1
77 var lastNode *xmlNode
78 if len(allNodes) > 0 {
79 lastNode = allNodes[lastNodeIndex]
80 } else {
81 lastNode = &xmlNode{}
82 }
83
84 switch tok := tok.(type) {
85 case xml.StartElement:
86 attrs := attributesSlice(tok.Attr)
87 sort.Sort(attrs)
88 allNodes = append(allNodes, &xmlNode{XMLName: tok.Name, XMLAttr: tok.Attr})
89 case xml.EndElement:
90 if len(allNodes) > 1 {
91 allNodes[lastNodeIndex-1].Nodes = append(allNodes[lastNodeIndex-1].Nodes, lastNode)
92 allNodes = allNodes[:lastNodeIndex]
93 }
94 case xml.CharData:
95 lastNode.Content = append(lastNode.Content, tok.Copy()...)
96 case xml.Comment:
97 lastNode.Comments = append(lastNode.Comments, tok.Copy())
98 case xml.ProcInst:
99 lastNode.ProcInsts = append(lastNode.ProcInsts, tok.Copy())
100 }
101 }
102
103 if len(allNodes) == 0 {
104 return nil, errors.New("found no nodes")
105 }
106 firstNode := allNodes[0]
107 trimParentNodesContentSpaces(firstNode)
108
109 return firstNode, nil
110}
111
112func newXmlDecoder(reader io.Reader) *xml.Decoder {
113 dec := xml.NewDecoder(reader)
114 dec.CharsetReader = charset.NewReaderLabel
115 return dec
116}
117
118func trimParentNodesContentSpaces(node *xmlNode) {
119 if len(node.Nodes) > 0 {
120 node.Content = bytes.TrimSpace(node.Content)
121 for _, childNode := range node.Nodes {
122 trimParentNodesContentSpaces(childNode)
123 }
124 }
125}
126
127type xmlNode struct {
128 XMLName xml.Name
129 Comments []xml.Comment
130 ProcInsts []xml.ProcInst
131 XMLAttr []xml.Attr
132 Content []byte
133 Nodes []*xmlNode
134}