diff --git a/lib/netlink/nlsocket.rb b/lib/netlink/nlsocket.rb index 6a9ea30..ced2b3b 100644 --- a/lib/netlink/nlsocket.rb +++ b/lib/netlink/nlsocket.rb @@ -10,6 +10,13 @@ module Netlink ERRNO_MAP[klass::Errno] = klass end + # Raise an Errno exception if the given rc is not NOERROR + def self.check_error(rc) + if -rc != Errno::NOERROR::Errno + raise ERRNO_MAP[-rc] || "Netlink Error: #{msg.inspect}" + end + end + # NLSocket provides low-level sending and receiving of messages across # a netlink socket, adding headers to sent messages and parsing # received messages. @@ -38,6 +45,7 @@ module Netlink attr_accessor :seq # the last sequence number used attr_accessor :pid # default pid to include in message headers attr_accessor :timeout # default timeout when receiving message + attr_accessor :junk_handler # proc to log or handle unexpected messages # Create a new Netlink socket. Pass in chosen protocol: # :protocol => Netlink::NETLINK_ARPD @@ -137,18 +145,19 @@ module Netlink end # Loop receiving responses until a DONE message is received (or you - # break out of the loop, or a timeout exception occurs). + # break out of the loop, or a timeout exception occurs). Checks the + # message type and pid/seq. # # Yields Netlink::Message objects, or if no block is given, returns an # array of those objects. # # (Compare: rtnl_dump_filter_l in lib/libnetlink.c) - def receive_until_done(expected_type=nil, timeout=@timeout, &blk) #:yields: msg + def receive_until_done(expected_type, timeout=@timeout, &blk) #:yields: msg res = [] blk ||= lambda { |obj| res << obj } receive_responses(true, timeout) do |type,msg| return res if type == NLMSG_DONE - if expected_type && type != expected_type + if type != expected_type false else blk.call(msg) if msg @@ -156,11 +165,12 @@ module Netlink end end - # Loop infinitely receiving responses and yielding message objects - # of the given type. - def receive_stream(expected_type=nil) + # This is the entry point for protocols which yield an infinite stream + # of messages (e.g. firewall, ulog). There is no timeout, and + # the pid/seq are not checked. + def receive_stream(expected_type) #:yields: msg receive_responses(false, nil) do |type, msg| - if expected_type && type != expected_type + if type != expected_type false else yield msg @@ -168,19 +178,19 @@ module Netlink end end - # This is the main loop for receiving responses. It optionally checks - # the pid/seq of received messages, and discards those which don't match. - # Raises an exception on NLMSG_ERROR. + # This is the main loop for receiving responses, yielding the type and + # message object for each received message. It optionally checks the pid/seq + # and discards those which don't match. If the block returns 'false' then + # they are also logged as junk. # - # Matching messages are yielded to the block. If the block returns - # false then they are treated as junk. - def receive_responses(check_pid_seq=false, timeout=nil) + # Raises an exception on NLMSG_ERROR (other than Errno::NOERROR), or if + # no packet received within the specified timeout. Pass nil for infinite + # timeout. + def receive_responses(check_pid_seq, timeout=@timeout) loop do - receive_response(timeout) do |type, flags, seq, pid, msg| + parse_yield(recvmsg(timeout)) do |type, flags, seq, pid, msg| if !check_pid_seq || (pid == @pid && seq == @seq) - if type == NLMSG_ERROR && -msg.error != Errno::NOERROR::Errno - raise ERRNO_MAP[-msg.error] || "Netlink Error: #{msg.inspect}" - end + self.class.check_error(msg.error) if type == NLMSG_ERROR res = yield type, msg next unless res == false end @@ -203,17 +213,6 @@ module Netlink end end - # Receive one datagram from kernel. Yield header fields plus - # Netlink::Message objects (maybe multiple times if the datagram - # includes multiple netlink messages). Raise an exception if no - # datagram received within the specified or default timeout period; - # pass nil for infinite timeout. - # - # receive_response { |type, flags, seq, pid, msg| p msg } - def receive_response(timeout=@timeout, &blk) # :yields: type, flags, seq, pid, Message - parse_yield(recvmsg(timeout), &blk) - end - # Parse netlink packet in a string buffer. Yield header fields plus # a Netlink::Message-derived object for each message. For unknown message # types it will yield a raw String, or nil if there is no message body.