blob: 8b6518b5c8fd3f8f4acddeb2419b1f63d342b473 [file] [log] [blame]
Matthias Andreas Benkard832a54e2019-01-29 09:27:38 +01001package leafnodes
2
3import (
4 "fmt"
5 "reflect"
6 "time"
7
8 "github.com/onsi/ginkgo/internal/codelocation"
9 "github.com/onsi/ginkgo/internal/failer"
10 "github.com/onsi/ginkgo/types"
11)
12
13type runner struct {
14 isAsync bool
15 asyncFunc func(chan<- interface{})
16 syncFunc func()
17 codeLocation types.CodeLocation
18 timeoutThreshold time.Duration
19 nodeType types.SpecComponentType
20 componentIndex int
21 failer *failer.Failer
22}
23
24func newRunner(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, nodeType types.SpecComponentType, componentIndex int) *runner {
25 bodyType := reflect.TypeOf(body)
26 if bodyType.Kind() != reflect.Func {
27 panic(fmt.Sprintf("Expected a function but got something else at %v", codeLocation))
28 }
29
30 runner := &runner{
31 codeLocation: codeLocation,
32 timeoutThreshold: timeout,
33 failer: failer,
34 nodeType: nodeType,
35 componentIndex: componentIndex,
36 }
37
38 switch bodyType.NumIn() {
39 case 0:
40 runner.syncFunc = body.(func())
41 return runner
42 case 1:
43 if !(bodyType.In(0).Kind() == reflect.Chan && bodyType.In(0).Elem().Kind() == reflect.Interface) {
44 panic(fmt.Sprintf("Must pass a Done channel to function at %v", codeLocation))
45 }
46
47 wrappedBody := func(done chan<- interface{}) {
48 bodyValue := reflect.ValueOf(body)
49 bodyValue.Call([]reflect.Value{reflect.ValueOf(done)})
50 }
51
52 runner.isAsync = true
53 runner.asyncFunc = wrappedBody
54 return runner
55 }
56
57 panic(fmt.Sprintf("Too many arguments to function at %v", codeLocation))
58}
59
60func (r *runner) run() (outcome types.SpecState, failure types.SpecFailure) {
61 if r.isAsync {
62 return r.runAsync()
63 } else {
64 return r.runSync()
65 }
66}
67
68func (r *runner) runAsync() (outcome types.SpecState, failure types.SpecFailure) {
69 done := make(chan interface{}, 1)
70
71 go func() {
72 finished := false
73
74 defer func() {
75 if e := recover(); e != nil || !finished {
76 r.failer.Panic(codelocation.New(2), e)
77 select {
78 case <-done:
79 break
80 default:
81 close(done)
82 }
83 }
84 }()
85
86 r.asyncFunc(done)
87 finished = true
88 }()
89
90 select {
91 case <-done:
92 case <-time.After(r.timeoutThreshold):
93 r.failer.Timeout(r.codeLocation)
94 }
95
96 failure, outcome = r.failer.Drain(r.nodeType, r.componentIndex, r.codeLocation)
97 return
98}
99func (r *runner) runSync() (outcome types.SpecState, failure types.SpecFailure) {
100 finished := false
101
102 defer func() {
103 if e := recover(); e != nil || !finished {
104 r.failer.Panic(codelocation.New(2), e)
105 }
106
107 failure, outcome = r.failer.Drain(r.nodeType, r.componentIndex, r.codeLocation)
108 }()
109
110 r.syncFunc()
111 finished = true
112
113 return
114}