| 1 | #include <stdio.h> |
|---|
| 2 | #include <stdlib.h> |
|---|
| 3 | #include <errno.h> |
|---|
| 4 | #include <sys/types.h> |
|---|
| 5 | #include <sys/stat.h> |
|---|
| 6 | #include <fcntl.h> |
|---|
| 7 | #include "host.h" |
|---|
| 8 | #include "scoped_make_raw.h" |
|---|
| 9 | #include "ptyshell.h" |
|---|
| 10 | #include "unio.h" |
|---|
| 11 | |
|---|
| 12 | #include <iostream> |
|---|
| 13 | |
|---|
| 14 | namespace Partty { |
|---|
| 15 | |
|---|
| 16 | |
|---|
| 17 | // FIXME Serverとのコネクションが切断したら再接続する? |
|---|
| 18 | |
|---|
| 19 | Host::Host(int server_socket, char lock_code, |
|---|
| 20 | const char* session_name, size_t session_name_length, |
|---|
| 21 | const char* password, size_t password_length ) : |
|---|
| 22 | impl(new HostIMPL(server_socket, lock_code, |
|---|
| 23 | session_name, session_name_length, |
|---|
| 24 | password, password_length )) {} |
|---|
| 25 | |
|---|
| 26 | HostIMPL::HostIMPL(int server_socket, char lock_code, |
|---|
| 27 | const char* session_name, size_t session_name_length, |
|---|
| 28 | const char* password, size_t password_length ) : |
|---|
| 29 | server(server_socket), |
|---|
| 30 | m_lock_code(lock_code), m_locking(false), |
|---|
| 31 | m_session_name_length(session_name_length), |
|---|
| 32 | m_password_length(password_length) |
|---|
| 33 | { |
|---|
| 34 | if( session_name_length > MAX_SESSION_NAME_LENGTH ) { |
|---|
| 35 | throw initialize_error("session name is too long"); |
|---|
| 36 | } |
|---|
| 37 | if( password_length > MAX_PASSWORD_LENGTH ) { |
|---|
| 38 | throw initialize_error("password is too long"); |
|---|
| 39 | } |
|---|
| 40 | std::memcpy(m_session_name, session_name, m_session_name_length); |
|---|
| 41 | std::memcpy(m_password, password, m_password_length); |
|---|
| 42 | } |
|---|
| 43 | |
|---|
| 44 | |
|---|
| 45 | Host::~Host() { delete impl; } |
|---|
| 46 | HostIMPL::~HostIMPL() {} |
|---|
| 47 | |
|---|
| 48 | |
|---|
| 49 | int Host::run(void) { return impl->run(); } |
|---|
| 50 | int HostIMPL::run(void) |
|---|
| 51 | { |
|---|
| 52 | // Serverにヘッダを送る |
|---|
| 53 | negotiation_header_t header; |
|---|
| 54 | memcpy(header.magic, NEGOTIATION_MAGIC_STRING, NEGOTIATION_MAGIC_STRING_LENGTH); |
|---|
| 55 | |
|---|
| 56 | // headerにはネットワークバイトオーダーで入れる |
|---|
| 57 | header.user_name_length = htons(0); // user_nameは今のところ空 |
|---|
| 58 | header.session_name_length = htons(m_session_name_length); |
|---|
| 59 | header.password_length = htons(m_password_length); |
|---|
| 60 | if( write_all(server, (char*)&header, sizeof(header)) != sizeof(header) ) { |
|---|
| 61 | throw initialize_error("failed to send negotiation header"); |
|---|
| 62 | } |
|---|
| 63 | if( write_all(server, m_session_name, m_session_name_length) != m_session_name_length ) { |
|---|
| 64 | throw initialize_error("failed to send session name"); |
|---|
| 65 | } |
|---|
| 66 | if( write_all(server, m_password, m_password_length) != m_password_length ) { |
|---|
| 67 | throw initialize_error("failed to send session password"); |
|---|
| 68 | } |
|---|
| 69 | |
|---|
| 70 | // 新しい仮想端末を確保して、シェルを起動する |
|---|
| 71 | ptyshell psh(STDIN_FILENO); |
|---|
| 72 | if( psh.fork(NULL) < 0 ) { throw initialize_error("can't execute shell"); } |
|---|
| 73 | sh = psh.masterfd(); |
|---|
| 74 | |
|---|
| 75 | // 標準入力をRawモードにする |
|---|
| 76 | // makerawはPtyChildShellをforkした後 |
|---|
| 77 | // (forkする前にRawモードにすると子仮想端末までRawモードになってしまう) |
|---|
| 78 | PtyScopedMakeRaw makeraw(STDIN_FILENO); |
|---|
| 79 | |
|---|
| 80 | // 監視対象のファイルディスクリプタにO_NONBLOCKをセット |
|---|
| 81 | if( fcntl(STDIN_FILENO, F_SETFL, O_NONBLOCK) < 0 ) |
|---|
| 82 | { throw initialize_error("failed to set stdinput nonblocking mode"); } |
|---|
| 83 | if( fcntl(server, F_SETFL, O_NONBLOCK) < 0 ) |
|---|
| 84 | { throw initialize_error("failed to set server socket nonblocking mode"); } |
|---|
| 85 | if( fcntl(sh, F_SETFL, O_NONBLOCK) < 0 ) |
|---|
| 86 | { throw initialize_error("failed to set pty nonblocking mode"); } |
|---|
| 87 | |
|---|
| 88 | // mp::dispatchに登録 |
|---|
| 89 | using namespace mp::placeholders; |
|---|
| 90 | if( mpdp.add(STDIN_FILENO, mp::EV_READ, |
|---|
| 91 | mp::bind(&HostIMPL::io_stdin, this, _1, _2)) < 0 ) { |
|---|
| 92 | throw initialize_error("can't add stdinput to IO multiplexer"); |
|---|
| 93 | } |
|---|
| 94 | if( mpdp.add(server, mp::EV_READ, |
|---|
| 95 | mp::bind(&HostIMPL::io_server, this, _1, _2)) < 0 ) { |
|---|
| 96 | throw initialize_error("can't add server socket to IO multiplexer"); |
|---|
| 97 | } |
|---|
| 98 | if( mpdp.add(sh, mp::EV_READ, |
|---|
| 99 | mp::bind(&HostIMPL::io_shell, this, _1, _2)) < 0 ) { |
|---|
| 100 | throw initialize_error("can't add pty to IO multiplexer"); |
|---|
| 101 | } |
|---|
| 102 | |
|---|
| 103 | // 端末のウィンドウサイズを取得しておく |
|---|
| 104 | get_window_size(STDIN_FILENO, &winsz); |
|---|
| 105 | |
|---|
| 106 | // mp::dispatch::run |
|---|
| 107 | return mpdp.run(); |
|---|
| 108 | } |
|---|
| 109 | |
|---|
| 110 | |
|---|
| 111 | int HostIMPL::io_stdin(int fd, short event) |
|---|
| 112 | { |
|---|
| 113 | // 標準入力 -> シェル |
|---|
| 114 | ssize_t len = read(fd, shared_buffer, SHARED_BUFFER_SIZE); |
|---|
| 115 | if( len < 0 ) { |
|---|
| 116 | if( errno == EAGAIN || errno == EINTR ) { return 0; } |
|---|
| 117 | else { throw io_error("stdinput is broken"); } |
|---|
| 118 | } else if( len == 0 ) { throw io_end_error("end of stdinput"); } |
|---|
| 119 | // ブロックしながらシェルに書き込む |
|---|
| 120 | // XXX 書き込み可能になるまでビジーループ |
|---|
| 121 | if( write_all(sh, shared_buffer, len) != (size_t)len ){ |
|---|
| 122 | throw io_error("pty is broken"); |
|---|
| 123 | } |
|---|
| 124 | // 標準入力のウィンドウサイズが変更されたら子仮想端末にも反映する |
|---|
| 125 | struct winsize next; |
|---|
| 126 | get_window_size(fd, &next); |
|---|
| 127 | if( winsz.ws_row != next.ws_row || winsz.ws_col != next.ws_col ) { |
|---|
| 128 | set_window_size(sh, &next); |
|---|
| 129 | winsz = next; |
|---|
| 130 | } |
|---|
| 131 | // lock_codeが含まれていたらm_lockingをトグルする |
|---|
| 132 | for(const char *p=shared_buffer, *p_end=p+len; p != p_end; ++p) { |
|---|
| 133 | if(*p == m_lock_code) { |
|---|
| 134 | m_locking = !m_locking; |
|---|
| 135 | } |
|---|
| 136 | } |
|---|
| 137 | return 0; |
|---|
| 138 | } |
|---|
| 139 | |
|---|
| 140 | int HostIMPL::io_server(int fd, short event) |
|---|
| 141 | { |
|---|
| 142 | // Server -> シェル |
|---|
| 143 | ssize_t len = read(fd, shared_buffer, SHARED_BUFFER_SIZE); |
|---|
| 144 | if( len < 0 ) { |
|---|
| 145 | if( errno == EAGAIN || errno == EINTR ) { return 0; } |
|---|
| 146 | else { throw io_error("server connection is broken"); } |
|---|
| 147 | } else if( len == 0 ) { throw io_end_error("server connection closed"); } |
|---|
| 148 | // ロック中ならServerからの入力は捨てる |
|---|
| 149 | if( m_locking ) { return 0; } |
|---|
| 150 | // ブロックしながらシェルに書き込む |
|---|
| 151 | // XXX 書き込み可能になるまでビジーループ |
|---|
| 152 | if( write_all(sh, shared_buffer, len) != (size_t)len ){ |
|---|
| 153 | throw io_error("pty is broken"); |
|---|
| 154 | } |
|---|
| 155 | return 0; |
|---|
| 156 | } |
|---|
| 157 | |
|---|
| 158 | int HostIMPL::io_shell(int fd, short event) |
|---|
| 159 | { |
|---|
| 160 | // シェル -> 標準出力 |
|---|
| 161 | // シェル -> Server |
|---|
| 162 | ssize_t len = read(fd, shared_buffer, SHARED_BUFFER_SIZE); |
|---|
| 163 | if( len < 0 ) { |
|---|
| 164 | if( errno == EAGAIN || errno == EINTR ) { return 0; } |
|---|
| 165 | else { throw io_error("pty is broken"); } |
|---|
| 166 | } else if( len == 0 ) { throw io_end_error("session ends"); } |
|---|
| 167 | // ブロックしながら標準出力に書き込む |
|---|
| 168 | // XXX 書き込み可能になるまでビジーループ |
|---|
| 169 | if( write_all(STDOUT_FILENO, shared_buffer, len) != (size_t)len ) { |
|---|
| 170 | throw io_error("stdoutput is broken"); |
|---|
| 171 | } |
|---|
| 172 | // ブロックしながらServerに書き込む |
|---|
| 173 | // XXX 書き込み可能になるまでビジーループ |
|---|
| 174 | if( write_all(server, shared_buffer, len) != (size_t)len ) { |
|---|
| 175 | throw io_error("server connection is broken"); |
|---|
| 176 | } |
|---|
| 177 | // 標準出力の転送が終わるまで待つ |
|---|
| 178 | // (転送スピードを超過して書き込むと端末が壊れるため) |
|---|
| 179 | tcdrain(STDOUT_FILENO); |
|---|
| 180 | return 0; |
|---|
| 181 | } |
|---|
| 182 | |
|---|
| 183 | int HostIMPL::get_window_size(int fd, struct winsize* ws) |
|---|
| 184 | { |
|---|
| 185 | return ioctl(fd, TIOCGWINSZ, ws); |
|---|
| 186 | } |
|---|
| 187 | |
|---|
| 188 | int HostIMPL::set_window_size(int fd, struct winsize* ws) |
|---|
| 189 | { |
|---|
| 190 | return ioctl(fd, TIOCSWINSZ, ws); |
|---|
| 191 | } |
|---|
| 192 | |
|---|
| 193 | |
|---|
| 194 | } // namespace Partty |
|---|
| 195 | |
|---|