1use anyhow::{Context, Result};
2use clap::Parser;
3use reqwest::Client;
4
5#[cfg(target_os = "macos")]
7use security_framework::passwords::{get_generic_password, set_generic_password};
8
9#[cfg(any(target_os = "linux", target_os = "windows"))]
10use keyring::Entry;
11
12#[cfg(unix)]
13use std::os::unix::fs::PermissionsExt;
14use std::path::PathBuf;
15use std::{env, fs};
16
17const SERVICE_NAME: &str = "NERSC";
18const URL: &str = "https://sshproxy.nersc.gov";
19const SCOPE: &str = "default";
20
21#[derive(Parser)]
22#[command(
23 author = "Dinesh Kumar",
24 about = "Retrieve NERSC SSH keys using system credential storage",
25 long_about = None,
26 version = "2.1.0"
27 )]
28struct Args {
29 username: Option<String>,
32
33 #[clap(short = 'p', long)]
35 update_password: bool,
36
37 #[clap(long)]
39 update_secret: bool,
40}
41
42#[cfg(target_os = "macos")]
44fn update_password(username: &str, password: &str) -> Result<()> {
45 set_generic_password(SERVICE_NAME, username, password.as_bytes())
47 .context("Failed to save password to keychain")?;
48 Ok(())
49}
50#[cfg(target_os = "macos")]
52fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
53 let service = format!("{}_SECRET", SERVICE_NAME);
55 set_generic_password(&service, username, otp_secret.as_bytes())
56 .context("Failed to save OTP secret to keychain")?;
57 Ok(())
58}
59
60#[cfg(target_os = "macos")]
62fn get_password(username: &str) -> Result<String> {
63 let password = get_generic_password(SERVICE_NAME, username)
64 .context("Failed to retrieve password from keychain")?;
65 Ok(String::from_utf8(password.to_vec())?)
66}
67
68#[cfg(target_os = "macos")]
70fn get_otp_secret(username: &str) -> Result<String> {
71 let service = format!("{}_SECRET", SERVICE_NAME);
72 let secret = get_generic_password(&service, username)
73 .context("Failed to retrieve OTP secret from keychain")?;
74 Ok(String::from_utf8(secret.to_vec())?)
75}
76
77#[cfg(target_os = "linux")]
79fn update_password(username: &str, password: &str) -> Result<()> {
80 let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
81 entry
82 .set_password(password)
83 .context("Failed to save password to credential storage")?;
84 Ok(())
85}
86
87#[cfg(target_os = "linux")]
89fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
90 let service = format!("{}_SECRET", SERVICE_NAME);
91 let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
92 entry
93 .set_password(otp_secret)
94 .context("Failed to save OTP secret to credential storage")?;
95 Ok(())
96}
97
98#[cfg(target_os = "linux")]
100fn get_password(username: &str) -> Result<String> {
101 let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
102 entry
103 .get_password()
104 .context("Failed to retrieve password from credential storage")
105}
106
107#[cfg(target_os = "linux")]
109fn get_otp_secret(username: &str) -> Result<String> {
110 let service = format!("{}_SECRET", SERVICE_NAME);
111 let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
112 entry
113 .get_password()
114 .context("Failed to retrieve OTP secret from credential storage")
115}
116
117#[cfg(target_os = "windows")]
119fn update_password(username: &str, password: &str) -> Result<()> {
120 let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
121 entry
122 .set_password(password)
123 .context("Failed to save password to credential storage")?;
124 Ok(())
125}
126
127#[cfg(target_os = "windows")]
129fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
130 let service = format!("{}_SECRET", SERVICE_NAME);
131 let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
132 entry
133 .set_password(otp_secret)
134 .context("Failed to save OTP secret to credential storage")?;
135 Ok(())
136}
137
138#[cfg(target_os = "windows")]
140fn get_password(username: &str) -> Result<String> {
141 let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
142 entry
143 .get_password()
144 .context("Failed to retrieve password from credential storage")
145}
146
147#[cfg(target_os = "windows")]
149fn get_otp_secret(username: &str) -> Result<String> {
150 let service = format!("{}_SECRET", SERVICE_NAME);
151 let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
152 entry
153 .get_password()
154 .context("Failed to retrieve OTP secret from credential storage")
155}
156
157fn generate_totp(secret: &str) -> Result<String> {
159 use std::time::{SystemTime, UNIX_EPOCH};
160 use totp_lite::{totp_custom, Sha1};
161
162 let secret_bytes = data_encoding::BASE32_NOPAD
164 .decode(secret.to_uppercase().as_bytes())
165 .context("Failed to decode base32 OTP secret")?;
166
167 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
169
170 let totp = totp_custom::<Sha1>(30, 6, &secret_bytes, timestamp);
172
173 Ok(format!("{:06}", totp))
174}
175
176async fn request_ssh_key(username: &str, password_otp: &str) -> Result<String> {
178 let endpoint = format!("{}/create_pair/{}/", URL, SCOPE);
179
180 let client = Client::builder()
181 .http1_only()
182 .redirect(reqwest::redirect::Policy::none())
183 .build()?;
184
185 let request = client
186 .post(&endpoint)
187 .basic_auth(username, Some(password_otp));
188
189 let response = request
190 .send()
191 .await
192 .context("Failed to send request to sshproxy server")?;
193
194 let status = response.status();
195 let body = response.text().await?;
196
197 if !status.is_success() {
198 anyhow::bail!("Server returned error: {} - {}", status, body);
199 }
200
201 if body.contains("Authentication failed") {
203 anyhow::bail!("Authentication failed. Check your password and OTP");
204 }
205
206 if !body.contains("-----BEGIN RSA PRIVATE KEY-----")
208 && !body.contains("-----BEGIN OPENSSH PRIVATE KEY-----")
209 {
210 anyhow::bail!(
211 "Response does not contain a valid SSH private key:\n{}",
212 body
213 );
214 }
215
216 Ok(body)
217}
218
219fn extract_certificate(key_content: &str) -> Result<String> {
221 for line in key_content.lines() {
222 if line.contains("ssh-rsa") || line.contains("ssh-ed25519") {
223 return Ok(line.to_string());
224 }
225 }
226 anyhow::bail!("No certificate found in key file")
227}
228
229fn save_key_files(key_path: &PathBuf, key_content: &str, cert_content: &str) -> Result<()> {
231 fs::write(key_path, key_content).context("Failed to write private key")?;
233
234 #[cfg(unix)]
236 {
237 let metadata = fs::metadata(key_path)?;
238 let mut permissions = metadata.permissions();
239 permissions.set_mode(0o600);
240 fs::set_permissions(key_path, permissions)?;
241 }
242
243 let cert_path = key_path
245 .with_extension("")
246 .with_extension("pub")
247 .with_extension("");
248 let cert_path = format!("{}-cert.pub", cert_path.display());
249 fs::write(&cert_path, cert_content).context("Failed to write certificate")?;
250
251 let output = std::process::Command::new("ssh-keygen")
253 .arg("-y")
254 .arg("-f")
255 .arg(key_path)
256 .output()
257 .context("Failed to generate public key with ssh-keygen")?;
258
259 if !output.status.success() {
260 anyhow::bail!(
261 "ssh-keygen failed: {}",
262 String::from_utf8_lossy(&output.stderr)
263 );
264 }
265
266 let pub_path = key_path.with_extension("pub");
267 fs::write(&pub_path, output.stdout).context("Failed to write public key")?;
268
269 Ok(())
270}
271
272fn get_cert_validity(cert_path: &str) -> Result<String> {
274 let output = std::process::Command::new("ssh-keygen")
275 .arg("-L")
276 .arg("-f")
277 .arg(cert_path)
278 .output()
279 .context("Failed to read certificate with ssh-keygen")?;
280
281 if !output.status.success() {
282 anyhow::bail!("ssh-keygen -L failed");
283 }
284
285 let output_str = String::from_utf8_lossy(&output.stdout);
286 for line in output_str.lines() {
287 if line.trim().starts_with("Valid:") {
288 return Ok(line.trim().to_string());
289 }
290 }
291
292 Ok("Valid: unknown".to_string())
293}
294
295#[tokio::main]
296async fn main() -> Result<()> {
297 let args = Args::parse();
299
300 let username = args.username.unwrap_or_else(|| {
302 env::var("USER")
303 .or_else(|_| env::var("USERNAME"))
304 .unwrap_or_else(|_| whoami::username())
305 });
306
307 if args.update_password {
309 println!("Enter new password for user {}: ", username);
310 let password = rpassword::read_password().context("Failed to read password")?;
311 update_password(&username, &password)?;
312 println!("Password updated successfully.");
313 return Ok(());
314 }
315
316 if args.update_secret {
318 println!("Enter TOTP secret for user {}: ", username);
319 let otp_secret = rpassword::read_password().context("Failed to read OTP secret")?;
320 update_secret(&username, &otp_secret)?;
321 println!("OTP secret updated successfully.");
322 return Ok(());
323 }
324
325 let home = dirs::home_dir().context("Could not determine home directory")?;
327 let key_path = home.join(".ssh").join("nersc");
328
329 let password = get_password(&username)
331 .context("Failed to get password. Run with --update-password first")?;
332
333 let otp_secret = get_otp_secret(&username)
334 .context("Failed to get OTP secret. Run with --update-secret first")?;
335
336 let totp_code = generate_totp(&otp_secret)?;
338
339 let password_otp = format!("{}{}", password, totp_code);
341
342 println!("Requesting SSH key for user: {}", username);
343
344 let key_content = request_ssh_key(&username, &password_otp).await?;
346
347 let cert_content = extract_certificate(&key_content)?;
349
350 save_key_files(&key_path, &key_content, &cert_content)?;
352
353 println!("Successfully obtained ssh key: {}", key_path.display());
354
355 let cert_path = format!("{}-cert.pub", key_path.display());
357 if let Ok(validity) = get_cert_validity(&cert_path) {
358 println!("Key is {}", validity.to_lowercase());
359 }
360
361 Ok(())
362}