diff --git a/tests/acceptance/fakes/dest/write_wrong_magic.rb b/tests/acceptance/fakes/dest/write_wrong_magic.rb new file mode 100755 index 0000000..9a14ff7 --- /dev/null +++ b/tests/acceptance/fakes/dest/write_wrong_magic.rb @@ -0,0 +1,24 @@ +#!/usr/bin/env ruby +# encoding: utf-8 + +# Accept a connection, write hello, wait for a write request, read the +# data, then write back a reply with a bad magic field. We then +# expect a reconnect. + +require 'flexnbd/fake_dest' +include FlexNBD + +addr, port = *ARGV +server = FakeDest.new( addr, port ) + +client = server.accept +client.write_hello +req = client.read_request +client.read_data( req[:len] ) +client.write_reply( req[:handle], 0, :magic => :wrong ) + +client2 = server.accept +client.close +client2.close + +exit(0) diff --git a/tests/acceptance/flexnbd/fake_dest.rb b/tests/acceptance/flexnbd/fake_dest.rb index 23ade2f..6901967 100644 --- a/tests/acceptance/flexnbd/fake_dest.rb +++ b/tests/acceptance/flexnbd/fake_dest.rb @@ -56,10 +56,21 @@ module FlexNBD } end + REPLY_MAGIC="\x67\x44\x66\x98" def write_error( handle ) - @sock.write( "\x67\x44\x66\x98" ) - @sock.write( "\x00\x00\x00\x01" ) + write_reply( handle, 1 ) + end + + + def write_reply( handle, err=0, opts={} ) + if opts[:magic] == :wrong + write_rand( @sock, 4 ) + else + @sock.write( REPLY_MAGIC ) + end + + @sock.write( [err].pack("N") ) @sock.write( handle ) end @@ -69,6 +80,10 @@ module FlexNBD end + def read_data( len ) + @sock.read( len ) + end + def self.parse_be64(str) raise "String is the wrong length: 8 bytes expected (#{str.length} received)" unless diff --git a/tests/acceptance/test_source_error_handling.rb b/tests/acceptance/test_source_error_handling.rb index 09ccab7..835bc38 100644 --- a/tests/acceptance/test_source_error_handling.rb +++ b/tests/acceptance/test_source_error_handling.rb @@ -89,6 +89,13 @@ class TestSourceErrorHandling < Test::Unit::TestCase end + def test_bad_write_reply_causes_retry + run_fake( "dest/write_wrong_magic" ) + @env.mirror12_unchecked + assert_success + end + + private def run_fake(name) @env.run_fake( name, @env.ip, @env.port2 )