diff --git a/nfs/nfs.go b/nfs/nfs.go index f60594d..124aeb8 100644 --- a/nfs/nfs.go +++ b/nfs/nfs.go @@ -252,7 +252,7 @@ type FSInfo struct { Properties uint32 } -// Dial an RPC svc after getting the port from the portmapper +// DialService Dial an RPC svc after getting the port from the portmapper func DialService(addr string, prog rpc.Mapping) (*rpc.Client, error) { pm, err := rpc.DialPortmapper("tcp", addr) if err != nil { @@ -328,7 +328,7 @@ func dialService(addr string, port int) (*rpc.Client, error) { } func isAddrInUse(err error) bool { - if er, ok := (err.(*net.OpError)); ok { + if er, ok := err.(*net.OpError); ok { if syser, ok := er.Err.(*os.SyscallError); ok { return syser.Err == syscall.EADDRINUSE } diff --git a/nfs/rpc/portmap.go b/nfs/rpc/portmap.go index e8ad39b..2030800 100644 --- a/nfs/rpc/portmap.go +++ b/nfs/rpc/portmap.go @@ -5,6 +5,7 @@ package rpc import ( "fmt" + "io" "github.com/willscott/go-nfs-client/nfs/xdr" ) @@ -17,7 +18,9 @@ const ( PmapProg = 100000 PmapVers = 2 - PmapProcGetPort = 3 + PmapProcSetPort = 1 + PMapProcUnsetPort = 2 + PmapProcGetPort = 3 IPProtoTCP = 6 IPProtoUDP = 17 @@ -45,22 +48,7 @@ type Portmapper struct { } func (p *Portmapper) Getport(mapping Mapping) (int, error) { - type getport struct { - Header - Mapping - } - msg := &getport{ - Header{ - Rpcvers: 2, - Prog: PmapProg, - Vers: PmapVers, - Proc: PmapProcGetPort, - Cred: AuthNull, - Verf: AuthNull, - }, - mapping, - } - res, err := p.Call(msg) + res, err := p.call(PmapProcGetPort, mapping) if err != nil { return 0, err } @@ -71,6 +59,41 @@ func (p *Portmapper) Getport(mapping Mapping) (int, error) { return int(port), nil } +func (p *Portmapper) Setport(mapping Mapping) (bool, error) { + res, err := p.call(PmapProcSetPort, mapping) + if err != nil { + return false, err + } + + return xdr.ReadBoolean(res) +} + +func (p *Portmapper) Unsetport(mapping Mapping) (bool, error) { + res, err := p.call(PMapProcUnsetPort, mapping) + if err != nil { + return false, err + } + + return xdr.ReadBoolean(res) +} + +func (p *Portmapper) call(proc uint32, mapping Mapping) (io.ReadSeeker, error) { + return p.Call(struct { + Header + Mapping + }{ + Header: Header{ + Rpcvers: 2, + Prog: PmapProg, + Vers: PmapVers, + Proc: proc, + Cred: AuthNull, + Verf: AuthNull, + }, + Mapping: mapping, + }) +} + func DialPortmapper(net, host string) (*Portmapper, error) { client, err := DialTCP(net, nil, fmt.Sprintf("%s:%d", host, PmapPort)) if err != nil { diff --git a/nfs/xdr/decode.go b/nfs/xdr/decode.go index 9754798..9dfb776 100644 --- a/nfs/xdr/decode.go +++ b/nfs/xdr/decode.go @@ -23,6 +23,14 @@ func ReadUint32(r io.Reader) (uint32, error) { return n, nil } +func ReadBoolean(r io.Reader) (bool, error) { + var b bool + if err := Read(r, &b); err != nil { + return false, err + } + return b, nil +} + func ReadOpaque(r io.Reader) ([]byte, error) { length, err := ReadUint32(r) if err != nil {