forked from localtunnel/go-localtunnel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlistener.go
226 lines (201 loc) · 4.52 KB
/
listener.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
package localtunnel
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"
)
// Listener implements a net.Listener using localtunnel.me
type Listener struct {
log Logger
remote string
url string
context context.Context
mErr sync.Mutex
err error
cancel func()
nConns counter
incoming chan net.Conn
done sync.WaitGroup
}
// Listen creates a *Listener that gets incoming connections from localtunnel.me
func Listen(options Options) (*Listener, error) {
options.setDefaults()
ctx, cancel := context.WithCancel(context.Background())
l := &Listener{
log: options.Log,
context: ctx,
cancel: cancel,
}
// Create a setup URL
setupURL := options.BaseURL + "/"
if options.Subdomain != "" {
setupURL += options.Subdomain
} else {
setupURL += "?new"
}
// Call the setupURL
l.log.Println("registering tunnel:", setupURL)
client := http.Client{Timeout: 30 * time.Second}
res, err := client.Get(setupURL)
if err != nil {
return nil, fmt.Errorf("failed to setup tunnel, error: %s", err)
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("internal server error, statusCode: %d", res.StatusCode)
}
body, err := readAtmost(res.Body, 4*1024)
if err != nil {
return nil, fmt.Errorf("failed to read server response, error: %s", err)
}
var reply struct {
ID string `json:"id"`
Port int `json:"port"`
MaxConnCount int `json:"max_conn_count"`
URL string `json:"url"`
}
err = json.Unmarshal(body, &reply)
if err != nil {
return nil, fmt.Errorf("failed to parse server response, error: %s", err)
}
l.log.Println("registered tunnel:", reply.URL)
// Set some sanity values
if reply.MaxConnCount == 0 {
reply.MaxConnCount = 1
}
if reply.MaxConnCount > options.MaxConnections {
reply.MaxConnCount = options.MaxConnections
}
// Extract remote host
u, _ := url.Parse(options.BaseURL)
l.remote = fmt.Sprintf("%s:%d", u.Hostname(), reply.Port)
// Set remote URL
l.url = reply.URL
// Start listening for new connections
l.incoming = make(chan net.Conn, reply.MaxConnCount)
l.done.Add(reply.MaxConnCount)
for i := 0; i < reply.MaxConnCount; i++ {
go l.proxy()
}
l.nConns.WaitFor(1)
return l, nil
}
// Accept returns the next incoming connection
func (l *Listener) Accept() (net.Conn, error) {
select {
case <-l.context.Done():
return nil, l.err
case c := <-l.incoming:
if c == nil {
return nil, l.err
}
return c, nil
}
}
func (l *Listener) proxy() {
var d net.Dialer
for l.context.Err() == nil {
// Dial with Context
var c net.Conn
var err error
for i := 0; i < 3; i++ {
time.Sleep(time.Duration(i*i) * 3 * time.Second)
c, err = d.DialContext(l.context, "tcp", l.remote)
if err == nil || l.context.Err() != nil {
break
}
l.log.Println("error opening connection to ", l.remote, "error:", err)
}
if err != nil {
l.abort(err)
break
}
l.nConns.Add(1)
err = l.handle(c)
if err != nil {
l.abort(err)
}
}
l.done.Done()
}
func (l *Listener) handle(c net.Conn) error {
var n int
var err error
var b [1]byte
// Ensure that we close the connection if we not done reading before
// context.Done()
doneReading := make(chan struct{})
go func() {
select {
case <-doneReading:
return
case <-l.context.Done():
c.Close()
}
}()
start := time.Now()
for n == 0 && err == nil {
n, err = c.Read(b[:])
}
close(doneReading)
if err != nil {
// Ignore if it took more than 30s
if start.Before(time.Now().Add(-30 * time.Second)) {
c.Close()
return nil
}
return err
}
l.nConns.Add(-1)
done := make(chan struct{})
l.incoming <- &conn{Conn: c, Buffer: b, Done: done}
// Wait for conn to be closed
select {
case <-done:
case <-l.context.Done():
}
// Always close the remote connection
c.Close()
return nil
}
// Addr implements net.Addr
type Addr struct {
URL string
}
// Addr returns an address representation in compliance with net.Listener
func (l *Listener) Addr() net.Addr {
return Addr{URL: l.url}
}
func (l *Listener) abort(err error) {
l.mErr.Lock()
defer l.mErr.Unlock()
// Only abort once
if l.err != nil {
return
}
l.err = err
// Close all tunnels and stop creating new ones
go func() {
l.cancel()
go func() {
for c := range l.incoming {
c.Close()
}
}()
l.done.Wait()
close(l.incoming)
}()
}
// Close the listener, breaking all connections proxied by this listener
func (l *Listener) Close() error {
l.abort(ErrListenerClosed)
l.done.Wait()
if l.err != ErrListenerClosed {
return l.err
}
return nil
}