|
| 1 | +use std::io::Write; |
| 2 | +use std::{env, fs}; |
| 3 | + |
| 4 | +extern crate regex; |
| 5 | + |
| 6 | +use crate::version; |
| 7 | +use clap::{Arg, ArgMatches}; |
| 8 | +use flate2::read::GzDecoder; |
| 9 | +use futures::stream::StreamExt; |
| 10 | +use indicatif::{ProgressBar, ProgressStyle}; |
| 11 | +use regex::Regex; |
| 12 | +use serde::Deserialize; |
| 13 | +use serde_json::Value; |
| 14 | +use std::path::Path; |
| 15 | +use tar::Archive; |
| 16 | + |
| 17 | +pub fn cli() -> clap::Command { |
| 18 | + clap::Command::new("upgrade") |
| 19 | + .about("Checks for updates for the currently running spacetime CLI tool") |
| 20 | + .arg(Arg::new("version").help("The specific version to upgrade to")) |
| 21 | + .after_help("Run `spacetime help upgrade` for more detailed information.\n") |
| 22 | +} |
| 23 | + |
| 24 | +#[derive(Deserialize)] |
| 25 | +struct ReleaseAsset { |
| 26 | + name: String, |
| 27 | + browser_download_url: String, |
| 28 | +} |
| 29 | + |
| 30 | +#[derive(Deserialize)] |
| 31 | +struct Release { |
| 32 | + tag_name: String, |
| 33 | + assets: Vec<ReleaseAsset>, |
| 34 | +} |
| 35 | + |
| 36 | +fn get_download_name() -> String { |
| 37 | + let os = env::consts::OS; |
| 38 | + let arch = env::consts::ARCH; |
| 39 | + |
| 40 | + let os_str = match os { |
| 41 | + "macos" => "darwin", |
| 42 | + "windows" => return "spacetime.exe".to_string(), |
| 43 | + "linux" => "linux", |
| 44 | + _ => panic!("Unsupported OS"), |
| 45 | + }; |
| 46 | + |
| 47 | + let arch_str = match arch { |
| 48 | + "x86_64" => "amd64", |
| 49 | + "aarch64" => "arm64", |
| 50 | + _ => panic!("Unsupported architecture"), |
| 51 | + }; |
| 52 | + |
| 53 | + format!("spacetime.{}-{}.tar.gz", os_str, arch_str) |
| 54 | +} |
| 55 | + |
| 56 | +fn clean_version(version: &str) -> Option<String> { |
| 57 | + let re = Regex::new(r"v?(\d+\.\d+\.\d+)").unwrap(); |
| 58 | + re.captures(version) |
| 59 | + .and_then(|cap| cap.get(1)) |
| 60 | + .map(|match_| match_.as_str().to_string()) |
| 61 | +} |
| 62 | + |
| 63 | +async fn get_release_tag_from_version(release_version: &str) -> Result<Option<String>, reqwest::Error> { |
| 64 | + let release_version = format!("v{}-beta", release_version); |
| 65 | + let url = "https://api.github.com/repos/clockworklabs/SpacetimeDB/releases"; |
| 66 | + let client = reqwest::Client::builder() |
| 67 | + .user_agent(format!("SpacetimeDB CLI/{}", version::CLI_VERSION)) |
| 68 | + .build()?; |
| 69 | + let releases: Vec<Value> = client |
| 70 | + .get(url) |
| 71 | + .header( |
| 72 | + reqwest::header::USER_AGENT, |
| 73 | + format!("SpacetimeDB CLI/{}", version::CLI_VERSION).as_str(), |
| 74 | + ) |
| 75 | + .send() |
| 76 | + .await? |
| 77 | + .json() |
| 78 | + .await?; |
| 79 | + |
| 80 | + for release in releases.iter() { |
| 81 | + if let Some(release_tag) = release["tag_name"].as_str() { |
| 82 | + if release_tag.starts_with(&release_version) { |
| 83 | + return Ok(Some(release_tag.to_string())); |
| 84 | + } |
| 85 | + } |
| 86 | + } |
| 87 | + Ok(None) |
| 88 | +} |
| 89 | + |
| 90 | +async fn download_with_progress(client: &reqwest::Client, url: &str, temp_path: &Path) -> Result<(), anyhow::Error> { |
| 91 | + let response = client.get(url).send().await?; |
| 92 | + let total_size = match response.headers().get(reqwest::header::CONTENT_LENGTH) { |
| 93 | + Some(size) => size.to_str().unwrap().parse::<u64>().unwrap(), |
| 94 | + None => 0, |
| 95 | + }; |
| 96 | + |
| 97 | + let pb = ProgressBar::new(total_size); |
| 98 | + pb.set_style( |
| 99 | + ProgressStyle::default_bar().template("{spinner} Downloading update... {bytes}/{total_bytes} ({eta})"), |
| 100 | + ); |
| 101 | + |
| 102 | + let mut file = fs::File::create(temp_path)?; |
| 103 | + let mut downloaded_bytes = 0; |
| 104 | + |
| 105 | + let mut response_stream = response.bytes_stream(); |
| 106 | + while let Some(chunk) = response_stream.next().await { |
| 107 | + let chunk = chunk?; |
| 108 | + downloaded_bytes += chunk.len(); |
| 109 | + pb.set_position(downloaded_bytes as u64); |
| 110 | + file.write_all(&chunk)?; |
| 111 | + } |
| 112 | + |
| 113 | + pb.finish_with_message("Download complete."); |
| 114 | + Ok(()) |
| 115 | +} |
| 116 | + |
| 117 | +pub async fn exec(args: &ArgMatches) -> Result<(), anyhow::Error> { |
| 118 | + let version = args.get_one::<String>("version"); |
| 119 | + let current_exe_path = env::current_exe()?; |
| 120 | + |
| 121 | + let url = match version { |
| 122 | + None => "https://api.github.com/repos/clockworklabs/SpacetimeDB/releases/latest".to_string(), |
| 123 | + Some(release_version) => { |
| 124 | + let release_tag = get_release_tag_from_version(release_version).await?; |
| 125 | + if release_tag.is_none() { |
| 126 | + return Err(anyhow::anyhow!("No release found for version {}", release_version)); |
| 127 | + } |
| 128 | + format!( |
| 129 | + "https://api.github.com/repos/clockworklabs/SpacetimeDB/releases/tags/{}", |
| 130 | + release_tag.unwrap() |
| 131 | + ) |
| 132 | + } |
| 133 | + }; |
| 134 | + |
| 135 | + let client = reqwest::Client::builder() |
| 136 | + .user_agent(format!("SpacetimeDB CLI/{}", version::CLI_VERSION)) |
| 137 | + .build()?; |
| 138 | + |
| 139 | + print!("Finding version..."); |
| 140 | + std::io::stdout().flush()?; |
| 141 | + let release: Release = client.get(url).send().await?.json().await?; |
| 142 | + let release_version = clean_version(&release.tag_name).unwrap(); |
| 143 | + println!("done."); |
| 144 | + |
| 145 | + if release_version == version::CLI_VERSION { |
| 146 | + println!("You're already running the latest version: {}", version::CLI_VERSION); |
| 147 | + return Ok(()); |
| 148 | + } |
| 149 | + |
| 150 | + let download_name = get_download_name(); |
| 151 | + let asset = release.assets.iter().find(|&asset| asset.name == download_name); |
| 152 | + |
| 153 | + if asset.is_none() { |
| 154 | + return Err(anyhow::anyhow!( |
| 155 | + "No assets available for the detected OS and architecture." |
| 156 | + )); |
| 157 | + } |
| 158 | + |
| 159 | + println!( |
| 160 | + "You are currently running version {} of spacetime. The version you're upgrading to is {}.", |
| 161 | + version::CLI_VERSION, |
| 162 | + release_version, |
| 163 | + ); |
| 164 | + println!( |
| 165 | + "This will replace the current executable at {}.", |
| 166 | + current_exe_path.display() |
| 167 | + ); |
| 168 | + print!("Do you want to continue? [y/N] "); |
| 169 | + std::io::stdout().flush()?; |
| 170 | + let mut input = String::new(); |
| 171 | + std::io::stdin().read_line(&mut input)?; |
| 172 | + if input.trim().to_lowercase() != "y" && input.trim().to_lowercase() != "yes" { |
| 173 | + println!("Aborting upgrade."); |
| 174 | + return Ok(()); |
| 175 | + } |
| 176 | + |
| 177 | + let temp_dir = tempfile::tempdir()?.into_path(); |
| 178 | + let temp_path = &temp_dir.join(download_name.clone()); |
| 179 | + download_with_progress(&client, &asset.unwrap().browser_download_url, temp_path).await?; |
| 180 | + |
| 181 | + if download_name.to_lowercase().ends_with(".tar.gz") || download_name.to_lowercase().ends_with("tgz") { |
| 182 | + let tar_gz = fs::File::open(temp_path)?; |
| 183 | + let tar = GzDecoder::new(tar_gz); |
| 184 | + let mut archive = Archive::new(tar); |
| 185 | + let mut spacetime_found = false; |
| 186 | + for mut file in archive.entries()?.filter_map(|e| e.ok()) { |
| 187 | + if let Ok(path) = file.path() { |
| 188 | + if path.ends_with("spacetime") { |
| 189 | + spacetime_found = true; |
| 190 | + file.unpack(temp_dir.join("spacetime"))?; |
| 191 | + } |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + if !spacetime_found { |
| 196 | + fs::remove_dir_all(&temp_dir)?; |
| 197 | + return Err(anyhow::anyhow!("Spacetime executable not found in archive")); |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + let new_exe_path = if temp_path.to_str().unwrap().ends_with(".exe") { |
| 202 | + temp_path.clone() |
| 203 | + } else if download_name.ends_with(".tar.gz") { |
| 204 | + temp_dir.join("spacetime") |
| 205 | + } else { |
| 206 | + fs::remove_dir_all(&temp_dir)?; |
| 207 | + return Err(anyhow::anyhow!("Unsupported download type")); |
| 208 | + }; |
| 209 | + |
| 210 | + // Move the current executable into a temporary directory, which will later be deleted by the OS |
| 211 | + let current_exe_temp_dir = env::temp_dir(); |
| 212 | + let current_exe_to_temp = current_exe_temp_dir.join("spacetime_old"); |
| 213 | + fs::rename(¤t_exe_path, current_exe_to_temp)?; |
| 214 | + fs::rename(new_exe_path, ¤t_exe_path)?; |
| 215 | + fs::remove_dir_all(&temp_dir)?; |
| 216 | + |
| 217 | + println!("spacetime has been updated to version {}", release_version); |
| 218 | + |
| 219 | + Ok(()) |
| 220 | +} |
0 commit comments