Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions pineappl_py/src/fk_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use pineappl::fk_table::{FkAssumptions, FkTable};
use pineappl::grid::Grid;
use pineappl::lumi::LumiCache;

use numpy::{IntoPyArray, PyArray1, PyArray4};
use numpy::{IntoPyArray, PyArray1, PyArray4, PyReadonlyArray1};
use pyo3::prelude::*;

use std::collections::HashMap;
Expand All @@ -11,6 +11,8 @@ use std::io::BufReader;
use std::path::PathBuf;
use std::str::FromStr;

use crate::grid::PyGrid;

/// PyO3 wrapper to :rustdoc:`pineappl::fk_table::FkTable <fk_table/struct.FkTable.html>`
///
/// *Usage*: `pineko`, `yadism`
Expand Down Expand Up @@ -38,6 +40,13 @@ impl PyFkAssumptions {

#[pymethods]
impl PyFkTable {
#[new]
pub fn new(grid: PyGrid) -> Self {
Self {
fk_table: FkTable::try_from(grid.grid).unwrap(),
}
}

#[staticmethod]
pub fn read(path: PathBuf) -> Self {
Self {
Expand Down Expand Up @@ -214,17 +223,24 @@ impl PyFkTable {
/// -------
/// numpy.ndarray(float) :
/// cross sections for all bins
#[pyo3(signature = (pdg_id, xfx, bin_indices = None, lumi_mask= None))]
pub fn convolute_with_one<'py>(
&self,
pdg_id: i32,
xfx: &PyAny,
bin_indices: Option<PyReadonlyArray1<usize>>,
lumi_mask: Option<PyReadonlyArray1<bool>>,
py: Python<'py>,
) -> &'py PyArray1<f64> {
let mut xfx = |id, x, q2| f64::extract(xfx.call1((id, x, q2)).unwrap()).unwrap();
let mut alphas = |_| 1.0;
let mut lumi_cache = LumiCache::with_one(pdg_id, &mut xfx, &mut alphas);
self.fk_table
.convolute(&mut lumi_cache, &[], &[])
.convolute(
&mut lumi_cache,
&bin_indices.map_or(vec![], |b| b.to_vec().unwrap()),
&lumi_mask.map_or(vec![], |l| l.to_vec().unwrap()),
)
.into_pyarray(py)
}

Expand Down
40 changes: 40 additions & 0 deletions pineappl_py/tests/test_fk_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np

import pineappl


class TestFkTable:
def fake_grid(self, bins=None):
lumis = [pineappl.lumi.LumiEntry([(1, 21, 0.1)])]
orders = [pineappl.grid.Order(3, 0, 0, 0)]
bin_limits = np.array([1e-7, 1e-3, 1] if bins is None else bins, dtype=float)
subgrid_params = pineappl.subgrid.SubgridParams()
g = pineappl.grid.Grid.create(lumis, orders, bin_limits, subgrid_params)
return g

def test_convolute_with_one(self):
g = self.fake_grid()

# DIS grid
xs = np.linspace(0.5, 1.0, 5)
vs = xs.copy()
subgrid = pineappl.import_only_subgrid.ImportOnlySubgridV1(
vs[np.newaxis, :, np.newaxis],
np.array([90.0]),
xs,
np.array([1.0]),
)
g.set_subgrid(0, 0, 0, subgrid)
fk = pineappl.fk_table.FkTable(g)
np.testing.assert_allclose(
fk.convolute_with_one(2212, lambda pid, x, q2: 0.0, lambda q2: 0.0),
[0.0] * 2,
)
np.testing.assert_allclose(
fk.convolute_with_one(2212, lambda pid, x, q2: 1, lambda q2: 1.0),
[5e6 / 9999, 0.0],
)
np.testing.assert_allclose(
fk.convolute_with_one(2212, lambda pid, x, q2: 1, lambda q2: 2.0),
[2**3 * 5e6 / 9999, 0.0],
)