commit 933f20823fcb085ea0934bd73154f16b3a8de3b6 Author: Brian Candler Date: Fri Apr 29 11:51:10 2011 +0100 Initial commit, work in progress diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..90281b0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +*~ +*.orig +*.gem +*.sw? +.irbrc +.#* +rdoc +coverage +tmp/* diff --git a/README b/README new file mode 100644 index 0000000..314bab6 --- /dev/null +++ b/README @@ -0,0 +1,43 @@ +Ruby Netlink +============ + +This library provides an API for using a Linux Netlink socket, for doing +things like manipulating IP interfaces, routes and firewall rules +programmatically. + +Useful reference material +========================= + +* http://www.linuxjournal.com/article/7356 +* http://people.redhat.com/nhorman/papers/netlink.pdf +* apt-get source iproute + +Note there are some errors in the nhorman paper. On page 8/9, it says + + nlmsg_pid ... Also note that it is + imperative that any program receiving netlink socket messages from + the kernel verify that this field is set to zero, or it is possible to expose + the software to unexpected influences from other non-privlidged user + space programs. + +However, what really needs to be checked is the pid in the sockaddr_nl +structure returned by recvmsg msghdr, as shown by this code in +lib/libnetlink.c: + + struct msghdr msg = { + .msg_name = &nladdr, + .msg_namelen = sizeof(nladdr), + .msg_iov = &iov, + .msg_iovlen = 1, + }; +... + status = recvmsg(rth->fd, &msg, 0); +... + if (nladdr.nl_pid != 0 || + h->nlmsg_pid != rth->local.nl_pid || + h->nlmsg_seq != rth->dump) { + +Copyright +========= + +Copyright (C) 2011 Bytemark Computer Consulting Ltd diff --git a/lib/netlink/constants.rb b/lib/netlink/constants.rb new file mode 100644 index 0000000..43f91d7 --- /dev/null +++ b/lib/netlink/constants.rb @@ -0,0 +1,149 @@ +require 'socket' +class Socket + # From /usr/include/bits/socket.h + PF_NETLINK = 16 unless defined? Socket::PF_NETLINK + AF_NETLINK = PF_NETLINK unless defined? Socket::AF_NETLINK +end + +module Netlink + # From linux/netlink.h + NETLINK_ROUTE = 0 + NETLINK_UNUSED = 1 + NETLINK_USERSOCK = 2 + NETLINK_FIREWALL = 3 + NETLINK_INET_DIAG = 4 + NETLINK_NFLOG = 5 + NETLINK_XFRM = 6 + NETLINK_SELINUX = 7 + NETLINK_ISCSI = 8 + NETLINK_AUDIT = 9 + NETLINK_FIB_LOOKUP = 10 + NETLINK_CONNECTOR = 11 + NETLINK_NETFILTER = 12 + NETLINK_IP6_FW = 13 + NETLINK_DNRTMSG = 14 + NETLINK_KOBJECT_UEVENT = 15 + NETLINK_GENERIC = 16 + NETLINK_SCSITRANSPORT = 18 + NETLINK_ECRYPTFS = 19 + + NLM_F_REQUEST = 1 + NLM_F_MULTI = 2 + NLM_F_ACK = 4 + NLM_F_ECHO = 8 + NLM_F_ROOT = 0x100 + NLM_F_MATCH = 0x200 + NLM_F_ATOMIC = 0x400 + NLM_F_DUMP = (NLM_F_ROOT|NLM_F_MATCH) + + NLM_F_REPLACE = 0x100 + NLM_F_EXCL = 0x200 + NLM_F_CREATE = 0x400 + NLM_F_APPEND = 0x800 + + NLMSG_ALIGNTO = 4 + + NLMSG_NOOP = 0x1 + NLMSG_ERROR = 0x2 + NLMSG_DONE = 0x3 + NLMSG_OVERRUN = 0x4 + + NETLINK_ADD_MEMBERSHIP = 1 + NETLINK_DROP_MEMBERSHIP = 2 + NETLINK_PKTINFO = 3 + NETLINK_BROADCAST_ERROR = 4 + NETLINK_NO_ENOBUFS = 5 + + NETLINK_UNCONNECTED = 0 + NETLINK_CONNECTED = 1 + + NLA_F_NESTED = (1 << 15) + NLA_F_NET_BYTEORDER = (1 << 14) + NLA_TYPE_MASK = ~(NLA_F_NESTED | NLA_F_NET_BYTEORDER) + + NLA_ALIGNTO = 4 + + # from linux/rtnetlink.h. + # Should we put each of these groups under separate namespace? + # Netlink::Message::GETROUTE + # Netlink::Route::Type::UNICAST + # Netlink::Route::Protocol::STATIC + RTM_NEWLINK = 16 + RTM_DELLINK = 17 + RTM_GETLINK = 18 + RTM_SETLINK = 19 + + RTM_NEWADDR = 20 + RTM_DELADDR = 21 + RTM_GETADDR = 22 + + RTM_NEWROUTE = 24 + RTM_DELROUTE = 25 + RTM_GETROUTE = 26 + + RTM_NEWNEIGH = 28 + RTM_DELNEIGH = 29 + RTM_GETNEIGH = 30 + + RTM_NEWRULE = 32 + RTM_DELRULE = 33 + RTM_GETRULE = 34 + + RTM_NEWQDISC = 36 + RTM_DELQDISC = 37 + RTM_GETQDISC = 38 + + RTM_NEWTCLASS = 40 + RTM_DELTCLASS = 41 + RTM_GETTCLASS = 42 + + RTM_NEWTFILTER = 44 + RTM_DELTFILTER = 45 + RTM_GETTFILTER = 46 + + RTM_NEWACTION = 48 + RTM_DELACTION = 49 + RTM_GETACTION = 50 + + RTM_NEWPREFIX = 52 + RTM_GETMULTICAST = 58 + RTM_GETANYCAST = 62 + + RTM_NEWNEIGHTBL = 64 + RTM_GETNEIGHTBL = 66 + RTM_SETNEIGHTBL = 67 + + RTM_NEWNDUSEROPT = 68 + + RTM_NEWADDRLABEL = 72 + RTM_DELADDRLABEL = 73 + RTM_GETADDRLABEL = 74 + + RTM_GETDCB = 78 + RTM_SETDCB = 79 + + # Route#type + + RTN_UNSPEC = 0 + RTN_UNICAST = 1 + RTN_LOCAL = 2 + RTN_BROADCAST = 3 + RTN_ANYCAST = 4 + RTN_MULTICAST = 5 + RTN_BLACKHOLE = 6 + RTN_UNREACHABLE = 7 + RTN_PROHIBIT = 8 + RTN_THROW = 9 + RTN_NAT = 10 + RTN_XRESOLVE = 11 + + # Route#protocol + + RTPROT_UNSPEC = 0 + RTPROT_REDIRECT = 1 + RTPROT_KERNEL = 2 + RTPROT_BOOT = 3 + RTPROT_STATIC = 4 + + # MORE TO GO +end diff --git a/lib/netlink/message.rb b/lib/netlink/message.rb new file mode 100644 index 0000000..8435cb6 --- /dev/null +++ b/lib/netlink/message.rb @@ -0,0 +1,151 @@ +require 'netlink/constants' + +module Netlink + # Base class for Netlink messages + class Message + # Map of numeric message type code => message class + CODE_TO_MESSAGE = {} + + # You can initialize a message from a Hash or from another + # instance of itself. + # + # class Foo < Message + # field :foo, "C", 0xff + # field :bar, "L", 0 + # end + # msg = Foo.new(:bar => 123) # or ("bar" => 123) + # msg2 = Foo.new(msg) + # msg3 = Foo.new(:qux => 999) # error, no method qux= + def initialize(h={}) + if h.instance_of?(self.class) + @attrs = h.to_hash.dup + else + @attrs = self.class::DEFAULTS.dup + h.each { |k,v| self[k] = v } + end + end + + def to_hash + @attrs + end + + def each(&blk) + @attrs.each(&blk) + end + + # Set a field by name. Can use either symbol or string as key. + def []=(k,v) + send "#{k}=", v + end + + # Retrieve a field by name. Must use symbol as key. + def [](k) + @attrs[k] + end + + def self.inherited(subclass) #:nodoc: + subclass.const_set(:FIELDS, []) + subclass.const_set(:FORMAT, "") + subclass.const_set(:DEFAULTS, {}) + end + + # Define which message type code(s) use this structure + def self.code(*codes) + codes.each { |code| CODE_TO_MESSAGE[code] = self } + end + + # Define a field for this message, which creates accessor methods. The + # "pattern" is the Array#pack or String#unpack code to extract this field. + def self.field(name, pattern, default=nil, opt={}) + self::FIELDS << name + self::FORMAT << pattern + self::DEFAULTS[name] = default + define_method name do + @attrs.fetch name + end + define_method "#{name}=" do |val| + @attrs.store name, val + end + end + + def self.uchar(name, *args); field name, "C", 0, *args; end + def self.uint16(name, *args); field name, "S", 0, *args; end + def self.uint32(name, *args); field name, "L", 0, *args; end + def self.char(name, *args); field name, "c", 0, *args; end + def self.int16(name, *args); field name, "s", 0, *args; end + def self.int32(name, *args); field name, "l", 0, *args; end + def self.ushort(name, *args); field name, "S_", 0, *args; end + def self.uint(name, *args); field name, "I", 0, *args; end + def self.ulong(name, *args); field name, "L_", 0, *args; end + def self.short(name, *args); field name, "s_", 0, *args; end + def self.int(name, *args); field name, "i", 0, *args; end + def self.long(name, *args); field name, "l_", 0, *args; end + + # Returns the packed binary representation of this message (without + # header, and not padded to NLMSG_ALIGNTO bytes) + def to_s + self.class::FIELDS.map { |key| self[key] }.pack(self.class::FORMAT) + end + + def inspect + "#<#{self.class} #{@attrs.inspect}>" + end + + # Convert a binary representation of this message into an object instance + def self.parse(str) + res = new + str.unpack(self::FORMAT).zip(self::FIELDS).each do |val, key| + res[key] = val + end + res + end + + NLMSG_ALIGNTO_1 = NLMSG_ALIGNTO-1 #:nodoc: + NLMSG_ALIGNTO_1_MASK = ~NLMSG_ALIGNTO_1 #:nodoc: + + # Round up a length to a multiple of NLMSG_ALIGNTO bytes + def self.align(n) + (n + NLMSG_ALIGNTO_1) & NLMSG_ALIGNTO_1_MASK + end + + PADDING = ("\000" * NLMSG_ALIGNTO).freeze #:nodoc: + + # Pad a string up to a multiple of NLMSG_ALIGNTO bytes. Returns str. + def self.pad(str) + str << PADDING[0, align(str.bytesize) - str.bytesize] + end + + end + + class Link < Message + code RTM_NEWLINK, RTM_DELLINK, RTM_GETLINK + uchar :family + uchar :pad + ushort :type + int :index + uint :flags + uint :change + end + + class Addr < Message + code RTM_NEWADDR, RTM_DELADDR, RTM_GETADDR + uchar :family + uchar :prefixlen + uchar :flags + uchar :scope + int :index + end + + class Route < Message + code RTM_NEWROUTE, RTM_DELROUTE, RTM_GETROUTE + uchar :family + uchar :dst_len + uchar :src_len + uchar :tos + uchar :table + uchar :protocol + uchar :scope + uchar :type + uint :flags + end +end diff --git a/lib/netlink/nlsocket.rb b/lib/netlink/nlsocket.rb new file mode 100644 index 0000000..facf229 --- /dev/null +++ b/lib/netlink/nlsocket.rb @@ -0,0 +1,175 @@ +require 'socket' +require 'netlink/constants' +require 'netlink/message' + +module Netlink + class NLSocket + DEFAULT_TIMEOUT = 2 + + SOCKADDR_PACK = "SSLL".freeze #:nodoc: + SOCKADDR_SIZE = 12 # :nodoc: + + # Generate a sockaddr_nl. Pass :pid and/or :groups. + def self.sockaddr(opt={}) + [Socket::AF_NETLINK, 0, opt[:pid] || 0, opt[:groups] || 0].pack("SSLL") + end + + # Default sockaddr_nl with 0 pid (send to kernel) and no multicast groups + SOCKADDR_DEFAULT = sockaddr.freeze + + # Check the sockaddr on a received message. Raises an error if the AF + # is not AF_NETLINK or the PID is not 0 (this is important for security) + def self.parse_sockaddr(str) + af, pad, pid, groups = str.unpack(SOCKADDR_PACK) + raise "Bad AF #{af}!" if af != Socket::AF_NETLINK + raise "Bad PID #{pid}!" if pid != 0 + end + + attr_accessor :socket + attr_accessor :seq + attr_accessor :pid + + # Create a new Netlink socket. Pass in chosen protocol: + # :protocol => Netlink::NETLINK_ARPD + # :protocol => Netlink::NETLINK_FIREWALL + # :protocol => Netlink::NETLINK_IP6_FW + # :protocol => Netlink::NETLINK_NFLOG + # :protocol => Netlink::NETLINK_ROUTE + # :protocol => Netlink::NETLINK_ROUTE6 + # :protocol => Netlink::NETLINK_TAPBASE + # :protocol => Netlink::NETLINK_TCPDIAG + # :protocol => Netlink::NETLINK_XFRM + # Other options: + # :groups => N (subscribe to multicastgroups, default to 0) + # :seq => N (override initial sequence number) + # :pid => N (override PID) + # :timeout => N (seconds, default to DEFAULT_TIMEOUT. Pass nil for no timeout) + def initialize(opt) + @socket ||= opt[:socket] || ::Socket.new( + Socket::AF_NETLINK, + Socket::SOCK_DGRAM, + opt[:protocol] || (raise "Missing :protocol") + ) + @socket.bind(NLSocket.sockaddr(opt)) unless opt[:socket] + @seq = opt[:seq] || Time.now.to_i + @pid = opt[:pid] || $$ + @timeout = opt.has_key?(:timeout) ? opt[:timeout] : DEFAULT_TIMEOUT + end + + # Send a Netlink::Message object over the socket + # obj:: the object to send (responds to #to_s) + # flags:: message header flags, default NLM_F_REQUEST + # sockaddr:: destination sockaddr, defaults to pid=0 and groups=0 + # seq:: sequence number, defaults to bump internal sequence + # pid:: pid, defaults to $$ + # vflags:: sendmsg flags, defaults to 0 + def send_request(type, obj, flags=NLM_F_REQUEST, sockaddr=SOCKADDR_DEFAULT, seq=(@seq += 1), pid=@pid, vflags=0, controls=[]) + @socket.sendmsg( + build_message(type, obj, flags, seq, pid), + vflags, sockaddr, *controls + ) + end + + NLMSGHDR_PACK = "LSSLL".freeze # :nodoc: + NLMSGHDR_SIZE = 16 # :nodoc: + + # Build a message comprising header+body. It is not padded at the end. + def build_message(type, body, flags=NLM_F_REQUEST, seq=(@seq += 1), pid=@pid) + body = body.to_s + header = [ + body.bytesize + NLMSGHDR_SIZE, + type, flags, seq, pid + ].pack(NLMSGHDR_PACK) + # assume the header is already aligned + header + body + end + + # Send multiple Netlink::Message objects in a single message. They + # need to share the same type and flags, and will be sent with sequential + # sequence nos. + def send_requests(type, objs, flags=NLM_F_REQUEST, pid=@pid) + objs.each_with_index do |obj, index| + if index < objs.size - 1 + data << build_message(type, obj, flags|NLM_F_MULTI, @seq+=1, pid) + Message.pad(data) + else + data << build_message(type, obj, flags, @seq+=1, pid) + end + end + end + + # Discard all waiting messages + def flush + while select([@socket], nil, nil, 0) + @socket.recvmsg + end + end + + # Loop receiving responses until Netlink::Message::Done, and yielding + # the objects found. Also filters so that only expected pid and seq + # are accepted. + # + # (Compare: rtnl_dump_filter_l in lib/libnetlink.c) + def receive_until_done(timeout=@timeout, junk_handler=nil, &blk) #:yields: type, flags, obj + res = [] + blk ||= lambda { |type, flags, obj| res << obj if obj } + junk_handler ||= lambda { |obj| warn "Discarding junk message #{obj}" } if $VERBOSE + loop do + receive_response(timeout) do |type, flags, seq, pid, obj| + if pid != @pid || seq != @seq + junk_handler[obj] if junk_handler + next + end + case type + when NLMSG_DONE + return res + when NLMSG_ERROR + raise "Netlink Error received" + end + blk.call(type, flags, obj) + end + end + end + + # Receive one datagram from kernel. If a block is given, then yield + # Netlink::Message objects (maybe multiple times if the datagram + # includes multiple netlink messages). + # + # receive_response { |msg| p msg } + def receive_response(timeout=@timeout, &blk) # :yields: type, flags, seq, pid, Message + if select([@socket], nil, nil, timeout) + mesg, sender, rflags, controls = @socket.recvmsg + raise EOFError unless mesg + NLSocket.parse_sockaddr(sender.to_sockaddr) + parse_yield(mesg, &blk) + else + raise "Timeout" + end + end + + # Parse message(s) in a string buffer and yield message object, flags, + # seq and pid + def parse_yield(mesg) # :yields: type, flags, seq, pid, Message + dechunk(mesg) do |h_type, h_flags, h_seq, h_pid, data| + klass = Message::CODE_TO_MESSAGE[h_type] + yield h_type, h_flags, h_seq, h_pid, klass && klass.parse(data) + end + end + + # Take message(s) in a string buffer and yield fields in turn + def dechunk(mesg) # :yields: type, flags, seq, pid, data + ptr = 0 + while ptr < mesg.bytesize + raise "Truncated netlink header!" if ptr + NLMSGHDR_SIZE > mesg.bytesize + len, type, flags, seq, pid = mesg[ptr,NLMSGHDR_SIZE].unpack(NLMSGHDR_PACK) + STDERR.puts " len=#{len}, type=#{type}, flags=#{flags}, seq=#{seq}, pid=#{pid}" if $DEBUG + raise "Truncated netlink message!" if ptr + len > mesg.bytesize + data = mesg[ptr+NLMSGHDR_SIZE, len-NLMSGHDR_SIZE] + STDERR.puts " data=#{data.inspect}" if $DEBUG && !data.empty? + yield type, flags, seq, pid, data + ptr = ptr + Message.align(len) + break unless flags & Netlink::NLM_F_MULTI + end + end + end +end diff --git a/lib/netlink/rtsocket.rb b/lib/netlink/rtsocket.rb new file mode 100644 index 0000000..7f63030 --- /dev/null +++ b/lib/netlink/rtsocket.rb @@ -0,0 +1,53 @@ +require 'netlink/nlsocket' +require 'netlink/message' + +module Netlink + # This is the high-level API using a NETLINK_ROUTE protocol socket + class RTSocket < NLSocket + def initialize(opt={}) + super(opt.merge(:protocol => Netlink::NETLINK_ROUTE)) + end + + # List links. Returns an array of Netlink::Link objects + def link_list(opt) + send_request RTM_GETLINK, Link.new(opt), + NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST + receive_until_done + end + + # List routes. Returns an array of Netlink::Route objects + # res = nl.routes(:family => Socket::AF_INET) + # #=> [..., ...] + def route_list(opt) + send_request RTM_GETROUTE, Route.new(opt), + NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST + receive_until_done + end + + def addr_list(opt) + send_request RTM_GETADDR, Addr.new(opt), + NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST + receive_until_done + end + + # Add a route + # nl.add_route(:family => Socket::AF_INET, ...) + def route_add(r) + send_request RTM_NEWROUTE, Route.new(r) + # Do we get any success/fail? + end + + # Delete a route + def route_delete(r) + send_request RTM_DELROUTE, Route.new(r) + end + end +end + +if __FILE__ == $0 + require 'pp' + nl = Netlink::RTSocket.new + pp nl.route_list(:family => Socket::AF_INET) + pp nl.link_list(:family => Socket::AF_INET) + pp nl.addr_list(:family => Socket::AF_INET) +end