Compare commits

...

2 commits

11 changed files with 141 additions and 32 deletions

2
Cargo.lock generated
View file

@ -9,7 +9,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "mathematical"
name = "math"
version = "0.1.0"
dependencies = [
"num",

View file

@ -1,7 +1,5 @@
[package]
name = "mathematical"
version = "0.1.0"
edition = "2021"
[dependencies]
num = "0.4.3"
[workspace]
resolver = "2"
members = [
"math"
]

7
math/Cargo.toml Normal file
View file

@ -0,0 +1,7 @@
[package]
name = "math"
version = "0.1.0"
edition = "2021"
[dependencies]
num = "0.4.3"

2
math/src/lib.rs Normal file
View file

@ -0,0 +1,2 @@
pub mod matrix;
pub mod sq_matrix;

View file

@ -108,7 +108,7 @@ impl<T: Num> IndexMut<(usize, usize)> for Matrix<T> {
#[macro_export]
macro_rules! matrix {
[ $w:expr; $( $x:expr ),+ ] => {
$crate::math::Matrix::new(vec![$( $x, )+], $w)
$crate::matrix::Matrix::new(vec![$( $x, )+], $w)
};
}

View file

@ -4,6 +4,24 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi
use num::{traits::NumAssign, Num, Signed};
impl<T: Num + Clone> Matrix<T> {
pub fn multiply(&self, rhs: &Self) -> Self {
if !self.can_mul(&rhs) {
panic!("Unable to multiply matrices with sizes {}x{} and {}x{}",
self.width, self.height(), rhs.width, rhs.height());
}
let mut new = Matrix::new_zeroes(rhs.width, self.height());
for (i, j) in new.indices() {
let new_elem = self.row(i).into_iter()
.zip(rhs.column(j).into_iter())
.map(|(e1, e2)| e1 * e2)
.reduce(|acc, e| acc + e);
new[i][j] = new_elem.unwrap();
}
return new;
}
}
impl<T: Num + Clone> Add for Matrix<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
@ -109,15 +127,7 @@ impl<T: NumAssign + Clone> DivAssign<T> for Matrix<T> {
impl<T: Num + Clone> Mul for Matrix<T> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
if !self.can_mul(&rhs) {
panic!("Unable to multiply matrices with sizes {}x{} and {}x{}",
self.width, self.height(), rhs.width, rhs.height());
}
let mut new = Matrix::new_zeroes(rhs.width, self.height());
for (i, j) in new.indices() {
todo!()
}
return new;
self.multiply(&rhs)
}
}
@ -190,8 +200,50 @@ mod tests {
#[test]
#[should_panic(expected = "Unable to multiply")]
fn mul() {
fn mul_wrong_size() {
let _ = matrix![2; 1, 2, 3, 4] * matrix![2; 1, 2, 3, 4, 5, 6];
}
#[test]
fn mul() {
let first = matrix![3;
1, 2, 3,
4, 5, 6,
7, 8, 9
];
let second = matrix![3;
2, 3, 4,
5, 6, 7,
8, 9, 10
];
assert_eq!(first.multiply(&second), matrix![3;
36, 42, 48,
81, 96, 111,
126, 150, 174
]);
assert_eq!(second * first, matrix![3;
42, 51, 60,
78, 96, 114,
114, 141, 168
]);
let first = matrix![3;
1, 2, 3,
4, 5, 6
];
let second = matrix![2;
2, 3,
4, 5,
6, 7
];
assert_eq!(first.multiply(&second), matrix![2;
28, 34,
64, 79
]);
assert_eq!(second * first, matrix![3;
14, 19, 24,
24, 33, 42,
34, 47, 60
])
}
}

View file

@ -120,7 +120,6 @@ impl<'a, T: Num> Iterator for ColumnsIter<'a, T> {
}
}
impl Iterator for IterIndices {
type Item = (usize, usize);
fn next(&mut self) -> Option<Self::Item> {
@ -139,7 +138,7 @@ mod tests {
use super::*;
#[test]
fn iter() {
fn iter_elem() {
let data = vec![1,2,3,4,5,6,7,8,9];
let mut matrix = Matrix::new(data.clone(), 3);
for (i, e) in matrix.iter().enumerate() {
@ -153,5 +152,66 @@ mod tests {
assert_eq!(data[i] + 2, e);
}
}
#[test]
fn iter_indexed() {
let data = vec![1,2,3,4,5,6,7,8,9];
let mut matrix = Matrix::new(data.clone(), 3);
let (width, height) = (matrix.width(), matrix.height());
let mut matrix_iter = matrix.iter_indexed();
for i in 0..height {
for j in 0..width {
let (mi, mj, e) = matrix_iter.next().unwrap();
assert_eq!((mi, mj), (i, j));
assert_eq!(*e, data[i * width + j]);
}
}
let mut matrix_iter = matrix.iter_indexed_mut();
for i in 0..height {
for j in 0..width {
let (mi, mj, e) = matrix_iter.next().unwrap();
assert_eq!((mi, mj), (i, j));
assert_eq!(*e, data[i * width + j]);
*e += 2;
assert_eq!(*e, data[i * width + j] + 2);
}
}
}
#[test]
fn iter_rows() {
let mut matrix = matrix![3;
1, 2, 3,
4, 5, 6,
7, 8, 9
];
let mut matrix_iter = matrix.iter_rows();
assert_eq!(matrix_iter.next().unwrap(), &[1, 2, 3]);
assert_eq!(matrix_iter.next().unwrap(), &[4, 5, 6]);
assert_eq!(matrix_iter.next().unwrap(), &[7, 8, 9]);
assert_eq!(matrix_iter.next(), None);
for row in matrix.iter_rows_mut() {
row[0] += 9;
}
let mut matrix_iter = matrix.iter_rows_mut();
assert_eq!(matrix_iter.next().unwrap(), &[10, 2, 3]);
assert_eq!(matrix_iter.next().unwrap(), &[13, 5, 6]);
assert_eq!(matrix_iter.next().unwrap(), &[16, 8, 9]);
assert_eq!(matrix_iter.next(), None);
}
#[test]
fn iter_columns() {
let matrix = matrix![3;
1, 2, 3,
4, 5, 6,
7, 8, 9
];
for (j, column) in matrix.iter_columns().enumerate() {
for (i, e) in column.enumerate() {
assert_eq!(*e, matrix[i][j]);
}
}
}
}

View file

@ -45,8 +45,6 @@ fn check_size(order: usize) -> usize {
#[macro_export]
macro_rules! sq_matrix {
[ $o:expr; $( $x:expr ),+ ] => {
$crate::math::SquareMatrix::new(vec![$( $x, )+], $o)
$crate::matrix::SquareMatrix::new(vec![$( $x, )+], $o)
};
}
pub use sq_matrix;

View file

@ -1 +0,0 @@
pub mod math;

View file

@ -1,7 +0,0 @@
mod matrix;
pub use matrix::Matrix;
pub use matrix::matrix;
mod sq_matrix;
pub use sq_matrix::SquareMatrix;
pub use sq_matrix::sq_matrix;