sshproxy_rust/
main.rs

1use anyhow::{Context, Result};
2use clap::Parser;
3use reqwest::Client;
4
5// Platform-specific imports
6#[cfg(target_os = "macos")]
7use security_framework::passwords::{get_generic_password, set_generic_password};
8
9#[cfg(target_os = "linux")]
10use keyring::Entry;
11
12use std::os::unix::fs::PermissionsExt;
13use std::path::PathBuf;
14use std::{env, fs};
15
16const SERVICE_NAME: &str = "NERSC";
17const URL: &str = "https://sshproxy.nersc.gov";
18const SCOPE: &str = "default";
19
20#[derive(Parser)]
21#[command(
22    author = "Dinesh Kumar",
23    about = "Retrieve NERSC SSH keys using system credential storage",
24    long_about = None,
25    version = "2.0.0"
26    )]
27struct Args {
28    /// Username, if not provided, taken from USER env variable
29    // #[clap(long, env = "USER")]
30    username: Option<String>,
31
32    /// Update NERSC password in macOS Keychain
33    #[clap(short = 'p', long)]
34    update_password: bool,
35
36    /// Update NERSC TOTP secret in macOS Keychain
37    #[clap(long)]
38    update_secret: bool,
39}
40
41/// NERSC passwords expire every year.
42#[cfg(target_os = "macos")]
43fn update_password(username: &str, password: &str) -> Result<()> {
44    // save password
45    set_generic_password(SERVICE_NAME, username, password.as_bytes())
46        .context("Failed to save password to keychain")?;
47    Ok(())
48}
49/// usually totp secrets do not expire
50#[cfg(target_os = "macos")]
51fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
52    // save otp secret
53    let service = format!("{}_SECRET", SERVICE_NAME);
54    set_generic_password(&service, username, otp_secret.as_bytes())
55        .context("Failed to save OTP secret to keychain")?;
56    Ok(())
57}
58
59/// Retrieve password from macOS Keychain
60#[cfg(target_os = "macos")]
61fn get_password(username: &str) -> Result<String> {
62    let password = get_generic_password(SERVICE_NAME, username)
63        .context("Failed to retrieve password from keychain")?;
64    Ok(String::from_utf8(password.to_vec())?)
65}
66
67/// Retrieve OTP secret from macOS Keychain
68#[cfg(target_os = "macos")]
69fn get_otp_secret(username: &str) -> Result<String> {
70    let service = format!("{}_SECRET", SERVICE_NAME);
71    let secret = get_generic_password(&service, username)
72        .context("Failed to retrieve OTP secret from keychain")?;
73    Ok(String::from_utf8(secret.to_vec())?)
74}
75
76/// NERSC passwords expire every year.
77#[cfg(target_os = "linux")]
78fn update_password(username: &str, password: &str) -> Result<()> {
79    let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
80    entry
81        .set_password(password)
82        .context("Failed to save password to credential storage")?;
83    Ok(())
84}
85
86/// usually totp secrets do not expire
87#[cfg(target_os = "linux")]
88fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
89    let service = format!("{}_SECRET", SERVICE_NAME);
90    let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
91    entry
92        .set_password(otp_secret)
93        .context("Failed to save OTP secret to credential storage")?;
94    Ok(())
95}
96
97/// Retrieve password from credential storage
98#[cfg(target_os = "linux")]
99fn get_password(username: &str) -> Result<String> {
100    let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
101    entry
102        .get_password()
103        .context("Failed to retrieve password from credential storage")
104}
105
106/// Retrieve OTP secret from credential storage
107#[cfg(target_os = "linux")]
108fn get_otp_secret(username: &str) -> Result<String> {
109    let service = format!("{}_SECRET", SERVICE_NAME);
110    let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
111    entry
112        .get_password()
113        .context("Failed to retrieve OTP secret from credential storage")
114}
115
116/// Generate TOTP code from secret
117fn generate_totp(secret: &str) -> Result<String> {
118    use std::time::{SystemTime, UNIX_EPOCH};
119    use totp_lite::{totp_custom, Sha1};
120
121    // Decode base32 secret
122    let secret_bytes = data_encoding::BASE32_NOPAD
123        .decode(secret.to_uppercase().as_bytes())
124        .context("Failed to decode base32 OTP secret")?;
125
126    // Get current Unix timestamp
127    let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
128
129    // Generate TOTP (30 second interval, 6 digits)
130    let totp = totp_custom::<Sha1>(30, 6, &secret_bytes, timestamp);
131
132    Ok(format!("{:06}", totp))
133}
134
135/// Request SSH key and certificate from sshproxy API
136async fn request_ssh_key(username: &str, password_otp: &str) -> Result<String> {
137    let endpoint = format!("{}/create_pair/{}/", URL, SCOPE);
138
139    let client = Client::builder()
140        .http1_only()
141        .redirect(reqwest::redirect::Policy::none())
142        .build()?;
143
144    let request = client
145        .post(&endpoint)
146        .basic_auth(username, Some(password_otp));
147
148    let response = request
149        .send()
150        .await
151        .context("Failed to send request to sshproxy server")?;
152
153    let status = response.status();
154    let body = response.text().await?;
155
156    if !status.is_success() {
157        anyhow::bail!("Server returned error: {} - {}", status, body);
158    }
159
160    // Check for authentication failure
161    if body.contains("Authentication failed") {
162        anyhow::bail!("Authentication failed. Check your password and OTP");
163    }
164
165    // Check for valid RSA private key
166    if !body.contains("-----BEGIN RSA PRIVATE KEY-----")
167        && !body.contains("-----BEGIN OPENSSH PRIVATE KEY-----")
168    {
169        anyhow::bail!(
170            "Response does not contain a valid SSH private key:\n{}",
171            body
172        );
173    }
174
175    Ok(body)
176}
177
178/// Extract certificate from combined key file
179fn extract_certificate(key_content: &str) -> Result<String> {
180    for line in key_content.lines() {
181        if line.contains("ssh-rsa") || line.contains("ssh-ed25519") {
182            return Ok(line.to_string());
183        }
184    }
185    anyhow::bail!("No certificate found in key file")
186}
187
188/// Save key files to disk with proper permissions
189fn save_key_files(key_path: &PathBuf, key_content: &str, cert_content: &str) -> Result<()> {
190    // Save private key
191    fs::write(key_path, key_content).context("Failed to write private key")?;
192
193    // Set permissions to 600
194    let metadata = fs::metadata(key_path)?;
195    let mut permissions = metadata.permissions();
196    permissions.set_mode(0o600);
197    fs::set_permissions(key_path, permissions)?;
198
199    // Save certificate
200    let cert_path = key_path
201        .with_extension("")
202        .with_extension("pub")
203        .with_extension("");
204    let cert_path = format!("{}-cert.pub", cert_path.display());
205    fs::write(&cert_path, cert_content).context("Failed to write certificate")?;
206
207    // Generate and save public key using ssh-keygen
208    let output = std::process::Command::new("ssh-keygen")
209        .arg("-y")
210        .arg("-f")
211        .arg(key_path)
212        .output()
213        .context("Failed to generate public key with ssh-keygen")?;
214
215    if !output.status.success() {
216        anyhow::bail!(
217            "ssh-keygen failed: {}",
218            String::from_utf8_lossy(&output.stderr)
219        );
220    }
221
222    let pub_path = key_path.with_extension("pub");
223    fs::write(&pub_path, output.stdout).context("Failed to write public key")?;
224
225    Ok(())
226}
227
228/// Get certificate validity information
229fn get_cert_validity(cert_path: &str) -> Result<String> {
230    let output = std::process::Command::new("ssh-keygen")
231        .arg("-L")
232        .arg("-f")
233        .arg(cert_path)
234        .output()
235        .context("Failed to read certificate with ssh-keygen")?;
236
237    if !output.status.success() {
238        anyhow::bail!("ssh-keygen -L failed");
239    }
240
241    let output_str = String::from_utf8_lossy(&output.stdout);
242    for line in output_str.lines() {
243        if line.trim().starts_with("Valid:") {
244            return Ok(line.trim().to_string());
245        }
246    }
247
248    Ok("Valid: unknown".to_string())
249}
250
251#[tokio::main]
252async fn main() -> Result<()> {
253    // Parse command line arguments
254    let args = Args::parse();
255
256    // get username
257    let username = args.username.unwrap_or_else(|| {
258        env::var("USER")
259            .expect("Could not determine username from environment. Please provide --username.")
260    });
261
262    // check if we need to update password
263    if args.update_password {
264        println!("Enter new password for user {}: ", username);
265        let password = rpassword::read_password().context("Failed to read password")?;
266        update_password(&username, &password)?;
267        println!("Password updated successfully.");
268        return Ok(());
269    }
270
271    // check if we need to update otp secret
272    if args.update_secret {
273        println!("Enter TOTP secret for user {}: ", username);
274        let otp_secret = rpassword::read_password().context("Failed to read OTP secret")?;
275        update_secret(&username, &otp_secret)?;
276        println!("OTP secret updated successfully.");
277        return Ok(());
278    }
279
280    // Determine output path
281    let home = dirs::home_dir().context("Could not determine home directory")?;
282    let key_path = home.join(".ssh").join("nersc");
283
284    // Retrieve credentials from keychain
285    let password = get_password(&username)
286        .context("Failed to get password. Run with --update-password first")?;
287
288    let otp_secret = get_otp_secret(&username)
289        .context("Failed to get OTP secret. Run with --update-secret first")?;
290
291    // Generate TOTP code
292    let totp_code = generate_totp(&otp_secret)?;
293
294    // Combine password and OTP
295    let password_otp = format!("{}{}", password, totp_code);
296
297    println!("Requesting SSH key for user: {}", username);
298
299    // Request key from API
300    let key_content = request_ssh_key(&username, &password_otp).await?;
301
302    // Extract certificate
303    let cert_content = extract_certificate(&key_content)?;
304
305    // Save files
306    save_key_files(&key_path, &key_content, &cert_content)?;
307
308    println!("Successfully obtained ssh key: {}", key_path.display());
309
310    // Show validity
311    let cert_path = format!("{}-cert.pub", key_path.display());
312    if let Ok(validity) = get_cert_validity(&cert_path) {
313        println!("Key is {}", validity.to_lowercase());
314    }
315
316    Ok(())
317}