sshproxy_rust/
main.rs

1use anyhow::{Context, Result};
2use clap::Parser;
3use reqwest::Client;
4
5// Platform-specific imports
6#[cfg(target_os = "macos")]
7use security_framework::passwords::{get_generic_password, set_generic_password};
8
9#[cfg(any(target_os = "linux", target_os = "windows"))]
10use keyring::Entry;
11
12#[cfg(unix)]
13use std::os::unix::fs::PermissionsExt;
14use std::path::PathBuf;
15use std::{env, fs};
16
17const SERVICE_NAME: &str = "NERSC";
18const URL: &str = "https://sshproxy.nersc.gov";
19const SCOPE: &str = "default";
20
21#[derive(Parser)]
22#[command(
23    author = "Dinesh Kumar",
24    about = "Retrieve NERSC SSH keys using system credential storage",
25    long_about = None,
26    version = "2.1.0"
27    )]
28struct Args {
29    /// Username, if not provided, taken from USER env variable
30    // #[clap(long, env = "USER")]
31    username: Option<String>,
32
33    /// Update NERSC password in macOS Keychain
34    #[clap(short = 'p', long)]
35    update_password: bool,
36
37    /// Update NERSC TOTP secret in macOS Keychain
38    #[clap(long)]
39    update_secret: bool,
40}
41
42/// NERSC passwords expire every year.
43#[cfg(target_os = "macos")]
44fn update_password(username: &str, password: &str) -> Result<()> {
45    // save password
46    set_generic_password(SERVICE_NAME, username, password.as_bytes())
47        .context("Failed to save password to keychain")?;
48    Ok(())
49}
50/// usually totp secrets do not expire
51#[cfg(target_os = "macos")]
52fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
53    // save otp secret
54    let service = format!("{}_SECRET", SERVICE_NAME);
55    set_generic_password(&service, username, otp_secret.as_bytes())
56        .context("Failed to save OTP secret to keychain")?;
57    Ok(())
58}
59
60/// Retrieve password from macOS Keychain
61#[cfg(target_os = "macos")]
62fn get_password(username: &str) -> Result<String> {
63    let password = get_generic_password(SERVICE_NAME, username)
64        .context("Failed to retrieve password from keychain")?;
65    Ok(String::from_utf8(password.to_vec())?)
66}
67
68/// Retrieve OTP secret from macOS Keychain
69#[cfg(target_os = "macos")]
70fn get_otp_secret(username: &str) -> Result<String> {
71    let service = format!("{}_SECRET", SERVICE_NAME);
72    let secret = get_generic_password(&service, username)
73        .context("Failed to retrieve OTP secret from keychain")?;
74    Ok(String::from_utf8(secret.to_vec())?)
75}
76
77/// NERSC passwords expire every year.
78#[cfg(target_os = "linux")]
79fn update_password(username: &str, password: &str) -> Result<()> {
80    let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
81    entry
82        .set_password(password)
83        .context("Failed to save password to credential storage")?;
84    Ok(())
85}
86
87/// usually totp secrets do not expire
88#[cfg(target_os = "linux")]
89fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
90    let service = format!("{}_SECRET", SERVICE_NAME);
91    let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
92    entry
93        .set_password(otp_secret)
94        .context("Failed to save OTP secret to credential storage")?;
95    Ok(())
96}
97
98/// Retrieve password from credential storage
99#[cfg(target_os = "linux")]
100fn get_password(username: &str) -> Result<String> {
101    let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
102    entry
103        .get_password()
104        .context("Failed to retrieve password from credential storage")
105}
106
107/// Retrieve OTP secret from credential storage
108#[cfg(target_os = "linux")]
109fn get_otp_secret(username: &str) -> Result<String> {
110    let service = format!("{}_SECRET", SERVICE_NAME);
111    let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
112    entry
113        .get_password()
114        .context("Failed to retrieve OTP secret from credential storage")
115}
116
117/// NERSC passwords expire every year.
118#[cfg(target_os = "windows")]
119fn update_password(username: &str, password: &str) -> Result<()> {
120    let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
121    entry
122        .set_password(password)
123        .context("Failed to save password to credential storage")?;
124    Ok(())
125}
126
127/// usually totp secrets do not expire
128#[cfg(target_os = "windows")]
129fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
130    let service = format!("{}_SECRET", SERVICE_NAME);
131    let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
132    entry
133        .set_password(otp_secret)
134        .context("Failed to save OTP secret to credential storage")?;
135    Ok(())
136}
137
138/// Retrieve password from credential storage
139#[cfg(target_os = "windows")]
140fn get_password(username: &str) -> Result<String> {
141    let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
142    entry
143        .get_password()
144        .context("Failed to retrieve password from credential storage")
145}
146
147/// Retrieve OTP secret from credential storage
148#[cfg(target_os = "windows")]
149fn get_otp_secret(username: &str) -> Result<String> {
150    let service = format!("{}_SECRET", SERVICE_NAME);
151    let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
152    entry
153        .get_password()
154        .context("Failed to retrieve OTP secret from credential storage")
155}
156
157/// Generate TOTP code from secret
158fn generate_totp(secret: &str) -> Result<String> {
159    use std::time::{SystemTime, UNIX_EPOCH};
160    use totp_lite::{totp_custom, Sha1};
161
162    // Decode base32 secret
163    let secret_bytes = data_encoding::BASE32_NOPAD
164        .decode(secret.to_uppercase().as_bytes())
165        .context("Failed to decode base32 OTP secret")?;
166
167    // Get current Unix timestamp
168    let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
169
170    // Generate TOTP (30 second interval, 6 digits)
171    let totp = totp_custom::<Sha1>(30, 6, &secret_bytes, timestamp);
172
173    Ok(format!("{:06}", totp))
174}
175
176/// Request SSH key and certificate from sshproxy API
177async fn request_ssh_key(username: &str, password_otp: &str) -> Result<String> {
178    let endpoint = format!("{}/create_pair/{}/", URL, SCOPE);
179
180    let client = Client::builder()
181        .http1_only()
182        .redirect(reqwest::redirect::Policy::none())
183        .build()?;
184
185    let request = client
186        .post(&endpoint)
187        .basic_auth(username, Some(password_otp));
188
189    let response = request
190        .send()
191        .await
192        .context("Failed to send request to sshproxy server")?;
193
194    let status = response.status();
195    let body = response.text().await?;
196
197    if !status.is_success() {
198        anyhow::bail!("Server returned error: {} - {}", status, body);
199    }
200
201    // Check for authentication failure
202    if body.contains("Authentication failed") {
203        anyhow::bail!("Authentication failed. Check your password and OTP");
204    }
205
206    // Check for valid RSA private key
207    if !body.contains("-----BEGIN RSA PRIVATE KEY-----")
208        && !body.contains("-----BEGIN OPENSSH PRIVATE KEY-----")
209    {
210        anyhow::bail!(
211            "Response does not contain a valid SSH private key:\n{}",
212            body
213        );
214    }
215
216    Ok(body)
217}
218
219/// Extract certificate from combined key file
220fn extract_certificate(key_content: &str) -> Result<String> {
221    for line in key_content.lines() {
222        if line.contains("ssh-rsa") || line.contains("ssh-ed25519") {
223            return Ok(line.to_string());
224        }
225    }
226    anyhow::bail!("No certificate found in key file")
227}
228
229/// Save key files to disk with proper permissions
230fn save_key_files(key_path: &PathBuf, key_content: &str, cert_content: &str) -> Result<()> {
231    // Save private key
232    fs::write(key_path, key_content).context("Failed to write private key")?;
233
234    // Set permissions to 600 (Unix only)
235    #[cfg(unix)]
236    {
237        let metadata = fs::metadata(key_path)?;
238        let mut permissions = metadata.permissions();
239        permissions.set_mode(0o600);
240        fs::set_permissions(key_path, permissions)?;
241    }
242
243    // Save certificate
244    let cert_path = key_path
245        .with_extension("")
246        .with_extension("pub")
247        .with_extension("");
248    let cert_path = format!("{}-cert.pub", cert_path.display());
249    fs::write(&cert_path, cert_content).context("Failed to write certificate")?;
250
251    // Generate and save public key using ssh-keygen
252    let output = std::process::Command::new("ssh-keygen")
253        .arg("-y")
254        .arg("-f")
255        .arg(key_path)
256        .output()
257        .context("Failed to generate public key with ssh-keygen")?;
258
259    if !output.status.success() {
260        anyhow::bail!(
261            "ssh-keygen failed: {}",
262            String::from_utf8_lossy(&output.stderr)
263        );
264    }
265
266    let pub_path = key_path.with_extension("pub");
267    fs::write(&pub_path, output.stdout).context("Failed to write public key")?;
268
269    Ok(())
270}
271
272/// Get certificate validity information
273fn get_cert_validity(cert_path: &str) -> Result<String> {
274    let output = std::process::Command::new("ssh-keygen")
275        .arg("-L")
276        .arg("-f")
277        .arg(cert_path)
278        .output()
279        .context("Failed to read certificate with ssh-keygen")?;
280
281    if !output.status.success() {
282        anyhow::bail!("ssh-keygen -L failed");
283    }
284
285    let output_str = String::from_utf8_lossy(&output.stdout);
286    for line in output_str.lines() {
287        if line.trim().starts_with("Valid:") {
288            return Ok(line.trim().to_string());
289        }
290    }
291
292    Ok("Valid: unknown".to_string())
293}
294
295#[tokio::main]
296async fn main() -> Result<()> {
297    // Parse command line arguments
298    let args = Args::parse();
299
300    // get username
301    let username = args.username.unwrap_or_else(|| {
302        env::var("USER")
303            .or_else(|_| env::var("USERNAME"))
304            .unwrap_or_else(|_| whoami::username())
305    });
306
307    // check if we need to update password
308    if args.update_password {
309        println!("Enter new password for user {}: ", username);
310        let password = rpassword::read_password().context("Failed to read password")?;
311        update_password(&username, &password)?;
312        println!("Password updated successfully.");
313        return Ok(());
314    }
315
316    // check if we need to update otp secret
317    if args.update_secret {
318        println!("Enter TOTP secret for user {}: ", username);
319        let otp_secret = rpassword::read_password().context("Failed to read OTP secret")?;
320        update_secret(&username, &otp_secret)?;
321        println!("OTP secret updated successfully.");
322        return Ok(());
323    }
324
325    // Determine output path
326    let home = dirs::home_dir().context("Could not determine home directory")?;
327    let key_path = home.join(".ssh").join("nersc");
328
329    // Retrieve credentials from keychain
330    let password = get_password(&username)
331        .context("Failed to get password. Run with --update-password first")?;
332
333    let otp_secret = get_otp_secret(&username)
334        .context("Failed to get OTP secret. Run with --update-secret first")?;
335
336    // Generate TOTP code
337    let totp_code = generate_totp(&otp_secret)?;
338
339    // Combine password and OTP
340    let password_otp = format!("{}{}", password, totp_code);
341
342    println!("Requesting SSH key for user: {}", username);
343
344    // Request key from API
345    let key_content = request_ssh_key(&username, &password_otp).await?;
346
347    // Extract certificate
348    let cert_content = extract_certificate(&key_content)?;
349
350    // Save files
351    save_key_files(&key_path, &key_content, &cert_content)?;
352
353    println!("Successfully obtained ssh key: {}", key_path.display());
354
355    // Show validity
356    let cert_path = format!("{}-cert.pub", key_path.display());
357    if let Ok(validity) = get_cert_validity(&cert_path) {
358        println!("Key is {}", validity.to_lowercase());
359    }
360
361    Ok(())
362}