1use anyhow::{Context, Result};
2use clap::Parser;
3use reqwest::Client;
4
5#[cfg(target_os = "macos")]
7use security_framework::passwords::{get_generic_password, set_generic_password};
8
9#[cfg(target_os = "linux")]
10use keyring::Entry;
11
12use std::os::unix::fs::PermissionsExt;
13use std::path::PathBuf;
14use std::{env, fs};
15
16const SERVICE_NAME: &str = "NERSC";
17const URL: &str = "https://sshproxy.nersc.gov";
18const SCOPE: &str = "default";
19
20#[derive(Parser)]
21#[command(
22 author = "Dinesh Kumar",
23 about = "Retrieve NERSC SSH keys using system credential storage",
24 long_about = None,
25 version = "2.0.0"
26 )]
27struct Args {
28 username: Option<String>,
31
32 #[clap(short = 'p', long)]
34 update_password: bool,
35
36 #[clap(long)]
38 update_secret: bool,
39}
40
41#[cfg(target_os = "macos")]
43fn update_password(username: &str, password: &str) -> Result<()> {
44 set_generic_password(SERVICE_NAME, username, password.as_bytes())
46 .context("Failed to save password to keychain")?;
47 Ok(())
48}
49#[cfg(target_os = "macos")]
51fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
52 let service = format!("{}_SECRET", SERVICE_NAME);
54 set_generic_password(&service, username, otp_secret.as_bytes())
55 .context("Failed to save OTP secret to keychain")?;
56 Ok(())
57}
58
59#[cfg(target_os = "macos")]
61fn get_password(username: &str) -> Result<String> {
62 let password = get_generic_password(SERVICE_NAME, username)
63 .context("Failed to retrieve password from keychain")?;
64 Ok(String::from_utf8(password.to_vec())?)
65}
66
67#[cfg(target_os = "macos")]
69fn get_otp_secret(username: &str) -> Result<String> {
70 let service = format!("{}_SECRET", SERVICE_NAME);
71 let secret = get_generic_password(&service, username)
72 .context("Failed to retrieve OTP secret from keychain")?;
73 Ok(String::from_utf8(secret.to_vec())?)
74}
75
76#[cfg(target_os = "linux")]
78fn update_password(username: &str, password: &str) -> Result<()> {
79 let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
80 entry
81 .set_password(password)
82 .context("Failed to save password to credential storage")?;
83 Ok(())
84}
85
86#[cfg(target_os = "linux")]
88fn update_secret(username: &str, otp_secret: &str) -> Result<()> {
89 let service = format!("{}_SECRET", SERVICE_NAME);
90 let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
91 entry
92 .set_password(otp_secret)
93 .context("Failed to save OTP secret to credential storage")?;
94 Ok(())
95}
96
97#[cfg(target_os = "linux")]
99fn get_password(username: &str) -> Result<String> {
100 let entry = Entry::new(SERVICE_NAME, username).context("Failed to create keyring entry")?;
101 entry
102 .get_password()
103 .context("Failed to retrieve password from credential storage")
104}
105
106#[cfg(target_os = "linux")]
108fn get_otp_secret(username: &str) -> Result<String> {
109 let service = format!("{}_SECRET", SERVICE_NAME);
110 let entry = Entry::new(&service, username).context("Failed to create keyring entry")?;
111 entry
112 .get_password()
113 .context("Failed to retrieve OTP secret from credential storage")
114}
115
116fn generate_totp(secret: &str) -> Result<String> {
118 use std::time::{SystemTime, UNIX_EPOCH};
119 use totp_lite::{totp_custom, Sha1};
120
121 let secret_bytes = data_encoding::BASE32_NOPAD
123 .decode(secret.to_uppercase().as_bytes())
124 .context("Failed to decode base32 OTP secret")?;
125
126 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
128
129 let totp = totp_custom::<Sha1>(30, 6, &secret_bytes, timestamp);
131
132 Ok(format!("{:06}", totp))
133}
134
135async fn request_ssh_key(username: &str, password_otp: &str) -> Result<String> {
137 let endpoint = format!("{}/create_pair/{}/", URL, SCOPE);
138
139 let client = Client::builder()
140 .http1_only()
141 .redirect(reqwest::redirect::Policy::none())
142 .build()?;
143
144 let request = client
145 .post(&endpoint)
146 .basic_auth(username, Some(password_otp));
147
148 let response = request
149 .send()
150 .await
151 .context("Failed to send request to sshproxy server")?;
152
153 let status = response.status();
154 let body = response.text().await?;
155
156 if !status.is_success() {
157 anyhow::bail!("Server returned error: {} - {}", status, body);
158 }
159
160 if body.contains("Authentication failed") {
162 anyhow::bail!("Authentication failed. Check your password and OTP");
163 }
164
165 if !body.contains("-----BEGIN RSA PRIVATE KEY-----")
167 && !body.contains("-----BEGIN OPENSSH PRIVATE KEY-----")
168 {
169 anyhow::bail!(
170 "Response does not contain a valid SSH private key:\n{}",
171 body
172 );
173 }
174
175 Ok(body)
176}
177
178fn extract_certificate(key_content: &str) -> Result<String> {
180 for line in key_content.lines() {
181 if line.contains("ssh-rsa") || line.contains("ssh-ed25519") {
182 return Ok(line.to_string());
183 }
184 }
185 anyhow::bail!("No certificate found in key file")
186}
187
188fn save_key_files(key_path: &PathBuf, key_content: &str, cert_content: &str) -> Result<()> {
190 fs::write(key_path, key_content).context("Failed to write private key")?;
192
193 let metadata = fs::metadata(key_path)?;
195 let mut permissions = metadata.permissions();
196 permissions.set_mode(0o600);
197 fs::set_permissions(key_path, permissions)?;
198
199 let cert_path = key_path
201 .with_extension("")
202 .with_extension("pub")
203 .with_extension("");
204 let cert_path = format!("{}-cert.pub", cert_path.display());
205 fs::write(&cert_path, cert_content).context("Failed to write certificate")?;
206
207 let output = std::process::Command::new("ssh-keygen")
209 .arg("-y")
210 .arg("-f")
211 .arg(key_path)
212 .output()
213 .context("Failed to generate public key with ssh-keygen")?;
214
215 if !output.status.success() {
216 anyhow::bail!(
217 "ssh-keygen failed: {}",
218 String::from_utf8_lossy(&output.stderr)
219 );
220 }
221
222 let pub_path = key_path.with_extension("pub");
223 fs::write(&pub_path, output.stdout).context("Failed to write public key")?;
224
225 Ok(())
226}
227
228fn get_cert_validity(cert_path: &str) -> Result<String> {
230 let output = std::process::Command::new("ssh-keygen")
231 .arg("-L")
232 .arg("-f")
233 .arg(cert_path)
234 .output()
235 .context("Failed to read certificate with ssh-keygen")?;
236
237 if !output.status.success() {
238 anyhow::bail!("ssh-keygen -L failed");
239 }
240
241 let output_str = String::from_utf8_lossy(&output.stdout);
242 for line in output_str.lines() {
243 if line.trim().starts_with("Valid:") {
244 return Ok(line.trim().to_string());
245 }
246 }
247
248 Ok("Valid: unknown".to_string())
249}
250
251#[tokio::main]
252async fn main() -> Result<()> {
253 let args = Args::parse();
255
256 let username = args.username.unwrap_or_else(|| {
258 env::var("USER")
259 .expect("Could not determine username from environment. Please provide --username.")
260 });
261
262 if args.update_password {
264 println!("Enter new password for user {}: ", username);
265 let password = rpassword::read_password().context("Failed to read password")?;
266 update_password(&username, &password)?;
267 println!("Password updated successfully.");
268 return Ok(());
269 }
270
271 if args.update_secret {
273 println!("Enter TOTP secret for user {}: ", username);
274 let otp_secret = rpassword::read_password().context("Failed to read OTP secret")?;
275 update_secret(&username, &otp_secret)?;
276 println!("OTP secret updated successfully.");
277 return Ok(());
278 }
279
280 let home = dirs::home_dir().context("Could not determine home directory")?;
282 let key_path = home.join(".ssh").join("nersc");
283
284 let password = get_password(&username)
286 .context("Failed to get password. Run with --update-password first")?;
287
288 let otp_secret = get_otp_secret(&username)
289 .context("Failed to get OTP secret. Run with --update-secret first")?;
290
291 let totp_code = generate_totp(&otp_secret)?;
293
294 let password_otp = format!("{}{}", password, totp_code);
296
297 println!("Requesting SSH key for user: {}", username);
298
299 let key_content = request_ssh_key(&username, &password_otp).await?;
301
302 let cert_content = extract_certificate(&key_content)?;
304
305 save_key_files(&key_path, &key_content, &cert_content)?;
307
308 println!("Successfully obtained ssh key: {}", key_path.display());
309
310 let cert_path = format!("{}-cert.pub", key_path.display());
312 if let Ok(validity) = get_cert_validity(&cert_path) {
313 println!("Key is {}", validity.to_lowercase());
314 }
315
316 Ok(())
317}