1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use super::bindings::size_t;
use super::bindings::{
    DMatrixHandle, TreeliteDMatrixCreateFromCSR, TreeliteDMatrixCreateFromMat, TreeliteDMatrixFree,
    TreeliteDMatrixGetDimension,
};
use super::{DataType, RetCodeCheck};
use crate::errors::TreeRiteError;
use fehler::{throw, throws};
use libc::c_void;
use ndarray::{ArrayView, Ix2};
use num_traits::Float;
use std::ffi::CStr;
use std::ptr::null_mut;
use std::{f32, f64};

#[throws(TreeRiteError)]
pub fn treelite_dmatrix_free(handle: DMatrixHandle) {
    unsafe { TreeliteDMatrixFree(handle) }.check()?;
}

pub trait FloatInfo {
    const DATA_TYPE: DataType;
    const MISSING: Self;
}

impl FloatInfo for f64 {
    const DATA_TYPE: DataType = DataType::Float64;
    const MISSING: Self = f64::NAN;
}

impl FloatInfo for f32 {
    const DATA_TYPE: DataType = DataType::Float32;
    const MISSING: Self = f32::NAN;
}

#[throws(TreeRiteError)]
pub fn treelite_dmatrix_create_from_array<'a, F: Float + FloatInfo>(
    data: ArrayView<'a, F, Ix2>,
) -> DMatrixHandle {
    if !data.is_standard_layout() {
        throw!(TreeRiteError::DataNotCContiguous);
    }
    let mut out = null_mut();
    unsafe {
        TreeliteDMatrixCreateFromMat(
            data.as_ptr() as *const c_void,
            Into::<&'static CStr>::into(F::DATA_TYPE).as_ptr(),
            data.nrows() as u64 as size_t,
            data.ncols() as u64 as size_t,
            &F::MISSING as *const F as *const c_void,
            &mut out,
        )
    }
    .check()?;
    out
}

#[throws(TreeRiteError)]
pub fn treelite_dmatrix_create_from_slice<'a, T: Float + FloatInfo>(
    data: &'a [T],
) -> DMatrixHandle {
    let mut out = null_mut();
    unsafe {
        TreeliteDMatrixCreateFromMat(
            data.as_ptr() as *const c_void,
            Into::<&'static CStr>::into(T::DATA_TYPE).as_ptr(),
            1,
            data.len() as u64 as size_t,
            &T::MISSING as *const T as *const c_void,
            &mut out,
        )
    }
    .check()?;
    out
}

#[throws(TreeRiteError)]
pub fn treelite_dmatrix_create_from_csr_format<'a, T: Float + FloatInfo>(
    headers: &'a [u64],
    indices: &'a [u32],
    data: &'a [T],
    num_row: u64,
    num_col: u64,
) -> DMatrixHandle {
    let mut out = null_mut();
    unsafe {
        TreeliteDMatrixCreateFromCSR(
            data.as_ptr() as *const c_void,
            Into::<&'static CStr>::into(T::DATA_TYPE).as_ptr(),
            indices.as_ptr() as *const u32,
            headers.as_ptr() as *const size_t,
            num_row as size_t,
            num_col as size_t,
            &mut out,
        )
    }
    .check()?;
    out
}

// #[throws(TreeRiteError)]
// pub fn treelite_dmatrix_create_from_file(path: &Path, format: Option<String>, data_type: DataType, nthread: usize, verbose: bool) -> DMatrixHandle {
//     let format = format.unwrap_or_else(|| "libsvm".to_string());
//     let format = CString::new(format)?;
//     let path = CString::new(path.to_string_lossy().to_owned().to_string())?;
//     let verbose = if verbose { 1 } else { 0 };
//     let mut out = null_mut();
//     let retcode = unsafe {
//         TreeliteDMatrixCreateFromFile(
//             path.as_ptr(),
//             format.as_ptr(),
//             Into::<&'static CStr>::into(data_type).as_ptr(),
//             nthread as i32,
//             verbose,
//             &mut out,
//         )
//     };
//     if retcode != 0 {
//         throw!(get_last_error())
//     }
//     out
// }

#[throws(TreeRiteError)]
pub fn treelite_dmatrix_get_dimension(handle: DMatrixHandle) -> (u64, u64, u64) {
    let (mut nrow, mut ncol, mut nelem) = (0, 0, 0);

    unsafe { TreeliteDMatrixGetDimension(handle, &mut nrow, &mut ncol, &mut nelem) }.check()?;

    (nrow as u64, ncol as u64, nelem as u64)
}