diff --git a/src/client.c b/src/client.c index 3e3f95d..b575b27 100644 --- a/src/client.c +++ b/src/client.c @@ -154,6 +154,11 @@ void write_not_zeroes(struct client* client, uint64_t from, int len) } +int fd_read_request( int fd, struct nbd_request_raw *out_request) +{ + return readloop(fd, out_request, sizeof(struct nbd_request_raw)); +} + /* Returns 1 if *request was filled with a valid request which we should * try to honour. 0 otherwise. */ int client_read_request( struct client * client , struct nbd_request *out_request ) @@ -175,7 +180,7 @@ int client_read_request( struct client * client , struct nbd_request *out_reques return 0; } - if (readloop(client->socket, &request_raw, sizeof(request_raw)) == -1) { + if (fd_read_request(client->socket, &request_raw) == -1) { if (errno == 0) { debug("EOF reading request"); return 0; /* neat point to close the socket */ @@ -190,6 +195,22 @@ int client_read_request( struct client * client , struct nbd_request *out_reques return 1; } +int fd_write_reply( int fd, char *handle, int error ) +{ + struct nbd_reply reply; + struct nbd_reply_raw reply_raw; + + reply.magic = REPLY_MAGIC; + reply.error = error; + memcpy( reply.handle, handle, 8 ); + + nbd_h2r_reply( &reply, &reply_raw ); + + write( fd, &reply_raw, sizeof( reply_raw ) ); + + return 1; +} + /* Writes a reply to request *request, with error, to the client's * socket. @@ -198,20 +219,10 @@ int client_read_request( struct client * client , struct nbd_request *out_reques */ int client_write_reply( struct client * client, struct nbd_request *request, int error ) { - struct nbd_reply reply; - struct nbd_reply_raw reply_raw; - - reply.magic = REPLY_MAGIC; - reply.error = error; - memcpy( reply.handle, &request->handle, 8 ); - - nbd_h2r_reply( &reply, &reply_raw ); - - write( client->socket, &reply_raw, sizeof( reply_raw ) ); - - return 1; + return fd_write_reply( client->socket, request->handle, error); } + void client_write_init( struct client * client, uint64_t size ) { struct nbd_init init = {{0}}; diff --git a/src/control.c b/src/control.c index f88e0a9..88dcd32 100644 --- a/src/control.c +++ b/src/control.c @@ -155,9 +155,58 @@ int mirror_pass(struct server * serve, int should_lock, uint64_t *written) } +void mirror_transfer_control( struct mirror_status * mirror ) +{ + /* TODO: set up an error handler to clean up properly on ERROR. + */ + + /* A transfer of control is expressed as a 3-way handshake. + * First, We send a REQUEST_ENTRUST. If this fails to be + * received, this thread will simply block until the server is + * restarted. If the remote end doesn't understand it, it'll + * disconnect us, and an ERROR *should* bomb this thread. + * FIXME: make the ERROR work. + * If we get an explicit error back from the remote end, then + * again, this thread will bomb out. + * On receiving a valid response, we send a REQUEST_DISCONNECT, + * and we quit without checking for a response. This is the + * remote server's signal to assume control of the file. The + * reason we don't check for a response is the state we end up + * in if the final message goes astray: if we lose the + * REQUEST_DISCONNECT, the sender has quit and the receiver + * hasn't had a signal to take over yet, so the data is safe. + * If we were to wait for a response to the REQUEST_DISCONNECT, + * the sender and receiver would *both* be servicing write + * requests while the response was in flight, and if the + * response went astray we'd have two servers claiming + * responsibility for the same data. + */ + socket_nbd_entrust( mirror->client ); + socket_nbd_disconnect( mirror->client ); +} + + +/* THIS FUNCTION MUST ONLY BE CALLED WITH THE SERVER'S IO LOCKED. */ void mirror_on_exit( struct server * serve ) { + /* Send an explicit entrust and disconnect. After this + * point we cannot allow any reads or writes to the local file. + * We do this *before* trying to shut down the server so that if + * the transfer of control fails, we haven't stopped the server + * and already-connected clients don't get needlessly + * disconnected. + */ + mirror_transfer_control( serve->mirror ); + + /* If we're still here, the transfer of control went ok, and the + * remote is listening (or will be shortly). We can shut the + * server down. + * + * It doesn't matter if we get new client connections before + * now, the IO lock will stop them from doing anything. + */ serve_signal_close( serve ); + /* We have to wait until the server is closed before unlocking * IO. This is because the client threads check to see if the * server is still open before reading or writing inside their @@ -168,6 +217,7 @@ void mirror_on_exit( struct server * serve ) serve_wait_for_close( serve ); } + /** Thread launched to drive mirror process */ void* mirror_runner(void* serve_params_uncast) { diff --git a/src/readwrite.c b/src/readwrite.c index 8f1ad9b..aab1dad 100644 --- a/src/readwrite.c +++ b/src/readwrite.c @@ -109,6 +109,33 @@ void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) read_reply(fd, &request, &reply); } + +void socket_nbd_entrust( int fd ) +{ + struct nbd_request request; + struct nbd_reply reply; + + fill_request( &request, REQUEST_ENTRUST, 0, 0 ); + FATAL_IF_NEGATIVE( writeloop( fd, &request, sizeof( request ) ), + "Couldn't write request"); + read_reply( fd, &request, &reply ); +} + + +int socket_nbd_disconnect( int fd ) +{ + int success = 1; + struct nbd_request request; + + fill_request( &request, REQUEST_DISCONNECT, 0, 0 ); + /* FIXME: This shouldn't be a FATAL error. We should just drop + * the mirror without affecting the main server. + */ + FATAL_IF_NEGATIVE( writeloop( fd, &request, sizeof( request ) ), + "Failed to write the disconnect request." ); + return success; +} + #define CHECK_RANGE(error_type) { \ off64_t size = socket_nbd_read_hello(params->client); \ if (params->from < 0 || (params->from + params->len) > size) {\ diff --git a/src/readwrite.h b/src/readwrite.h index 22ce723..fb1736a 100644 --- a/src/readwrite.h +++ b/src/readwrite.h @@ -2,10 +2,15 @@ #define READWRITE_H +#include +#include + int socket_connect(struct sockaddr* to, struct sockaddr* from); off64_t socket_nbd_read_hello(int fd); void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf); void socket_nbd_write(int fd, off64_t from, int len, int out_fd, void* out_buf); +void socket_nbd_entrust(int fd); +int socket_nbd_disconnect( int fd ); #endif diff --git a/src/util.c b/src/util.c index 96736b8..510d4f8 100644 --- a/src/util.c +++ b/src/util.c @@ -23,6 +23,9 @@ void error_handler(int fatal __attribute__ ((unused)) ) DECLARE_ERROR_CONTEXT(context); if (!context) { + /* FIXME: This can't be right - by default we exit() + * with a status of 0 in this case. + */ pthread_exit((void*) 1); } diff --git a/tests/check_readwrite.c b/tests/check_readwrite.c new file mode 100644 index 0000000..a91b7f5 --- /dev/null +++ b/tests/check_readwrite.c @@ -0,0 +1,180 @@ +#include "readwrite.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "util.h" +#include "nbdtypes.h" + + + + +int fd_read_request( int, struct nbd_request_raw *); +int fd_write_reply( int, char *, int ); + +void dummy_error_handler(void * foo __attribute__((unused))) +{ + return; +} + +struct respond { + int sock_fds[2]; // server end + int do_fail; + pthread_t thread_id; + pthread_attr_t thread_attr; + struct nbd_request received; +}; + +void * responder( void *respond_uncast ) +{ + struct respond * resp = (struct respond *) respond_uncast; + int sock_fd = resp->sock_fds[1]; + struct nbd_request_raw request_raw; + char wrong_handle[] = "WHOOPSIE"; + + + if( fd_read_request( sock_fd, &request_raw ) == -1){ + fprintf(stderr, "Problem with fd_read_request\n"); + } else { + nbd_r2h_request( &request_raw, &resp->received); + if (resp->do_fail){ + fd_write_reply( sock_fd, wrong_handle, 0 ); + } + else { + fd_write_reply( sock_fd, resp->received.handle, 0 ); + } + } + return NULL; +} + + +struct respond * respond_create( int do_fail ) +{ + struct respond * respond = (struct respond *)calloc( 1, sizeof( struct respond ) ); + socketpair( PF_UNIX, SOCK_STREAM, 0, respond->sock_fds ); + respond->do_fail = do_fail; + + pthread_attr_init( &respond->thread_attr ); + pthread_create( &respond->thread_id, &respond->thread_attr, responder, respond ); + + return respond; +} + +void respond_destroy( struct respond * respond ){ + NULLCHECK( respond ); + + pthread_join( respond->thread_id, NULL ); + pthread_attr_destroy( &respond->thread_attr ); + + close( respond->sock_fds[0] ); + close( respond->sock_fds[1] ); + free( respond ); +} + + + +START_TEST( test_rejects_mismatched_handle ) +{ + struct respond * respond = respond_create( 1 ); + + DECLARE_ERROR_CONTEXT( error_context ); + error_init(); + error_set_handler( (cleanup_handler *)dummy_error_handler, error_context ); + + log_level=5; + socket_nbd_entrust( respond->sock_fds[0] ); + log_level=2; + + respond_destroy( respond ); +} +END_TEST + + +START_TEST( test_accepts_matched_handle ) +{ + struct respond * respond = respond_create( 0 ); + + socket_nbd_entrust( respond->sock_fds[0] ); + + respond_destroy( respond ); +} +END_TEST + + +START_TEST( test_entrust_type_sent ) +{ + struct respond * respond = respond_create( 0 ); + + socket_nbd_entrust( respond->sock_fds[0] ); + fail_unless( respond->received.type == REQUEST_ENTRUST, "Wrong type sent." ); + + respond_destroy( respond ); +} +END_TEST + + +START_TEST( test_disconnect_doesnt_read_reply ) +{ + struct respond * respond = respond_create( 1 ); + + socket_nbd_disconnect( respond->sock_fds[0] ); + + respond_destroy( respond ); +} +END_TEST + + +Suite* readwrite_suite(void) +{ + Suite *s = suite_create("acl"); + TCase *tc_transfer = tcase_create("entrust"); + TCase *tc_disconnect = tcase_create("disconnect"); + + + tcase_add_test_raise_signal(tc_transfer, test_rejects_mismatched_handle, 6); + tcase_add_exit_test(tc_transfer, test_accepts_matched_handle, 0); + tcase_add_test( tc_transfer, test_entrust_type_sent ); + + /* This test is a little funny. We respond with a dodgy handle + * and check that this *doesn't* cause a message rejection, + * because we want to know that the sender won't even try to + * read the response. + */ + tcase_add_exit_test( tc_disconnect, test_disconnect_doesnt_read_reply,0 ); + + suite_add_tcase(s, tc_transfer); + suite_add_tcase(s, tc_disconnect); + + return s; +} + + + +#ifdef DEBUG +# define LOG_LEVEL 0 +#else +# define LOG_LEVEL 2 +#endif + +int main(void) +{ + log_level = LOG_LEVEL; + int number_failed; + Suite *s = readwrite_suite(); + SRunner *sr = srunner_create(s); + srunner_run_all(sr, CK_NORMAL); + log_level = 0; + number_failed = srunner_ntests_failed(sr); + srunner_free(sr); + return (number_failed == 0) ? 0 : 1; +} +