Support rate limit (#145)
Support rate limit Reviewed-on: #145 Co-Authored-By: Lunny Xiao <xiaolunwen@gmail.com> Co-Committed-By: Lunny Xiao <xiaolunwen@gmail.com>
This commit was merged in pull request #145.
This commit is contained in:
@@ -15,6 +15,8 @@ import (
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"goftp.io/server/v2/ratelimit"
|
||||
)
|
||||
|
||||
// DataSocket describes a data socket is used to send non-control data between the client and
|
||||
@@ -39,6 +41,8 @@ type DataSocket interface {
|
||||
|
||||
type activeSocket struct {
|
||||
conn *net.TCPConn
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
sess *Session
|
||||
host string
|
||||
port int
|
||||
@@ -66,6 +70,8 @@ func newActiveSocket(sess *Session, remote string, port int) (DataSocket, error)
|
||||
socket := new(activeSocket)
|
||||
socket.sess = sess
|
||||
socket.conn = tcpConn
|
||||
socket.reader = ratelimit.Reader(tcpConn, sess.server.rateLimiter)
|
||||
socket.writer = ratelimit.Writer(tcpConn, sess.server.rateLimiter)
|
||||
socket.host = remote
|
||||
socket.port = port
|
||||
|
||||
@@ -81,15 +87,15 @@ func (socket *activeSocket) Port() int {
|
||||
}
|
||||
|
||||
func (socket *activeSocket) Read(p []byte) (n int, err error) {
|
||||
return socket.conn.Read(p)
|
||||
return socket.reader.Read(p)
|
||||
}
|
||||
|
||||
func (socket *activeSocket) ReadFrom(r io.Reader) (int64, error) {
|
||||
return socket.conn.ReadFrom(r)
|
||||
return io.Copy(socket.writer, r)
|
||||
}
|
||||
|
||||
func (socket *activeSocket) Write(p []byte) (n int, err error) {
|
||||
return socket.conn.Write(p)
|
||||
return socket.writer.Write(p)
|
||||
}
|
||||
|
||||
func (socket *activeSocket) Close() error {
|
||||
@@ -99,6 +105,8 @@ func (socket *activeSocket) Close() error {
|
||||
type passiveSocket struct {
|
||||
sess *Session
|
||||
conn net.Conn
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
port int
|
||||
host string
|
||||
ingress chan []byte
|
||||
@@ -169,7 +177,7 @@ func (socket *passiveSocket) Read(p []byte) (n int, err error) {
|
||||
if socket.err != nil {
|
||||
return 0, socket.err
|
||||
}
|
||||
return socket.conn.Read(p)
|
||||
return socket.reader.Read(p)
|
||||
}
|
||||
|
||||
func (socket *passiveSocket) ReadFrom(r io.Reader) (int64, error) {
|
||||
@@ -181,7 +189,7 @@ func (socket *passiveSocket) ReadFrom(r io.Reader) (int64, error) {
|
||||
|
||||
// For normal TCPConn, this will use sendfile syscall; if not,
|
||||
// it will just downgrade to normal read/write procedure
|
||||
return io.Copy(socket.conn, r)
|
||||
return io.Copy(socket.writer, r)
|
||||
}
|
||||
|
||||
func (socket *passiveSocket) Write(p []byte) (n int, err error) {
|
||||
@@ -190,7 +198,7 @@ func (socket *passiveSocket) Write(p []byte) (n int, err error) {
|
||||
if socket.err != nil {
|
||||
return 0, socket.err
|
||||
}
|
||||
return socket.conn.Write(p)
|
||||
return socket.writer.Write(p)
|
||||
}
|
||||
|
||||
func (socket *passiveSocket) Close() error {
|
||||
@@ -250,6 +258,8 @@ func (socket *passiveSocket) ListenAndServe() (err error) {
|
||||
}
|
||||
socket.err = nil
|
||||
socket.conn = conn
|
||||
socket.reader = ratelimit.Reader(socket.conn, socket.sess.server.rateLimiter)
|
||||
socket.writer = ratelimit.Writer(socket.conn, socket.sess.server.rateLimiter)
|
||||
_ = listener.Close()
|
||||
}()
|
||||
return nil
|
||||
|
||||
@@ -21,7 +21,8 @@ func main() {
|
||||
Name: "admin",
|
||||
Password: "admin",
|
||||
},
|
||||
Perm: server.NewSimplePerm("root", "root"),
|
||||
Perm: server.NewSimplePerm("root", "root"),
|
||||
RateLimit: 1000000, // 1MB/s limit
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
||||
38
ratelimit/limiter.go
Normal file
38
ratelimit/limiter.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright 2020 The goftp Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Limiter represents a rate limiter
|
||||
type Limiter struct {
|
||||
rate time.Duration
|
||||
count int64
|
||||
t time.Time
|
||||
}
|
||||
|
||||
// New create a limiter for transfer speed, parameter rate means bytes per second
|
||||
// 0 means don't limit
|
||||
func New(rate int64) *Limiter {
|
||||
return &Limiter{
|
||||
rate: time.Duration(rate),
|
||||
count: 0,
|
||||
t: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Wait sleep when write count bytes
|
||||
func (l *Limiter) Wait(count int) {
|
||||
if l.rate == 0 {
|
||||
return
|
||||
}
|
||||
l.count += int64(count)
|
||||
t := time.Duration(l.count)*time.Second/l.rate - time.Since(l.t)
|
||||
if t > 0 {
|
||||
time.Sleep(t)
|
||||
}
|
||||
}
|
||||
27
ratelimit/reader.go
Normal file
27
ratelimit/reader.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright 2020 The goftp Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ratelimit
|
||||
|
||||
import "io"
|
||||
|
||||
type reader struct {
|
||||
r io.Reader
|
||||
l *Limiter
|
||||
}
|
||||
|
||||
// Read Read
|
||||
func (r *reader) Read(buf []byte) (int, error) {
|
||||
n, err := r.r.Read(buf)
|
||||
r.l.Wait(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Reader returns a reader with limiter
|
||||
func Reader(r io.Reader, l *Limiter) io.Reader {
|
||||
return &reader{
|
||||
r: r,
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
26
ratelimit/writer.go
Normal file
26
ratelimit/writer.go
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright 2020 The goftp Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ratelimit
|
||||
|
||||
import "io"
|
||||
|
||||
type writer struct {
|
||||
w io.Writer
|
||||
l *Limiter
|
||||
}
|
||||
|
||||
// Write Write
|
||||
func (w *writer) Write(buf []byte) (int, error) {
|
||||
w.l.Wait(len(buf))
|
||||
return w.w.Write(buf)
|
||||
}
|
||||
|
||||
// Writer returns a writer with limiter
|
||||
func Writer(w io.Writer, l *Limiter) io.Writer {
|
||||
return &writer{
|
||||
w: w,
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"goftp.io/server/v2/ratelimit"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,6 +71,9 @@ type Options struct {
|
||||
|
||||
// A logger implementation, if nil the StdLogger is used
|
||||
Logger Logger
|
||||
|
||||
// Rate Limit per connection bytes per second, 0 means no limit
|
||||
RateLimit int64
|
||||
}
|
||||
|
||||
// Server is the root of your FTP application. You should instantiate one
|
||||
@@ -85,6 +90,8 @@ type Server struct {
|
||||
cancel context.CancelFunc
|
||||
feats string
|
||||
notifiers notifierList
|
||||
// rate limiter per connection
|
||||
rateLimiter *ratelimit.Limiter
|
||||
}
|
||||
|
||||
// ErrServerClosed is returned by ListenAndServe() or Serve() when a shutdown
|
||||
@@ -143,6 +150,7 @@ func optsWithDefaults(opts *Options) *Options {
|
||||
|
||||
newOpts.PublicIP = opts.PublicIP
|
||||
newOpts.PassivePorts = opts.PassivePorts
|
||||
newOpts.RateLimit = opts.RateLimit
|
||||
|
||||
return &newOpts
|
||||
}
|
||||
@@ -186,6 +194,7 @@ func NewServer(opts *Options) (*Server, error) {
|
||||
featCmds += " AUTH TLS\n PBSZ\n PROT\n"
|
||||
}
|
||||
s.feats = fmt.Sprintf(feats, featCmds)
|
||||
s.rateLimiter = ratelimit.New(opts.RateLimit)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user