mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2024-10-30 14:19:12 +01:00
395 lines
11 KiB
Go
395 lines
11 KiB
Go
|
package ssh
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
gossh "golang.org/x/crypto/ssh"
|
||
|
)
|
||
|
|
||
|
// ErrServerClosed is returned by the Server's Serve, ListenAndServe,
|
||
|
// and ListenAndServeTLS methods after a call to Shutdown or Close.
|
||
|
var ErrServerClosed = errors.New("ssh: Server closed")
|
||
|
|
||
|
type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
|
||
|
|
||
|
var DefaultRequestHandlers = map[string]RequestHandler{}
|
||
|
|
||
|
type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
||
|
|
||
|
var DefaultChannelHandlers = map[string]ChannelHandler{
|
||
|
"session": DefaultSessionHandler,
|
||
|
}
|
||
|
|
||
|
// Server defines parameters for running an SSH server. The zero value for
|
||
|
// Server is a valid configuration. When both PasswordHandler and
|
||
|
// PublicKeyHandler are nil, no client authentication is performed.
|
||
|
type Server struct {
|
||
|
Addr string // TCP address to listen on, ":22" if empty
|
||
|
Handler Handler // handler to invoke, ssh.DefaultHandler if nil
|
||
|
HostSigners []Signer // private keys for the host key, must have at least one
|
||
|
Version string // server version to be sent before the initial handshake
|
||
|
|
||
|
KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler
|
||
|
PasswordHandler PasswordHandler // password authentication handler
|
||
|
PublicKeyHandler PublicKeyHandler // public key authentication handler
|
||
|
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
|
||
|
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
|
||
|
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
|
||
|
ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
|
||
|
ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
|
||
|
SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions
|
||
|
|
||
|
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
||
|
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
||
|
|
||
|
// ChannelHandlers allow overriding the built-in session handlers or provide
|
||
|
// extensions to the protocol, such as tcpip forwarding. By default only the
|
||
|
// "session" handler is enabled.
|
||
|
ChannelHandlers map[string]ChannelHandler
|
||
|
|
||
|
// RequestHandlers allow overriding the server-level request handlers or
|
||
|
// provide extensions to the protocol, such as tcpip forwarding. By default
|
||
|
// no handlers are enabled.
|
||
|
RequestHandlers map[string]RequestHandler
|
||
|
|
||
|
listenerWg sync.WaitGroup
|
||
|
mu sync.Mutex
|
||
|
listeners map[net.Listener]struct{}
|
||
|
conns map[*gossh.ServerConn]struct{}
|
||
|
connWg sync.WaitGroup
|
||
|
doneChan chan struct{}
|
||
|
}
|
||
|
|
||
|
func (srv *Server) ensureHostSigner() error {
|
||
|
if len(srv.HostSigners) == 0 {
|
||
|
signer, err := generateSigner()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
srv.HostSigners = append(srv.HostSigners, signer)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (srv *Server) ensureHandlers() {
|
||
|
srv.mu.Lock()
|
||
|
defer srv.mu.Unlock()
|
||
|
if srv.RequestHandlers == nil {
|
||
|
srv.RequestHandlers = map[string]RequestHandler{}
|
||
|
for k, v := range DefaultRequestHandlers {
|
||
|
srv.RequestHandlers[k] = v
|
||
|
}
|
||
|
}
|
||
|
if srv.ChannelHandlers == nil {
|
||
|
srv.ChannelHandlers = map[string]ChannelHandler{}
|
||
|
for k, v := range DefaultChannelHandlers {
|
||
|
srv.ChannelHandlers[k] = v
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (srv *Server) config(ctx Context) *gossh.ServerConfig {
|
||
|
var config *gossh.ServerConfig
|
||
|
if srv.ServerConfigCallback == nil {
|
||
|
config = &gossh.ServerConfig{}
|
||
|
} else {
|
||
|
config = srv.ServerConfigCallback(ctx)
|
||
|
}
|
||
|
for _, signer := range srv.HostSigners {
|
||
|
config.AddHostKey(signer)
|
||
|
}
|
||
|
if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil {
|
||
|
config.NoClientAuth = true
|
||
|
}
|
||
|
if srv.Version != "" {
|
||
|
config.ServerVersion = "SSH-2.0-" + srv.Version
|
||
|
}
|
||
|
if srv.PasswordHandler != nil {
|
||
|
config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) {
|
||
|
applyConnMetadata(ctx, conn)
|
||
|
if ok := srv.PasswordHandler(ctx, string(password)); !ok {
|
||
|
return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
|
||
|
}
|
||
|
return ctx.Permissions().Permissions, nil
|
||
|
}
|
||
|
}
|
||
|
if srv.PublicKeyHandler != nil {
|
||
|
config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
|
||
|
applyConnMetadata(ctx, conn)
|
||
|
if ok := srv.PublicKeyHandler(ctx, key); !ok {
|
||
|
return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
|
||
|
}
|
||
|
ctx.SetValue(ContextKeyPublicKey, key)
|
||
|
return ctx.Permissions().Permissions, nil
|
||
|
}
|
||
|
}
|
||
|
if srv.KeyboardInteractiveHandler != nil {
|
||
|
config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) {
|
||
|
if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok {
|
||
|
return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
|
||
|
}
|
||
|
return ctx.Permissions().Permissions, nil
|
||
|
}
|
||
|
}
|
||
|
return config
|
||
|
}
|
||
|
|
||
|
// Handle sets the Handler for the server.
|
||
|
func (srv *Server) Handle(fn Handler) {
|
||
|
srv.Handler = fn
|
||
|
}
|
||
|
|
||
|
// Close immediately closes all active listeners and all active
|
||
|
// connections.
|
||
|
//
|
||
|
// Close returns any error returned from closing the Server's
|
||
|
// underlying Listener(s).
|
||
|
func (srv *Server) Close() error {
|
||
|
srv.mu.Lock()
|
||
|
defer srv.mu.Unlock()
|
||
|
srv.closeDoneChanLocked()
|
||
|
err := srv.closeListenersLocked()
|
||
|
for c := range srv.conns {
|
||
|
c.Close()
|
||
|
delete(srv.conns, c)
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Shutdown gracefully shuts down the server without interrupting any
|
||
|
// active connections. Shutdown works by first closing all open
|
||
|
// listeners, and then waiting indefinitely for connections to close.
|
||
|
// If the provided context expires before the shutdown is complete,
|
||
|
// then the context's error is returned.
|
||
|
func (srv *Server) Shutdown(ctx context.Context) error {
|
||
|
srv.mu.Lock()
|
||
|
lnerr := srv.closeListenersLocked()
|
||
|
srv.closeDoneChanLocked()
|
||
|
srv.mu.Unlock()
|
||
|
|
||
|
finished := make(chan struct{}, 1)
|
||
|
go func() {
|
||
|
srv.listenerWg.Wait()
|
||
|
srv.connWg.Wait()
|
||
|
finished <- struct{}{}
|
||
|
}()
|
||
|
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return ctx.Err()
|
||
|
case <-finished:
|
||
|
return lnerr
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Serve accepts incoming connections on the Listener l, creating a new
|
||
|
// connection goroutine for each. The connection goroutines read requests and then
|
||
|
// calls srv.Handler to handle sessions.
|
||
|
//
|
||
|
// Serve always returns a non-nil error.
|
||
|
func (srv *Server) Serve(l net.Listener) error {
|
||
|
srv.ensureHandlers()
|
||
|
defer l.Close()
|
||
|
if err := srv.ensureHostSigner(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if srv.Handler == nil {
|
||
|
srv.Handler = DefaultHandler
|
||
|
}
|
||
|
var tempDelay time.Duration
|
||
|
|
||
|
srv.trackListener(l, true)
|
||
|
defer srv.trackListener(l, false)
|
||
|
for {
|
||
|
conn, e := l.Accept()
|
||
|
if e != nil {
|
||
|
select {
|
||
|
case <-srv.getDoneChan():
|
||
|
return ErrServerClosed
|
||
|
default:
|
||
|
}
|
||
|
if ne, ok := e.(net.Error); ok && ne.Temporary() {
|
||
|
if tempDelay == 0 {
|
||
|
tempDelay = 5 * time.Millisecond
|
||
|
} else {
|
||
|
tempDelay *= 2
|
||
|
}
|
||
|
if max := 1 * time.Second; tempDelay > max {
|
||
|
tempDelay = max
|
||
|
}
|
||
|
time.Sleep(tempDelay)
|
||
|
continue
|
||
|
}
|
||
|
return e
|
||
|
}
|
||
|
go srv.handleConn(conn)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (srv *Server) handleConn(newConn net.Conn) {
|
||
|
if srv.ConnCallback != nil {
|
||
|
cbConn := srv.ConnCallback(newConn)
|
||
|
if cbConn == nil {
|
||
|
newConn.Close()
|
||
|
return
|
||
|
}
|
||
|
newConn = cbConn
|
||
|
}
|
||
|
ctx, cancel := newContext(srv)
|
||
|
conn := &serverConn{
|
||
|
Conn: newConn,
|
||
|
idleTimeout: srv.IdleTimeout,
|
||
|
closeCanceler: cancel,
|
||
|
}
|
||
|
if srv.MaxTimeout > 0 {
|
||
|
conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
|
||
|
}
|
||
|
defer conn.Close()
|
||
|
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
|
||
|
if err != nil {
|
||
|
// TODO: trigger event callback
|
||
|
return
|
||
|
}
|
||
|
|
||
|
srv.trackConn(sshConn, true)
|
||
|
defer srv.trackConn(sshConn, false)
|
||
|
|
||
|
ctx.SetValue(ContextKeyConn, sshConn)
|
||
|
applyConnMetadata(ctx, sshConn)
|
||
|
//go gossh.DiscardRequests(reqs)
|
||
|
go srv.handleRequests(ctx, reqs)
|
||
|
for ch := range chans {
|
||
|
handler := srv.ChannelHandlers[ch.ChannelType()]
|
||
|
if handler == nil {
|
||
|
handler = srv.ChannelHandlers["default"]
|
||
|
}
|
||
|
if handler == nil {
|
||
|
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
||
|
continue
|
||
|
}
|
||
|
go handler(srv, sshConn, ch, ctx)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
|
||
|
for req := range in {
|
||
|
handler := srv.RequestHandlers[req.Type]
|
||
|
if handler == nil {
|
||
|
handler = srv.RequestHandlers["default"]
|
||
|
}
|
||
|
if handler == nil {
|
||
|
req.Reply(false, nil)
|
||
|
continue
|
||
|
}
|
||
|
/*reqCtx, cancel := context.WithCancel(ctx)
|
||
|
defer cancel() */
|
||
|
ret, payload := handler(ctx, srv, req)
|
||
|
req.Reply(ret, payload)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ListenAndServe listens on the TCP network address srv.Addr and then calls
|
||
|
// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used.
|
||
|
// ListenAndServe always returns a non-nil error.
|
||
|
func (srv *Server) ListenAndServe() error {
|
||
|
addr := srv.Addr
|
||
|
if addr == "" {
|
||
|
addr = ":22"
|
||
|
}
|
||
|
ln, err := net.Listen("tcp", addr)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return srv.Serve(ln)
|
||
|
}
|
||
|
|
||
|
// AddHostKey adds a private key as a host key. If an existing host key exists
|
||
|
// with the same algorithm, it is overwritten. Each server config must have at
|
||
|
// least one host key.
|
||
|
func (srv *Server) AddHostKey(key Signer) {
|
||
|
// these are later added via AddHostKey on ServerConfig, which performs the
|
||
|
// check for one of every algorithm.
|
||
|
srv.HostSigners = append(srv.HostSigners, key)
|
||
|
}
|
||
|
|
||
|
// SetOption runs a functional option against the server.
|
||
|
func (srv *Server) SetOption(option Option) error {
|
||
|
return option(srv)
|
||
|
}
|
||
|
|
||
|
func (srv *Server) getDoneChan() <-chan struct{} {
|
||
|
srv.mu.Lock()
|
||
|
defer srv.mu.Unlock()
|
||
|
return srv.getDoneChanLocked()
|
||
|
}
|
||
|
|
||
|
func (srv *Server) getDoneChanLocked() chan struct{} {
|
||
|
if srv.doneChan == nil {
|
||
|
srv.doneChan = make(chan struct{})
|
||
|
}
|
||
|
return srv.doneChan
|
||
|
}
|
||
|
|
||
|
func (srv *Server) closeDoneChanLocked() {
|
||
|
ch := srv.getDoneChanLocked()
|
||
|
select {
|
||
|
case <-ch:
|
||
|
// Already closed. Don't close again.
|
||
|
default:
|
||
|
// Safe to close here. We're the only closer, guarded
|
||
|
// by srv.mu.
|
||
|
close(ch)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (srv *Server) closeListenersLocked() error {
|
||
|
var err error
|
||
|
for ln := range srv.listeners {
|
||
|
if cerr := ln.Close(); cerr != nil && err == nil {
|
||
|
err = cerr
|
||
|
}
|
||
|
delete(srv.listeners, ln)
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (srv *Server) trackListener(ln net.Listener, add bool) {
|
||
|
srv.mu.Lock()
|
||
|
defer srv.mu.Unlock()
|
||
|
if srv.listeners == nil {
|
||
|
srv.listeners = make(map[net.Listener]struct{})
|
||
|
}
|
||
|
if add {
|
||
|
// If the *Server is being reused after a previous
|
||
|
// Close or Shutdown, reset its doneChan:
|
||
|
if len(srv.listeners) == 0 && len(srv.conns) == 0 {
|
||
|
srv.doneChan = nil
|
||
|
}
|
||
|
srv.listeners[ln] = struct{}{}
|
||
|
srv.listenerWg.Add(1)
|
||
|
} else {
|
||
|
delete(srv.listeners, ln)
|
||
|
srv.listenerWg.Done()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
|
||
|
srv.mu.Lock()
|
||
|
defer srv.mu.Unlock()
|
||
|
if srv.conns == nil {
|
||
|
srv.conns = make(map[*gossh.ServerConn]struct{})
|
||
|
}
|
||
|
if add {
|
||
|
srv.conns[c] = struct{}{}
|
||
|
srv.connWg.Add(1)
|
||
|
} else {
|
||
|
delete(srv.conns, c)
|
||
|
srv.connWg.Done()
|
||
|
}
|
||
|
}
|