Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ext/OceananigansReactantExt/OceananigansReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ function Oceananigans.TimeSteppers.tick!(clock::Oceananigans.TimeSteppers.Clock{

if stage # tick a stage update
clock.stage += 1
clock.last_stage_Δt = Δt
else # tick an iteration and reset stage
clock.iteration.mlir_data = (clock.iteration + 1).mlir_data
clock.stage = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ const OnlyParticleTrackingModel = HydrostaticFreeSurfaceModel{TS, E, A, S, G, T,

function time_step!(model::OnlyParticleTrackingModel, Δt; callbacks = [], kwargs...)
tick!(model.clock, Δt)
model.clock.last_Δt = Δt
step_lagrangian_particles!(model, Δt)
update_state!(model, callbacks)
end
Expand Down
7 changes: 4 additions & 3 deletions src/Models/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import Oceananigans: initialize!
import Oceananigans.Architectures: architecture
import Oceananigans.Solvers: iteration
import Oceananigans.Simulations: timestepper
import Oceananigans.TimeSteppers: reset!, set_clock!

# A prototype interface for AbstractModel.
#
Expand Down Expand Up @@ -114,8 +115,10 @@ const OceananigansModels = Union{HydrostaticFreeSurfaceModel,
NonhydrostaticModel,
ShallowWaterModel}

set_clock!(model::OceananigansModels, new_clock) = set_clock!(model.clock, new_clock)

"""
possible_field_time_series(model::HydrostaticFreeSurfaceModel)
possible_field_time_series(model::OceananigansModels)

Return a `Tuple` containing properties of and `OceananigansModel` that could contain `FieldTimeSeries`.
"""
Expand Down Expand Up @@ -145,8 +148,6 @@ function update_model_field_time_series!(model::OceananigansModels, clock::Clock
return nothing
end

import Oceananigans.TimeSteppers: reset!

function reset!(model::OceananigansModels)

for field in fields(model)
Expand Down
6 changes: 3 additions & 3 deletions src/TimeSteppers/TimeSteppers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ TimeStepper(name::Symbol, args...; kwargs...) = TimeStepper(Val(name), args...;
# Fallback
TimeStepper(stepper::AbstractTimeStepper, args...; kwargs...) = stepper

#individual contructors
#individual constructors
TimeStepper(::Val{:QuasiAdamsBashforth2}, args...; kwargs...) =
QuasiAdamsBashforth2TimeStepper(args...; kwargs...)

Expand All @@ -69,15 +69,15 @@ TimeStepper(::Val{:SplitRungeKutta3}, args...; kwargs...) =

function first_time_step!(model::AbstractModel, Δt)
initialize!(model)
# The first update_state is conditionally gated from within time_step!
# The first update_state! is conditionally gated from within time_step!
# update_state!(model)
time_step!(model, Δt)
return nothing
end

function first_time_step!(model::AbstractModel{<:QuasiAdamsBashforth2TimeStepper}, Δt)
initialize!(model)
# The first update_state is conditionally gated from within time_step!
# The first update_state! is conditionally gated from within time_step!
# update_state!(model)
time_step!(model, Δt, euler=true)
return nothing
Expand Down
46 changes: 36 additions & 10 deletions src/TimeSteppers/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,10 @@ mutable struct Clock{TT, DT, IT, S}
stage :: S
end

function reset!(clock::Clock{TT, DT, IT, S}) where {TT, DT, IT, S}
clock.time = zero(TT)
clock.iteration = zero(IT)
clock.stage = zero(S)
clock.last_Δt = Inf
clock.last_stage_Δt = Inf
return nothing
end

"""
Clock(; time, last_Δt=Inf, last_stage_Δt=Inf, iteration=0, stage=1)

Returns a `Clock` object. By default, `Clock` is initialized to the zeroth `iteration`
Return a `Clock` object. By default, `Clock` is initialized to the zeroth `iteration`
and first time step `stage` with `last_Δt=last_stage_Δt=Inf`.
"""
function Clock(; time,
Expand All @@ -49,6 +40,38 @@ function Clock(; time,
return Clock{TT, DT, IT, typeof(stage)}(time, last_Δt, last_stage_Δt, iteration, stage)
end

function reset!(clock::Clock{TT, DT, IT, S}) where {TT, DT, IT, S}
clock.time = zero(TT)
clock.iteration = zero(IT)
clock.stage = zero(S)
clock.last_Δt = Inf
clock.last_stage_Δt = Inf
return nothing
end

"""
set_clock!(clock::Clock, new_clock::Clock)

Set `clock` to the `new_clock`.
"""
function set_clock!(clock::Clock, new_clock::Clock)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps even better is to just add a method to set!?

Suggested change
function set_clock!(clock::Clock, new_clock::Clock)
import Oceananigans: set!
function set!(clock::Clock, new_clock::Clock)

clock.time = new_clock.time
clock.iteration = new_clock.iteration
clock.last_Δt = new_clock.last_Δt
clock.last_stage_Δt = new_clock.last_stage_Δt
clock.stage = new_clock.stage

return nothing
end

function Base.:(==)(clock1::Clock, clock2::Clock)
return clock1.time == clock2.time &&
clock1.iteration == clock2.iteration &&
clock1.last_Δt == clock2.last_Δt &&
clock1.last_stage_Δt == clock2.last_stage_Δt &&
clock1.stage == clock2.stage
end

# TODO: when supporting DateTime, this function will have to be extended
time_step_type(TT) = TT

Expand Down Expand Up @@ -107,9 +130,12 @@ function tick!(clock, Δt; stage=false)

if stage # tick a stage update
clock.stage += 1
clock.last_stage_Δt = Δt
else # tick an iteration and reset stage
clock.iteration += 1
clock.stage = 1
clock.last_Δt = Δt
clock.last_stage_Δt = Δt
end

return nothing
Expand Down
3 changes: 0 additions & 3 deletions src/TimeSteppers/quasi_adams_bashforth_2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ function time_step!(model::AbstractModel{<:QuasiAdamsBashforth2TimeStepper}, Δt
ab2_step!(model, Δt)

tick!(model.clock, Δt)
model.clock.last_Δt = Δt
model.clock.last_stage_Δt = Δt # just one stage

compute_pressure_correction!(model, Δt)
@apply_regionally correct_velocities_and_cache_previous_tendencies!(model, Δt)
Expand Down Expand Up @@ -175,4 +173,3 @@ Time step velocity fields via the 2nd-order quasi Adams-Bashforth method
end

@kernel ab2_step_field!(::FunctionField, Δt, χ, Gⁿ, G⁻) = nothing

20 changes: 11 additions & 9 deletions src/TimeSteppers/runge_kutta_3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ function RungeKutta3TimeStepper(grid, prognostic_fields;
Gⁿ::TG = map(similar, prognostic_fields),
G⁻ = map(similar, prognostic_fields)) where {TI, TG}

!isnothing(implicit_solver) &&
@warn("Implicit-explicit time-stepping with RungeKutta3TimeStepper is not tested. " *
"\n implicit_solver: $(typeof(implicit_solver))")

γ¹ = 8 // 15
γ² = 5 // 12
γ³ = 3 // 4
Expand Down Expand Up @@ -92,12 +96,13 @@ function time_step!(model::AbstractModel{<:RungeKutta3TimeStepper}, Δt; callbac
γ² = model.timestepper.γ²
γ³ = model.timestepper.γ³

ζ¹ = nothing
ζ² = model.timestepper.ζ²
ζ³ = model.timestepper.ζ³

first_stage_Δt = γ¹ * Δt
second_stage_Δt = (γ² + ζ²) * Δt
third_stage_Δt = (γ³ + ζ³) * Δt
first_stage_Δt = stage_Δt(Δt, γ¹, ζ¹) # = γ¹ * Δt
second_stage_Δt = stage_Δt(Δt, γ², ζ²) # = (γ² + ζ²) * Δt
third_stage_Δt = stage_Δt(Δt, γ³, ζ³) # = (γ³ + ζ³) * Δt

# Compute the next time step a priori to reduce floating point error accumulation
tⁿ⁺¹ = next_time(model.clock, Δt)
Expand All @@ -109,7 +114,6 @@ function time_step!(model::AbstractModel{<:RungeKutta3TimeStepper}, Δt; callbac
rk3_substep!(model, Δt, γ¹, nothing)

tick!(model.clock, first_stage_Δt; stage=true)
model.clock.last_stage_Δt = first_stage_Δt

compute_pressure_correction!(model, first_stage_Δt)
make_pressure_correction!(model, first_stage_Δt)
Expand All @@ -125,7 +129,6 @@ function time_step!(model::AbstractModel{<:RungeKutta3TimeStepper}, Δt; callbac
rk3_substep!(model, Δt, γ², ζ²)

tick!(model.clock, second_stage_Δt; stage=true)
model.clock.last_stage_Δt = second_stage_Δt

compute_pressure_correction!(model, second_stage_Δt)
make_pressure_correction!(model, second_stage_Δt)
Expand All @@ -144,8 +147,9 @@ function time_step!(model::AbstractModel{<:RungeKutta3TimeStepper}, Δt; callbac
# round-off error when Δt is added to model.clock.time. Note that we still use
# third_stage_Δt for the substep, pressure correction, and Lagrangian particles step.
corrected_third_stage_Δt = tⁿ⁺¹ - model.clock.time

tick!(model.clock, third_stage_Δt)
# now model.clock.last_Δt = clock.last_stage_Δt = third_stage_Δt
# we correct those below
model.clock.last_stage_Δt = corrected_third_stage_Δt
model.clock.last_Δt = Δt

Expand Down Expand Up @@ -193,9 +197,7 @@ end
"""
Time step velocity fields via the 3rd-order Runge-Kutta method

```
Uᵐ⁺¹ = Uᵐ + Δt * (γᵐ * Gᵐ + ζᵐ * Gᵐ⁻¹)
```
Uᵐ⁺¹ = Uᵐ + Δt * (γᵐ * Gᵐ + ζᵐ * Gᵐ⁻¹)

where `m` denotes the substage.
"""
Expand Down
4 changes: 2 additions & 2 deletions test/test_grids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ function test_regular_rectilinear_constructor_errors(FT)
@test_throws ArgumentError RectilinearGrid(CPU(), FT, topology=(Flat, Flat, Periodic), size=(16, 16), extent=1)

@test_throws ArgumentError RectilinearGrid(CPU(), FT, topology=(Flat, Flat, Flat), size=16, extent=1)

@test_throws ArgumentError RectilinearGrid(CPU(), FT, size=(4, 4, 4), x=(0, 1), y=(0, 1), z=[-50.0, -30.0, -20.0, 0.0]) # too few z-faces
@test_throws ArgumentError RectilinearGrid(CPU(), FT, size=(4, 4, 4), x=(0, 1), y=(0, 1), z=[-2000.0, -1000.0, -50.0, -30.0, -20.0, 0.0]) # too many z-faces

Expand Down Expand Up @@ -352,7 +352,7 @@ function test_grid_equality(arch)
grid2 = RectilinearGrid(arch, topology=topo, size=(Nx, Ny, Nz), x=(0, 1), y=(-1, 1), z=0:Nz)
grid3 = RectilinearGrid(arch, topology=topo, size=(Nx, Ny, Nz), x=(0, 1), y=(-1, 1), z=0:Nz)

return grid1==grid1 && grid2 == grid3 && grid1 !== grid3
return grid1 == grid1 && grid2 == grid3 && grid1 !== grid3
end

function test_grid_equality_over_architectures()
Expand Down