diff --git a/src/client.c b/src/client.c index 48a7a36..a7c275b 100644 --- a/src/client.c +++ b/src/client.c @@ -11,6 +11,13 @@ #include #include +#include +#include +#include + + + + struct client *client_create( struct server *serve, int socket ) { @@ -286,6 +293,49 @@ void client_write_init( struct client * client, uint64_t size ) } +/* Remove len bytes from the client socket. This is needed when the + * client sends a write we can't honour - we need to get rid of the + * bytes they've already written before we can look for another request. + */ +void client_flush( struct client * client, size_t len ) +{ + int devnull = open("/dev/null", O_WRONLY); + FATAL_IF_NEGATIVE( devnull, + "Couldn't open /dev/null: %s", strerror(errno)); + int pipes[2]; + pipe( pipes ); + + const unsigned int flags = SPLICE_F_MORE | SPLICE_F_MOVE; + size_t spliced = 0; + + while ( spliced < len ) { + ssize_t received = splice( + client->socket, NULL, + pipes[1], NULL, + len-spliced, flags ); + FATAL_IF_NEGATIVE( received, + "splice error: %s", + strerror(errno)); + ssize_t junked = 0; + while( junked < received ) { + ssize_t junk; + junk = splice( + pipes[0], NULL, + devnull, NULL, + received, flags ); + FATAL_IF_NEGATIVE( junk, + "splice error: %s", + strerror(errno)); + junked += junk; + } + spliced += received; + } + debug("Flushed %d bytes", len); + + + close( devnull ); +} + /* Check to see if the client's request needs a reply constructing. * Returns 1 if we do, 0 otherwise. @@ -321,6 +371,7 @@ int client_request_needs_reply( struct client * client, request.len ); client_write_reply( client, &request, 1 ); + client_flush( client, request.len ); client->disconnect = 0; return 0; } diff --git a/src/ioutil.h b/src/ioutil.h index 09552d9..0687cea 100644 --- a/src/ioutil.h +++ b/src/ioutil.h @@ -27,6 +27,9 @@ int readloop(int filedes, void *buffer, size_t size); */ int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count); +/** Repeat a splice() operation until we have 'len' bytes. */ +ssize_t spliceloop(int fd_in, loff_t *off_in, int fd_out, loff_t *off_out, size_t len, unsigned int flags2); + /** Copy ''len'' bytes from ''fd_in'' to ''fd_out'' by creating a temporary * pipe and using the Linux splice call repeatedly until it has transferred * all the data. Returns -1 on error. diff --git a/tests/acceptance/fakes/source/write_out_of_range.rb b/tests/acceptance/fakes/source/write_out_of_range.rb index cc093a7..2954fad 100755 --- a/tests/acceptance/fakes/source/write_out_of_range.rb +++ b/tests/acceptance/fakes/source/write_out_of_range.rb @@ -4,6 +4,9 @@ # Connect, read the hello then make a write request with an impossible # (from,len) pair. We expect an error response, and not to be # disconnected. +# +# We then expect to be able to issue a successful write: the destination +# has to flush the data in the socket. require 'flexnbd/fake_source' include FlexNBD @@ -11,12 +14,19 @@ include FlexNBD addr, port = *ARGV client = FakeSource.new( addr, port, "Timed out connecting" ) -client.read_hello -client.write_write_request( 1 << 31, 1 << 31, "myhandle" ) +hello = client.read_hello +client.write_write_request( hello[:size]+1, 32, "myhandle" ) +client.write_data("1"*32) response = client.read_response fail "Not an error" if response[:error] == 0 fail "Wrong handle" unless "myhandle" == response[:handle] +client.write_write_request( 0, 32 ) +client.write_data( "2"*32 ) +success_response = client.read_response + +fail "Second write failed" unless success_response[:error] == 0 + client.close exit(0) diff --git a/tests/acceptance/flexnbd/fake_source.rb b/tests/acceptance/flexnbd/fake_source.rb index 5ba8fb9..41ffdbe 100644 --- a/tests/acceptance/flexnbd/fake_source.rb +++ b/tests/acceptance/flexnbd/fake_source.rb @@ -26,9 +26,17 @@ module FlexNBD def read_hello() timing_out( FlexNBD::MS_HELLO_TIME_SECS, "Timed out waiting for hello." ) do - fail "No hello." unless (hello = @sock.read( 152 )) && - hello.length==152 - hello + fail "No hello." unless (hello = @sock.read( 152 )) && + hello.length==152 + + magic_s = hello[0..7] + ignore_s= hello[8..15] + size_s = hello[16..23] + + size_h, size_l = size_s.unpack("NN") + size = (size_h << 32) + size_l + + return { :magic => magic_s, :size => size } end end @@ -99,7 +107,7 @@ module FlexNBD { :magic => magic, - :error => error_s.unpack("N"), + :error => error_s.unpack("N").first, :handle => handle } end