diff --git a/src/common/readwrite.c b/src/common/readwrite.c index 30f63f6..81de855 100644 --- a/src/common/readwrite.c +++ b/src/common/readwrite.c @@ -41,7 +41,7 @@ int socket_connect(struct sockaddr* to, struct sockaddr* from) return fd; } -int nbd_check_hello( struct nbd_init_raw* init_raw, uint64_t* out_size ) +int nbd_check_hello( struct nbd_init_raw* init_raw, uint64_t* out_size, uint32_t* out_flags ) { if ( strncmp( init_raw->passwd, INIT_PASSWD, 8 ) != 0 ) { warn( "wrong passwd" ); @@ -55,6 +55,10 @@ int nbd_check_hello( struct nbd_init_raw* init_raw, uint64_t* out_size ) if ( NULL != out_size ) { *out_size = be64toh( init_raw->size ); } + + if ( NULL != out_flags ) { + *out_flags = be32toh( init_raw->flags ); + } return 1; fail: @@ -62,7 +66,7 @@ fail: } -int socket_nbd_read_hello( int fd, uint64_t* out_size ) +int socket_nbd_read_hello( int fd, uint64_t* out_size, uint32_t* out_flags ) { struct nbd_init_raw init_raw; @@ -72,16 +76,17 @@ int socket_nbd_read_hello( int fd, uint64_t* out_size ) return 0; } - return nbd_check_hello( &init_raw, out_size ); + return nbd_check_hello( &init_raw, out_size, out_flags ); } -void nbd_hello_to_buf( struct nbd_init_raw *buf, off64_t out_size ) +void nbd_hello_to_buf( struct nbd_init_raw *buf, off64_t out_size, uint32_t out_flags ) { struct nbd_init init; memcpy( &init.passwd, INIT_PASSWD, 8 ); init.magic = INIT_MAGIC; init.size = out_size; + init.flags = out_flags; memset( buf, 0, sizeof( struct nbd_init_raw ) ); // ensure reserved is 0s nbd_h2r_init( &init, buf ); @@ -89,10 +94,10 @@ void nbd_hello_to_buf( struct nbd_init_raw *buf, off64_t out_size ) return; } -int socket_nbd_write_hello(int fd, off64_t out_size) +int socket_nbd_write_hello( int fd, off64_t out_size, uint32_t out_flags ) { struct nbd_init_raw init_raw; - nbd_hello_to_buf( &init_raw, out_size ); + nbd_hello_to_buf( &init_raw, out_size, out_flags ); if ( 0 > writeloop( fd, &init_raw, sizeof( init_raw ) ) ) { warn( SHOW_ERRNO( "failed to write hello to socket" ) ); @@ -213,7 +218,8 @@ int socket_nbd_disconnect( int fd ) #define CHECK_RANGE(error_type) { \ uint64_t size;\ - int success = socket_nbd_read_hello(params->client, &size); \ + uint32_t flags;\ + int success = socket_nbd_read_hello(params->client, &size, &flags); \ if ( success ) {\ uint64_t endpoint = params->from + params->len; \ if (endpoint > size || \ diff --git a/src/common/readwrite.h b/src/common/readwrite.h index 04b12c6..8b7371b 100644 --- a/src/common/readwrite.h +++ b/src/common/readwrite.h @@ -7,8 +7,8 @@ #include "nbdtypes.h" int socket_connect(struct sockaddr* to, struct sockaddr* from); -int socket_nbd_read_hello(int fd, uint64_t* size); -int socket_nbd_write_hello(int fd, uint64_t size); +int socket_nbd_read_hello(int fd, uint64_t* size, uint32_t* flags); +int socket_nbd_write_hello(int fd, uint64_t size, uint32_t flags); void socket_nbd_read(int fd, uint64_t from, uint32_t len, int out_fd, void* out_buf, int timeout_secs); void socket_nbd_write(int fd, uint64_t from, uint32_t len, int out_fd, void* out_buf, int timeout_secs); int socket_nbd_disconnect( int fd ); @@ -16,8 +16,8 @@ int socket_nbd_disconnect( int fd ); /* as you can see, we're slowly accumulating code that should really be in an * NBD library */ -void nbd_hello_to_buf( struct nbd_init_raw* buf, uint64_t out_size ); -int nbd_check_hello( struct nbd_init_raw* init_raw, uint64_t* out_size ); +void nbd_hello_to_buf( struct nbd_init_raw* buf, uint64_t out_size, uint32_t out_flags ); +int nbd_check_hello( struct nbd_init_raw* init_raw, uint64_t* out_size, uint32_t* out_flags ); #endif diff --git a/src/proxy/proxy.c b/src/proxy/proxy.c index c8af40e..78726ca 100644 --- a/src/proxy/proxy.c +++ b/src/proxy/proxy.c @@ -106,7 +106,7 @@ void proxy_destroy( struct proxier* proxy ) } /* Shared between our two different connect_to_upstream paths */ -void proxy_finish_connect_to_upstream( struct proxier *proxy, uint64_t size ); +void proxy_finish_connect_to_upstream( struct proxier *proxy, uint64_t size, uint32_t flags ); /* Try to establish a connection to our upstream server. Return 1 on success, * 0 on failure. this is a blocking call that returns a non-blocking socket. @@ -120,12 +120,13 @@ int proxy_connect_to_upstream( struct proxier* proxy ) int fd = socket_connect( &proxy->connect_to.generic, connect_from ); uint64_t size = 0; + uint32_t flags = 0; if ( -1 == fd ) { return 0; } - if( !socket_nbd_read_hello( fd, &size ) ) { + if( !socket_nbd_read_hello( fd, &size, &flags ) ) { WARN_IF_NEGATIVE( sock_try_close( fd ), "Couldn't close() after failed read of NBD hello on fd %i", fd @@ -135,7 +136,7 @@ int proxy_connect_to_upstream( struct proxier* proxy ) proxy->upstream_fd = fd; sock_set_nonblock( fd, 1 ); - proxy_finish_connect_to_upstream( proxy, size ); + proxy_finish_connect_to_upstream( proxy, size, flags ); return 1; } @@ -191,7 +192,7 @@ error: return; } -void proxy_finish_connect_to_upstream( struct proxier *proxy, uint64_t size ) { +void proxy_finish_connect_to_upstream( struct proxier *proxy, uint64_t size, uint32_t flags ) { if ( proxy->upstream_size == 0 ) { info( "Size of upstream image is %"PRIu64" bytes", size ); @@ -203,6 +204,17 @@ void proxy_finish_connect_to_upstream( struct proxier *proxy, uint64_t size ) { } proxy->upstream_size = size; + + if ( proxy->upstream_flags == 0 ) { + info( "Upstream transmission flags set to %"PRIu32"", flags ); + } else if ( proxy->upstream_flags != flags ) { + warn( + "Upstream transmission flags changed from %"PRIu32" to %"PRIu32"", + proxy->upstream_flags, flags + ); + } + + proxy->upstream_flags = flags; if ( AF_UNIX != proxy->connect_to.family ) { if ( sock_set_tcp_nodelay( proxy->upstream_fd, 1 ) == -1 ) { @@ -516,11 +528,18 @@ int proxy_read_init_from_upstream( struct proxier* proxy, int state ) if ( proxy->init.needle == proxy->init.size ) { uint64_t upstream_size; - if ( !nbd_check_hello( (struct nbd_init_raw*) proxy->init.buf, &upstream_size ) ) { + uint32_t upstream_flags; + if ( !nbd_check_hello( (struct nbd_init_raw*) proxy->init.buf, &upstream_size, &upstream_flags ) ) { warn( "Upstream sent invalid init" ); goto disconnect; } + /* record the flags, and log the reconnection */ + // TODO: Should we call this at all here? We lose the + // upstream_size and flags otherwise, but then we can't + // renegotiate anyway. + // proxy_finish_connect_to_upstream( proxy, upstream_size, upstream_flags ); + /* Currently, we only get disconnected from upstream (so needing to come * here) when we have an outstanding request. If that becomes false, * we'll need to choose the right state to return to here */ @@ -683,7 +702,7 @@ void proxy_session( struct proxier* proxy ) /* First action: Write hello to downstream */ - nbd_hello_to_buf( (struct nbd_init_raw *) proxy->rsp.buf, proxy->upstream_size ); + nbd_hello_to_buf( (struct nbd_init_raw *) proxy->rsp.buf, proxy->upstream_size, proxy->upstream_flags ); proxy->rsp.size = sizeof( struct nbd_init_raw ); proxy->rsp.needle = 0; state = WRITE_TO_DOWNSTREAM; diff --git a/src/proxy/proxy.h b/src/proxy/proxy.h index 5bf24dd..4173fb7 100644 --- a/src/proxy/proxy.h +++ b/src/proxy/proxy.h @@ -46,10 +46,13 @@ struct proxier { /* This is the size we advertise to the downstream server */ uint64_t upstream_size; + /* These are thet transmission flags sent as part of the handshake */ + uint32_t upstream_flags; + /* We transform the raw request header into here */ struct nbd_request req_hdr; - /* We transform the raw reply header into here */ + /* We transform the raw reply header into here */ struct nbd_reply rsp_hdr; /* Used for our non-blocking negotiation with upstream. TODO: maybe use diff --git a/src/server/mirror.c b/src/server/mirror.c index b3d22d5..bd25008 100644 --- a/src/server/mirror.c +++ b/src/server/mirror.c @@ -293,7 +293,8 @@ int mirror_connect( struct mirror * mirror, uint64_t local_size ) if( FD_ISSET( mirror->client, &fds ) ){ uint64_t remote_size; - if ( socket_nbd_read_hello( mirror->client, &remote_size ) ) { + uint32_t remote_flags; + if ( socket_nbd_read_hello( mirror->client, &remote_size, &remote_flags ) ) { if( remote_size == local_size ){ connected = 1; mirror_set_state( mirror, MS_GO );