blob: cf72ea990bda2c11ae7a87865c352b7fc9029730 [file] [log] [blame]
Matthias Andreas Benkard832a54e2019-01-29 09:27:38 +01001// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package astutil
6
7import (
8 "fmt"
9 "go/ast"
10 "reflect"
11 "sort"
12)
13
14// An ApplyFunc is invoked by Apply for each node n, even if n is nil,
15// before and/or after the node's children, using a Cursor describing
16// the current node and providing operations on it.
17//
18// The return value of ApplyFunc controls the syntax tree traversal.
19// See Apply for details.
20type ApplyFunc func(*Cursor) bool
21
22// Apply traverses a syntax tree recursively, starting with root,
23// and calling pre and post for each node as described below.
24// Apply returns the syntax tree, possibly modified.
25//
26// If pre is not nil, it is called for each node before the node's
27// children are traversed (pre-order). If pre returns false, no
28// children are traversed, and post is not called for that node.
29//
30// If post is not nil, and a prior call of pre didn't return false,
31// post is called for each node after its children are traversed
32// (post-order). If post returns false, traversal is terminated and
33// Apply returns immediately.
34//
35// Only fields that refer to AST nodes are considered children;
36// i.e., token.Pos, Scopes, Objects, and fields of basic types
37// (strings, etc.) are ignored.
38//
39// Children are traversed in the order in which they appear in the
40// respective node's struct definition. A package's files are
41// traversed in the filenames' alphabetical order.
42//
43func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
44 parent := &struct{ ast.Node }{root}
45 defer func() {
46 if r := recover(); r != nil && r != abort {
47 panic(r)
48 }
49 result = parent.Node
50 }()
51 a := &application{pre: pre, post: post}
52 a.apply(parent, "Node", nil, root)
53 return
54}
55
56var abort = new(int) // singleton, to signal termination of Apply
57
58// A Cursor describes a node encountered during Apply.
59// Information about the node and its parent is available
60// from the Node, Parent, Name, and Index methods.
61//
62// If p is a variable of type and value of the current parent node
63// c.Parent(), and f is the field identifier with name c.Name(),
64// the following invariants hold:
65//
66// p.f == c.Node() if c.Index() < 0
67// p.f[c.Index()] == c.Node() if c.Index() >= 0
68//
69// The methods Replace, Delete, InsertBefore, and InsertAfter
70// can be used to change the AST without disrupting Apply.
71type Cursor struct {
72 parent ast.Node
73 name string
74 iter *iterator // valid if non-nil
75 node ast.Node
76}
77
78// Node returns the current Node.
79func (c *Cursor) Node() ast.Node { return c.node }
80
81// Parent returns the parent of the current Node.
82func (c *Cursor) Parent() ast.Node { return c.parent }
83
84// Name returns the name of the parent Node field that contains the current Node.
85// If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
86// the filename for the current Node.
87func (c *Cursor) Name() string { return c.name }
88
89// Index reports the index >= 0 of the current Node in the slice of Nodes that
90// contains it, or a value < 0 if the current Node is not part of a slice.
91// The index of the current node changes if InsertBefore is called while
92// processing the current node.
93func (c *Cursor) Index() int {
94 if c.iter != nil {
95 return c.iter.index
96 }
97 return -1
98}
99
100// field returns the current node's parent field value.
101func (c *Cursor) field() reflect.Value {
102 return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
103}
104
105// Replace replaces the current Node with n.
106// The replacement node is not walked by Apply.
107func (c *Cursor) Replace(n ast.Node) {
108 if _, ok := c.node.(*ast.File); ok {
109 file, ok := n.(*ast.File)
110 if !ok {
111 panic("attempt to replace *ast.File with non-*ast.File")
112 }
113 c.parent.(*ast.Package).Files[c.name] = file
114 return
115 }
116
117 v := c.field()
118 if i := c.Index(); i >= 0 {
119 v = v.Index(i)
120 }
121 v.Set(reflect.ValueOf(n))
122}
123
124// Delete deletes the current Node from its containing slice.
125// If the current Node is not part of a slice, Delete panics.
126// As a special case, if the current node is a package file,
127// Delete removes it from the package's Files map.
128func (c *Cursor) Delete() {
129 if _, ok := c.node.(*ast.File); ok {
130 delete(c.parent.(*ast.Package).Files, c.name)
131 return
132 }
133
134 i := c.Index()
135 if i < 0 {
136 panic("Delete node not contained in slice")
137 }
138 v := c.field()
139 l := v.Len()
140 reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
141 v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
142 v.SetLen(l - 1)
143 c.iter.step--
144}
145
146// InsertAfter inserts n after the current Node in its containing slice.
147// If the current Node is not part of a slice, InsertAfter panics.
148// Apply does not walk n.
149func (c *Cursor) InsertAfter(n ast.Node) {
150 i := c.Index()
151 if i < 0 {
152 panic("InsertAfter node not contained in slice")
153 }
154 v := c.field()
155 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
156 l := v.Len()
157 reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
158 v.Index(i + 1).Set(reflect.ValueOf(n))
159 c.iter.step++
160}
161
162// InsertBefore inserts n before the current Node in its containing slice.
163// If the current Node is not part of a slice, InsertBefore panics.
164// Apply will not walk n.
165func (c *Cursor) InsertBefore(n ast.Node) {
166 i := c.Index()
167 if i < 0 {
168 panic("InsertBefore node not contained in slice")
169 }
170 v := c.field()
171 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
172 l := v.Len()
173 reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
174 v.Index(i).Set(reflect.ValueOf(n))
175 c.iter.index++
176}
177
178// application carries all the shared data so we can pass it around cheaply.
179type application struct {
180 pre, post ApplyFunc
181 cursor Cursor
182 iter iterator
183}
184
185func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
186 // convert typed nil into untyped nil
187 if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
188 n = nil
189 }
190
191 // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
192 saved := a.cursor
193 a.cursor.parent = parent
194 a.cursor.name = name
195 a.cursor.iter = iter
196 a.cursor.node = n
197
198 if a.pre != nil && !a.pre(&a.cursor) {
199 a.cursor = saved
200 return
201 }
202
203 // walk children
204 // (the order of the cases matches the order of the corresponding node types in go/ast)
205 switch n := n.(type) {
206 case nil:
207 // nothing to do
208
209 // Comments and fields
210 case *ast.Comment:
211 // nothing to do
212
213 case *ast.CommentGroup:
214 if n != nil {
215 a.applyList(n, "List")
216 }
217
218 case *ast.Field:
219 a.apply(n, "Doc", nil, n.Doc)
220 a.applyList(n, "Names")
221 a.apply(n, "Type", nil, n.Type)
222 a.apply(n, "Tag", nil, n.Tag)
223 a.apply(n, "Comment", nil, n.Comment)
224
225 case *ast.FieldList:
226 a.applyList(n, "List")
227
228 // Expressions
229 case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
230 // nothing to do
231
232 case *ast.Ellipsis:
233 a.apply(n, "Elt", nil, n.Elt)
234
235 case *ast.FuncLit:
236 a.apply(n, "Type", nil, n.Type)
237 a.apply(n, "Body", nil, n.Body)
238
239 case *ast.CompositeLit:
240 a.apply(n, "Type", nil, n.Type)
241 a.applyList(n, "Elts")
242
243 case *ast.ParenExpr:
244 a.apply(n, "X", nil, n.X)
245
246 case *ast.SelectorExpr:
247 a.apply(n, "X", nil, n.X)
248 a.apply(n, "Sel", nil, n.Sel)
249
250 case *ast.IndexExpr:
251 a.apply(n, "X", nil, n.X)
252 a.apply(n, "Index", nil, n.Index)
253
254 case *ast.SliceExpr:
255 a.apply(n, "X", nil, n.X)
256 a.apply(n, "Low", nil, n.Low)
257 a.apply(n, "High", nil, n.High)
258 a.apply(n, "Max", nil, n.Max)
259
260 case *ast.TypeAssertExpr:
261 a.apply(n, "X", nil, n.X)
262 a.apply(n, "Type", nil, n.Type)
263
264 case *ast.CallExpr:
265 a.apply(n, "Fun", nil, n.Fun)
266 a.applyList(n, "Args")
267
268 case *ast.StarExpr:
269 a.apply(n, "X", nil, n.X)
270
271 case *ast.UnaryExpr:
272 a.apply(n, "X", nil, n.X)
273
274 case *ast.BinaryExpr:
275 a.apply(n, "X", nil, n.X)
276 a.apply(n, "Y", nil, n.Y)
277
278 case *ast.KeyValueExpr:
279 a.apply(n, "Key", nil, n.Key)
280 a.apply(n, "Value", nil, n.Value)
281
282 // Types
283 case *ast.ArrayType:
284 a.apply(n, "Len", nil, n.Len)
285 a.apply(n, "Elt", nil, n.Elt)
286
287 case *ast.StructType:
288 a.apply(n, "Fields", nil, n.Fields)
289
290 case *ast.FuncType:
291 a.apply(n, "Params", nil, n.Params)
292 a.apply(n, "Results", nil, n.Results)
293
294 case *ast.InterfaceType:
295 a.apply(n, "Methods", nil, n.Methods)
296
297 case *ast.MapType:
298 a.apply(n, "Key", nil, n.Key)
299 a.apply(n, "Value", nil, n.Value)
300
301 case *ast.ChanType:
302 a.apply(n, "Value", nil, n.Value)
303
304 // Statements
305 case *ast.BadStmt:
306 // nothing to do
307
308 case *ast.DeclStmt:
309 a.apply(n, "Decl", nil, n.Decl)
310
311 case *ast.EmptyStmt:
312 // nothing to do
313
314 case *ast.LabeledStmt:
315 a.apply(n, "Label", nil, n.Label)
316 a.apply(n, "Stmt", nil, n.Stmt)
317
318 case *ast.ExprStmt:
319 a.apply(n, "X", nil, n.X)
320
321 case *ast.SendStmt:
322 a.apply(n, "Chan", nil, n.Chan)
323 a.apply(n, "Value", nil, n.Value)
324
325 case *ast.IncDecStmt:
326 a.apply(n, "X", nil, n.X)
327
328 case *ast.AssignStmt:
329 a.applyList(n, "Lhs")
330 a.applyList(n, "Rhs")
331
332 case *ast.GoStmt:
333 a.apply(n, "Call", nil, n.Call)
334
335 case *ast.DeferStmt:
336 a.apply(n, "Call", nil, n.Call)
337
338 case *ast.ReturnStmt:
339 a.applyList(n, "Results")
340
341 case *ast.BranchStmt:
342 a.apply(n, "Label", nil, n.Label)
343
344 case *ast.BlockStmt:
345 a.applyList(n, "List")
346
347 case *ast.IfStmt:
348 a.apply(n, "Init", nil, n.Init)
349 a.apply(n, "Cond", nil, n.Cond)
350 a.apply(n, "Body", nil, n.Body)
351 a.apply(n, "Else", nil, n.Else)
352
353 case *ast.CaseClause:
354 a.applyList(n, "List")
355 a.applyList(n, "Body")
356
357 case *ast.SwitchStmt:
358 a.apply(n, "Init", nil, n.Init)
359 a.apply(n, "Tag", nil, n.Tag)
360 a.apply(n, "Body", nil, n.Body)
361
362 case *ast.TypeSwitchStmt:
363 a.apply(n, "Init", nil, n.Init)
364 a.apply(n, "Assign", nil, n.Assign)
365 a.apply(n, "Body", nil, n.Body)
366
367 case *ast.CommClause:
368 a.apply(n, "Comm", nil, n.Comm)
369 a.applyList(n, "Body")
370
371 case *ast.SelectStmt:
372 a.apply(n, "Body", nil, n.Body)
373
374 case *ast.ForStmt:
375 a.apply(n, "Init", nil, n.Init)
376 a.apply(n, "Cond", nil, n.Cond)
377 a.apply(n, "Post", nil, n.Post)
378 a.apply(n, "Body", nil, n.Body)
379
380 case *ast.RangeStmt:
381 a.apply(n, "Key", nil, n.Key)
382 a.apply(n, "Value", nil, n.Value)
383 a.apply(n, "X", nil, n.X)
384 a.apply(n, "Body", nil, n.Body)
385
386 // Declarations
387 case *ast.ImportSpec:
388 a.apply(n, "Doc", nil, n.Doc)
389 a.apply(n, "Name", nil, n.Name)
390 a.apply(n, "Path", nil, n.Path)
391 a.apply(n, "Comment", nil, n.Comment)
392
393 case *ast.ValueSpec:
394 a.apply(n, "Doc", nil, n.Doc)
395 a.applyList(n, "Names")
396 a.apply(n, "Type", nil, n.Type)
397 a.applyList(n, "Values")
398 a.apply(n, "Comment", nil, n.Comment)
399
400 case *ast.TypeSpec:
401 a.apply(n, "Doc", nil, n.Doc)
402 a.apply(n, "Name", nil, n.Name)
403 a.apply(n, "Type", nil, n.Type)
404 a.apply(n, "Comment", nil, n.Comment)
405
406 case *ast.BadDecl:
407 // nothing to do
408
409 case *ast.GenDecl:
410 a.apply(n, "Doc", nil, n.Doc)
411 a.applyList(n, "Specs")
412
413 case *ast.FuncDecl:
414 a.apply(n, "Doc", nil, n.Doc)
415 a.apply(n, "Recv", nil, n.Recv)
416 a.apply(n, "Name", nil, n.Name)
417 a.apply(n, "Type", nil, n.Type)
418 a.apply(n, "Body", nil, n.Body)
419
420 // Files and packages
421 case *ast.File:
422 a.apply(n, "Doc", nil, n.Doc)
423 a.apply(n, "Name", nil, n.Name)
424 a.applyList(n, "Decls")
425 // Don't walk n.Comments; they have either been walked already if
426 // they are Doc comments, or they can be easily walked explicitly.
427
428 case *ast.Package:
429 // collect and sort names for reproducible behavior
430 var names []string
431 for name := range n.Files {
432 names = append(names, name)
433 }
434 sort.Strings(names)
435 for _, name := range names {
436 a.apply(n, name, nil, n.Files[name])
437 }
438
439 default:
440 panic(fmt.Sprintf("Apply: unexpected node type %T", n))
441 }
442
443 if a.post != nil && !a.post(&a.cursor) {
444 panic(abort)
445 }
446
447 a.cursor = saved
448}
449
450// An iterator controls iteration over a slice of nodes.
451type iterator struct {
452 index, step int
453}
454
455func (a *application) applyList(parent ast.Node, name string) {
456 // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
457 saved := a.iter
458 a.iter.index = 0
459 for {
460 // must reload parent.name each time, since cursor modifications might change it
461 v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
462 if a.iter.index >= v.Len() {
463 break
464 }
465
466 // element x may be nil in a bad AST - be cautious
467 var x ast.Node
468 if e := v.Index(a.iter.index); e.IsValid() {
469 x = e.Interface().(ast.Node)
470 }
471
472 a.iter.step = 1
473 a.apply(parent, name, &a.iter, x)
474 a.iter.index += a.iter.step
475 }
476 a.iter = saved
477}